In [1]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

# Loading Data

In [0]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

In [9]:
print(x_train.shape)

(60000, 28, 28, 1)


In [0]:
x_train = tf.expand_dims(x_train, axis=-1)
x_test = tf.expand_dims(x_test, axis=-1)

In [0]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(1000).batch(32)

In [0]:
test_ds = tf.data.Dataset.from_tensor_slices(
    (x_test, y_test)).shuffle(1000).batch(32)

# Model Subclassing API

In [0]:
class my_model(Model):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10)
    
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [0]:
model = my_model()

# Training

- Select metrics to measure the loss and the accuracy of the model. 
- These metrics accumulate the values over epochs and then print the overall result.

In [0]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

In [0]:
train_loss = tf.keras.metrics.Mean(name="train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name="test_loss")
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="test_accuracy")

In [0]:
# train the model
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        # training=True is only needed if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(images, training=True)
        loss_calc_train = loss(labels, predictions)
    
    gradients = tape.gradient(loss_calc_train, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss_calc_train)
    train_accuracy(labels, predictions)


In [0]:
# test the model
@tf.function
def test_step(images, labels):
    # training=False is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions = model(images, training=False)
    loss_calc_test = loss(labels, predictions)

    test_loss(loss_calc_test)
    test_accuracy(labels, predictions)



In [65]:
epochs = 10
for epoch in range(epochs):
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(images, labels)
    
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1, train_loss.result(), train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))


Epoch 1, Loss: 0.020601395517587662, Accuracy: 99.33333587646484, Test Loss: 0.3043467104434967, Test Accuracy: 96.875
Epoch 2, Loss: 0.01295520830899477, Accuracy: 99.57833099365234, Test Loss: 0.0023995027877390385, Test Accuracy: 100.0
Epoch 3, Loss: 0.007603609003126621, Accuracy: 99.75, Test Loss: 6.908534851390868e-05, Test Accuracy: 100.0
Epoch 4, Loss: 0.008921196684241295, Accuracy: 99.66666412353516, Test Loss: 0.00014284989447332919, Test Accuracy: 100.0
Epoch 5, Loss: 0.005234827287495136, Accuracy: 99.81500244140625, Test Loss: 4.0059967432171106e-05, Test Accuracy: 100.0
Epoch 6, Loss: 0.003951547667384148, Accuracy: 99.86333465576172, Test Loss: 0.002484978409484029, Test Accuracy: 100.0
Epoch 7, Loss: 0.003997230436652899, Accuracy: 99.86000061035156, Test Loss: 8.642549573778524e-07, Test Accuracy: 100.0
Epoch 8, Loss: 0.0028298525139689445, Accuracy: 99.91000366210938, Test Loss: 4.959200305165723e-05, Test Accuracy: 100.0
Epoch 9, Loss: 0.003775108139961958, Accuracy