In [22]:
import tensorflow as tf

# Model subclassing

Build a fully-customizable model by subclassing `tf.keras.Model` and defining your own `forward pass`. 

- Create layers in the `__init__` method and set them as attributes of the class instance. 
- Define the forward pass in the `call` method.

## 1. Subclassing Model

In [23]:
class MNISTModel(tf.keras.Model):
    def __init__(self, num_classes=10):
        super().__init__(name='mnist_model')
        self.num_classes = num_classes
        
        # Define user layers here:
        self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
        self.dense_1 = tf.keras.layers.Dense(256, activation='relu')
        self.dense_2 = tf.keras.layers.Dense(128, activation='relu')
        self.dense_3 = tf.keras.layers.Dense(num_classes)
        self.output_layer = tf.keras.layers.Softmax()
        
    def call(self, inputs):
        '''Define forward pass here using previously defined layers'''
        x = self.flatten(inputs)
        x = self.dense_1(x)
        x = self.dense_2(x)
        x = self.dense_3(x)
        return self.output_layer(x)

## 2. Train and Test

In [24]:
# load dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [25]:
model = MNISTModel(num_classes=10)

# The compile step specifies the training configuration.
# in this case, the true label is an integer, rather than one-hot, so 
# SparseCategoricalCrossentropy loss function is applied, otherwise,
# CategoricalCrossentropy should be used.
# 
model.compile(optimizer=tf.keras.optimizers.RMSprop(0.001),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Trains for 5 epochs.
model.fit(x_train, y_train, batch_size=32, epochs=5)

Train on 60000 samples
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x2192b39c588>

## 3. Save and Restore Model

In [29]:
# Save the model
path = 'ckpt_nn\mnist_subclass_model'
model.save(path, save_format='tf')
predictions = model.predict(x_test)

INFO:tensorflow:Assets written to: ckpt_nn\mnist_subclass_model\assets


In [30]:
# Recreate the exact same model purely from the file
new_model = tf.keras.models.load_model(path)
new_predictions = new_model.predict(x_test)

In [31]:
# compare results
import numpy as np
np.testing.assert_allclose(predictions, new_predictions, rtol=1e-6, atol=1e-6)