In [35]:
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

In [36]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

In [37]:
train_images.shape

(60000, 28, 28)

In [40]:
training_dataset = tf.data.Dataset.from_tensor_slices((tf.reshape(train_images, [-1, 28*28]), train_labels))
training_dataset = training_dataset.batch(32)

In [41]:
def get_mnist_model():

  inputs = keras.Input(shape = (28*28,))
  features = layers.Dense(512, activation="relu")(inputs)
  features = layers.Dropout(0.5)(features)
  outputs = layers.Dense(10, activation = "softmax")(features)
  model = keras.Model(inputs, outputs)
  return model

In [42]:
model = get_mnist_model()

optimizer = keras.optimizers.RMSprop()
loss_fn = keras.losses.SparseCategoricalCrossentropy()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = keras.metrics.Mean()

def train_step(inputs, targets):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training = True)
    loss = loss_fn(targets, predictions)
  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))

  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs[metric.name] = metric.result()

  loss_tracking_metric.update_state(loss)
  logs["loss"] = loss_tracking_metric.result()
  return logs

In [43]:
def reset_metrics():
  for metric in metrics:
    metric.reset_state()
  loss_tracking_metric.reset_state()

In [44]:
epochs = 3

for epoch in range(epochs):
  reset_metrics()
  for inputs_batch, targets_batch in training_dataset:
    logs = train_step(inputs_batch, targets_batch)
  print(f"Results at the end of epoch: {epoch}")
  for key, value in logs.items():
    print(f"...{key}: {value:.4f}")


Results at the end of epoch: 0
...sparse_categorical_accuracy: 0.8371
...loss: 4.4400
Results at the end of epoch: 1
...sparse_categorical_accuracy: 0.8877
...loss: 1.0226
Results at the end of epoch: 2
...sparse_categorical_accuracy: 0.9051
...loss: 0.9037


In [51]:
def test_step(inputs, targets):
  predictions = model(inputs, training = False)
  loss = loss_fn(targets, predictions)

  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs["val" + metric.name] = metric.result()

  loss_tracking_metric.update_state(loss)
  logs["loss"] = loss_tracking_metric.result()

  return logs


test_dataset = tf.data.Dataset.from_tensor_slices((tf.reshape(test_images, [-1, 28*28]), tf.squeeze(test_labels))).batch(32)

reset_metrics()

for inputs_batch, targets_batch in test_dataset:
  logs = test_step(inputs_batch, targets_batch)
print(f"Evaluation results:")
for key, value in logs.items():
  print(f"...{key}: {value:.4f}")

Evaluation results:
...valsparse_categorical_accuracy: 0.9437
...loss: 0.7116


In [56]:
# Make it faster by adding tf.function() by compiling eager code into Computation graph
# Will loose out on debugging power, however!

@tf.function
def test_step(inputs, targets):
  predictions = model(inputs, training = False)
  loss = loss_fn(targets, predictions)

  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs["val" + metric.name] = metric.result()

  loss_tracking_metric.update_state(loss)
  logs["loss"] = loss_tracking_metric.result()

  return logs

test_dataset = tf.data.Dataset.from_tensor_slices(
    (tf.reshape(test_images, [-1, 28*28]), tf.squeeze(test_labels))
).batch(32)

reset_metrics()

for inputs_batch, targets_batch in test_dataset:
  logs = test_step(inputs_batch, targets_batch)
print(f"Evaluation results:")
for key, value in logs.items():
  print(f"...{key}: {value:.4f}")

Evaluation results:
...valsparse_categorical_accuracy: 0.9437
...loss: 0.7116


In [None]:
# Leveraging fit() with a custom training loop

class CustomModel(keras.Model):
  def train_step(self, data):
    inputs, targets = data
    with tf.GradientTape() as tape:
      predictions = self(inputs, training = True)
      loss = self.compiled_loss(targets, predictions)

    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    self.compiled_metrics.upadte_stae(targets, predictions)
    return {m.name: m.result() for m in metrics}


In [60]:
inputs = keras.Input(shape=(28, 28))
x = layers.Flatten()(inputs)
x = layers.Dense(512, activation="relu")(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(10, activation="softmax")(x)
model = keras.Model(inputs, outputs)

model.compile(optimizer = keras.optimizers.RMSprop(),
              loss = keras.losses.SparseCategoricalCrossentropy(),
              metrics = [keras.metrics.SparseCategoricalAccuracy()])

model.fit(train_images, train_labels, epochs = 3)

Epoch 1/3
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - loss: 11.6266 - sparse_categorical_accuracy: 0.7935
Epoch 2/3
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - loss: 1.0484 - sparse_categorical_accuracy: 0.8828
Epoch 3/3
[1m1875/1875[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 2ms/step - loss: 0.8693 - sparse_categorical_accuracy: 0.9047


<keras.src.callbacks.history.History at 0x7fee3bdeb410>