# Subclassing the Model Class

In [29]:
import tensorflow as tf

In [70]:
mnist = tf.keras.datasets.mnist


We first declare and name individual layers in the constructor

In [105]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        
        inputs = tf.keras.Input(shape=(28, 28, 1))

        self.x0 = tf.keras.layers.Conv2D(filters=32,
                           kernel_size=5,
                           activation='relu',
                           input_shape=(28, 28, 1))
        self.x1 = tf.keras.layers.MaxPool2D()
        self.x2 = tf.keras.layers.Dropout(0.4)

        self.x3 = tf.keras.layers.Conv2D(filters=64,
                                   kernel_size=5,
                                   activation='relu')
        self.x4 = tf.keras.layers.MaxPool2D()
        self.x5 = tf.keras.layers.Dropout(0.4)

        self.x6 = tf.keras.layers.Flatten()
        self.x7 = tf.keras.layers.Dropout(0.4)
        
        self.output_pred = tf.keras.layers.Dense(10,
                                            activation='softmax')
    def call(self, inputs):
        x = self.x0(inputs)
        # Loop through each of the 8 layers
        for i in range(1, 8): 
            x = getattr(self, f'x{i}')(x)

        return self.output_pred(x)  

####  Define the forward pass in the `call ` method



 #### Load the training data


In [82]:
mnist_data = mnist.load_data()
(x_train, y_train), (x_test, y_test) = mnist_data

In [83]:
x_train, x_test = tf.cast(x_train/255., dtype=tf.float32), tf.cast(x_test/255., dtype=tf.float32)
y_train, y_test = tf.cast(y_train, dtype=tf.int64), tf.cast(y_test, dtype=tf.int64)


In [84]:
# Expand input dimesion

x_train = tf.reshape(x_train, [x_train.numpy().shape[0], 28, 28, 1])
x_test = tf.reshape(x_test, [x_test.numpy().shape[0], 28, 28, 1])


### Build and compile the model

In [102]:
model = MyModel()

In [103]:
model.compile(optimizer=tf.keras.optimizers.Adam(),
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

In [104]:
model.fit(x_train, y_train, batch_size=128, epochs=20)

1
2
3
4
5
6
7


W0813 16:17:13.631650 140462407554880 deprecation.py:323] From /root/.virtualenvs/tfs/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Train on 60000 samples
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7fbfb3814b50>

In [106]:
model.evaluate(x_test, y_test)



[0.01953036632850417, 0.9935]