In [1]:
import tensorflow as tf
import tensorflow_datasets as tfd
import tensorflow.keras as tfk

In [2]:
dataset, info = tfd.load('mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = dataset['train'], dataset['test']

In [3]:
def convert_types(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    return image, label

In [4]:
mnist_train = mnist_train.map(convert_types).shuffle(10000).batch(32)
mnist_test = mnist_test.map(convert_types).batch(32)

In [5]:
class ConvNet(tfk.Model):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = tfk.layers.Conv2D(32, 3, activation='relu')
        self.flatten = tfk.layers.Flatten()
        self.d1 = tfk.layers.Dense(128, activation='relu')
        self.d2 = tfk.layers.Dense(10, activation='softmax')
        
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

In [6]:
model = ConvNet()

In [14]:
loss_object = tfk.losses.SparseCategoricalCrossentropy()
optimizer = tfk.optimizers.Adam()

In [15]:
train_loss = tfk.metrics.Mean(name='train_loss')
train_accuracy = tfk.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tfk.metrics.Mean(name='test_loss')
test_accuracy = tfk.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [20]:
@tf.function
def train_step(X, y):
    with tf.GradientTape() as tape:
        ŷ = model(X)
        loss = loss_object(y, ŷ)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(y, ŷ)

In [21]:
@tf.function
def test_step(X, y):
    ŷ = model(X)
    loss = loss_object(y, ŷ)
    
    test_loss(loss)
    test_accuracy(y, ŷ)

In [None]:
X, y = next(iter(mnist_train))
ŷ = model(X)
loss = loss_object(y, ŷ)

In [22]:
EPOCHS = 5

for epoch in range(EPOCHS):
    for X, y in mnist_train:
        train_step(X, y)
    
    for X, y in mnist_test:
        test_step(X, y)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print (template.format(epoch+1,
                           train_loss.result(), 
                           train_accuracy.result()*100,
                           test_loss.result(), 
                           test_accuracy.result()*100))

Epoch 1, Loss: 0.14631910622119904, Accuracy: 95.60499572753906, Test Loss: 0.06119333952665329, Test Accuracy: 98.04000091552734
Epoch 2, Loss: 0.09565486758947372, Accuracy: 97.10083770751953, Test Loss: 0.06235846132040024, Test Accuracy: 97.9749984741211
Epoch 3, Loss: 0.07158130407333374, Accuracy: 97.82666778564453, Test Loss: 0.06021523103117943, Test Accuracy: 98.09667205810547
Epoch 4, Loss: 0.05717364326119423, Accuracy: 98.257080078125, Test Loss: 0.061921581625938416, Test Accuracy: 98.11499786376953
Epoch 5, Loss: 0.04770169034600258, Accuracy: 98.54367065429688, Test Loss: 0.06262506544589996, Test Accuracy: 98.1520004272461
