In [7]:
from tensorflow import keras
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers

def get_mnist_model():
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation='relu')(inputs)
    features = layers.Dropout(0.5)(features)
    output = layers.Dense(10, activation='softmax')(features)

    model = keras.Model(inputs=inputs, outputs=output)
    return model

model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.RMSprop()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = keras.metrics.Mean()

@tf.function #Compile tensorflow code to computational graph for global optimization. It enables fast execution
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss_val = loss_fn(targets, predictions)
    gradients = tape.gradient(loss_val, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    logs={}
    for metric in metrics:
        metric.update_state(predictions, targets)
        logs[metric.name] = metric.result()

    loss_tracking_metric.update_state(loss_val)
    logs['loss'] = loss_tracking_metric.result()
    return logs

def reset_metrics():
    for metric in metrics:
        metric.reset_state()
    loss_tracking_metric.reset_state()

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28*28)).astype('float32') / 255
test_images = test_images.reshape((10000, 28*28)).astype('float32') / 255

train_images, val_train_images = train_images[10000:], train_images[:10000]
train_labels, val_train_labels = train_labels[10000:], train_labels[:10000]

training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)

epochs = 3
for epoch in range(epochs):
    reset_metrics()
    for input_batch, input_target in training_dataset:
        logs = train_step(input_batch, input_target)
    print(f'Results at the end of epoch #{epoch}')
    for key, value in logs.items():
        print(f'{key} : {value:.4f}')




Results at the end of epoch #0
sparse_categorical_accuracy : 0.0000
loss : 0.2904
Results at the end of epoch #1
sparse_categorical_accuracy : 0.0005
loss : 0.1652
Results at the end of epoch #2
sparse_categorical_accuracy : 0.0012
loss : 0.1391
