In [12]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
from tensorflow.keras.datasets import mnist

In [13]:
def get_mnist_model():
  inputs = keras.Input(shape= (28*28,))
  features = layers.Dense(512, activation = "relu")(inputs)
  features = layers.Dropout(0.5)(features)
  outputs = layers.Dense(10, activation = "softmax")(features)
  model = keras.Model(inputs, outputs)
  return model

In [14]:
(images, labels), (test_images, test_labels) = mnist.load_data()

In [15]:
images = images.reshape((60000, 28*28)).astype("float32")/255
test_images = test_images.reshape((10000, 28*28)).astype("float32")/255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

In [16]:
model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.RMSprop()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = keras.metrics.Mean()

In [17]:
@tf.function
def training_step(inputs, targets):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training = True)
    loss = loss_fn(targets, predictions)
  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))

  logs = {}
  for metric in metrics:
    # accumulating metrics across batches
    metric.update_state(targets, predictions)
    logs[metric.name] = metric.result()

  # accumulating the mean of the loss values across batches
  loss_tracking_metric.update_state(loss)
  logs["loss"] = loss_tracking_metric.result()
  return logs

def reset_metrics():
  for metric in metrics:
    metric.reset_state()
  # A single object
  loss_tracking_metric.reset_state()

In [18]:
# Training loop
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)

epochs = 3

for epoch in range(epochs):
  reset_metrics()
  for inputs_batch, targets_batch in training_dataset:
    logs = training_step(inputs_batch, targets_batch)
  print(f"Results at the end of epoch {epoch}")
  for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9150
...loss: 0.2864
Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9548
...loss: 0.1580
Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9633
...loss: 0.1324


In [19]:
# Evaluation loop
@tf.function
def test_step(inputs, targets):
  predictions = model(inputs, training = False)
  loss = loss_fn(targets, predictions)

  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs["val_" + metric.name] = metric.result()

  loss_tracking_metric.update_state(loss)
  logs["val_loss"] = loss_tracking_metric.result()
  return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()

for inputs_batch, targets_batch in val_dataset:
  logs = test_step(inputs_batch, targets_batch)
print("Evaluation. results:")
for key, value in logs.items():
  print(f"...{key}: {value:.4f}")

Evaluation. results:
...val_sparse_categorical_accuracy: 0.9678
...val_loss: 0.1251
