In [5]:
import tensorflow as tf
from tensorflow.keras import layers as L


class CNNResUnit(tf.keras.models.Model):
  #number of cnn layers before concat. 
  #the filter size, will be a single number.
  def __init__(self, layers, num_filters, filter_size, **kwargs): 
    super().__init__(**kwargs)
    self.hidden = [L.Conv2D(num_filters, (filter_size, filter_size), activation="relu", padding="same") for _ in range(layers)]
    
  def call(self, inputs):
    x = inputs
    for layer in self.hidden:
      x = layer(x) 
    return inputs + x 


class DenseResUnit(tf.keras.models.Model):
  def __init__(self, layers, units, **kwargs):
    super().__init__(**kwargs)
    self.hidden = [L.Dense(units, activation="relu") for _ in range(layers)]

  def call(self, inputs):
    x = inputs 
    for layer in self.hidden:
      x = layer(x)
    return x + inputs


class ResnetModel(tf.keras.Model):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)
    self.dense1 = L.Dense(64, activation="relu")
    self.cnnres = CNNResUnit(2, 64, 3)
    self.denseres = DenseResUnit(2, 64)
    self.out = L.Dense(10, activation="softmax")

  def call(self, inputs):
    #x = self.dense1(inputs) 
    x = self.cnnres(inputs)
    for i in range(3):
      x = self.denseres(x)
    
    x = L.Flatten()(x)
    return self.out(x)



In [6]:
import numpy as np

(x_t, y_t), (x_v, y_v) = tf.keras.datasets.mnist.load_data()
x_t , x_v  = np.expand_dims((x_t /255.),axis=-1) , np.expand_dims((x_v/255.), axis=-1) 


model = ResnetModel()
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["acc"])

In [7]:
model.fit(x_t, y_t, validation_data=(x_v, y_v), epochs=20, batch_size=64)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20

KeyboardInterrupt: ignored