In [16]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Conv2D, Flatten
from tensorflow.keras import Model
import numpy as np

In [11]:
## load data
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0

In [12]:
print(X_train.shape)
print(X_test.shape)

(60000, 28, 28)
(10000, 28, 28)


In [28]:
## Add channel dim
X_train = X_train[:, :, :, tf.newaxis]
X_test = X_test[..., tf.newaxis] ## equivalent version

In [29]:
print(X_train.shape)
print(X_test.shape)

(60000, 28, 28, 1)
(10000, 28, 28, 1)


In [42]:
## create tensor for each element tuple (tensor for data, tensor for labels)
## shuffle all the data
## and create batches
train_ds = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((X_test, y_test)).batch(32)

In [44]:
## class for our model (model subclassing)
class my_model(Model):
    def __init__(self):
        super(my_model, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.dense1 = Dense(30, activation='relu')
        self.dense2 = Dense(10) ## no activation here
    
    def call(self, x): ## x: our input
        x = self.conv1 (x)
        x = self.flatten (x)
        x = self.dense1 (x)
        return self.dense2 (x)      

In [45]:
model = my_model()

In [47]:
## define our loss
## define our optimizer
## define metrics to measure loss/accuracy

loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss') ## metrics -> what will be printed on screen during training
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_acc')

In [55]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape: ## tf.GradientTape to train the model
        predictions = model(images, training=True) ## Put training at True
        loss = loss_obj(labels, predictions) ## target - predictions
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    train_loss(loss)
    train_acc(labels, predictions)

In [56]:
@tf.function
def test_step(images, labels):
    predictions = model(images, training=False)
    loss = loss_obj(labels, predictions)
    
    test_loss(loss)
    test_acc(labels, predictions)

In [61]:
def training(train_datas, test_datas, model, epochs, train_loss, train_acc, test_loss, test_acc):
    '''
    Train our model
    train_datas: tf tensor with tuples (batched data, vector of labels)
    test_datas: tf tensor with tuples (batched data, vector of labels)
    model: model built
    epochs: nbr of epochs
    train_loss: metrics function to compute and print the training loss (tf.keras.metrics.*)
    train_acc: metrics function to compute and print the training accuracy (tf.keras.metrics.*)
    test_loss: metrics function to compute and print the testing loss (tf.keras.metrics.*)
    trest_acc: metrics function to compute and print the testing accuracy (tf.keras.metrics.*)
    '''
    for epoch in range(epochs):
        # reset metrics at the start of each epoch
        train_loss.reset_states()
        train_acc.reset_states()
        test_loss.reset_states()
        test_acc.reset_states()
        
        for images, labels in train_datas:
            train_step(images, labels)
        for images, labels in test_datas:
            test_step(images, labels)
            
        template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
        print(template.format(epoch+1,
                        train_loss.result(),
                        train_acc.result(),
                        test_loss.result(),
                        test_acc.result()))


In [62]:
training(train_ds, test_ds, model, 4, train_loss, train_acc, test_loss, test_acc)

Epoch 1, Loss: 0.013752027414739132, Accuracy: 0.9952166676521301, Test Loss: 0.0707167312502861, Test Accuracy: 0.9818000197410583
Epoch 2, Loss: 0.010401680134236813, Accuracy: 0.9965999722480774, Test Loss: 0.06813216954469681, Test Accuracy: 0.9819999933242798
Epoch 3, Loss: 0.007962152361869812, Accuracy: 0.9973000288009644, Test Loss: 0.07883007079362869, Test Accuracy: 0.9829000234603882
Epoch 4, Loss: 0.006514339707791805, Accuracy: 0.9978500008583069, Test Loss: 0.08836805820465088, Test Accuracy: 0.98089998960495
