# Imports

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras

"""
A model to classify the quintessential MNIST
"""

# Dataset

In [None]:
#get dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0, x_test/255.0

#add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

#shuffle and batch the dataset
train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

# Model

In [None]:
#create model
class MyModel(keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = keras.layers.Conv2D(32, 3, activation='relu', input_shape = (28,28,1))
        self.flatten = keras.layers.Flatten()
        self.d1 = keras.layers.Dense(128, activation='relu')
        self.drop1 = keras.layers.Dropout(.2)
        self.d2 = keras.layers.Dense(10)
    
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        x = self.drop1(x)
        x = self.d2(x)
        return x

#choose loss function and optimizer
model = MyModel()
loss_object = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

#select metrics to measure training progress over epochs
train_loss = keras.metrics.Mean(name='train_loss')
train_accuracy = keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = keras.metrics.Mean(name='test_loss')
test_accuracy = keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

# Training

In [None]:
#do a training step
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(labels, predictions)

#do a test step
@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [None]:
#train
EPOCHS = 5
for x in range(EPOCHS):
    #reset metrics at beginning of each epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)
    
    for images, labels in test_ds:
        test_step(images, labels)

    print(
        f'Epoch {x+1}, '
        f'Loss: {train_loss.result()} '
        f'Accuracy: {train_accuracy.result()*100}, '
        f'Test Loss: {test_loss.result()}, '
        f'Test Accuracy: {test_accuracy.result()*100}'
    )

Epoch 1, Loss: 0.16067983210086823 Accuracy: 95.17666625976562, Test Loss: 0.06209205463528633, Test Accuracy: 97.87999725341797
Epoch 2, Loss: 0.056645724922418594 Accuracy: 98.18500518798828, Test Loss: 0.05078515782952309, Test Accuracy: 98.33999633789062
Epoch 3, Loss: 0.035098131746053696 Accuracy: 98.86500549316406, Test Loss: 0.051028694957494736, Test Accuracy: 98.47000122070312
Epoch 4, Loss: 0.024233978241682053 Accuracy: 99.2066650390625, Test Loss: 0.04713011533021927, Test Accuracy: 98.50999450683594
Epoch 5, Loss: 0.018569566309452057 Accuracy: 99.40666198730469, Test Loss: 0.05826923996210098, Test Accuracy: 98.52999877929688
