In [24]:
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist

import tensorflow as tf

def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs) 
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60_000, 28 * 28)).astype("float32") / 32
test_images = test_images.reshape((10_000, 28 * 28)).astype("float32") / 32
train_images, val_images = images[10_000:], images[:10_000]
train_labels, val_labels = labels[10_000:], labels[:10_000]

In [40]:
# Listing 7.21 - Writing a step-by-step training loop: the training step function

model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.RMSprop()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = keras.metrics.Mean()

@tf.function
def training_step(inputs, targets):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    loss = loss_fn(targets, predictions)
  gradients = tape.gradient(loss, model.trainable_weights)
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs[metric.name] = metric.result()

  loss_tracking_metric.update_state(loss)
  logs["loss"] = loss_tracking_metric.result()
  return logs




In [41]:
# Listing 7.22 - Writing a step-by-step training loop: resetting the metrics

def reset_metrics():
  for metric in metrics:
    metric.reset_state()
  loss_tracking_metric.reset_state()

In [42]:
# Listing 7.23 - Writing a step-by-step training loop: the loop itself

training_dataset = tf.data.Dataset.from_tensor_slices(
  (train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3

for epoch in range(epochs):
  reset_metrics()
  for inputs_batch, targets_batch in training_dataset:
    logs = training_step(inputs_batch, targets_batch)
  print(f"Results at the end of epoch {epoch}")
  for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9023
...loss: 0.3947
Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9466
...loss: 0.2323
Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9555
...loss: 0.2064


In [43]:
# Listing 7.24 - Writing a step-by-step evaluation loop
# Listing 7.25 - Adding a `@tf.function` decorator to our evaluation-step function

@tf.function
def test_step(inputs, targets):
  predictions = model(inputs, training=False)
  loss = loss_fn(targets, predictions)

  logs = {}
  for metric in metrics:
    metric.update_state(targets, predictions)
    logs["val_" + metric.name] = metric.result()
  
  loss_tracking_metric.update_state(loss)
  logs["val_loss"] = loss_tracking_metric.result()
  return logs


val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()

for inputs_batch, targets_batch in val_dataset:
  logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items():
  print(f"...{key}: {value:.4f}")

Evaluation results:
...val_sparse_categorical_accuracy: 0.9684
...val_loss: 0.1708


In [57]:
# Listing 7.26 - Implementing a custom training step to use with `fit()`

loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")

class CustomModel(keras.Model):
  # We overide the train_step method
  def train_step(self, data):
    inputs, targets = data
    with tf.GradientTape() as tape:
      predictions = self(inputs, training=True)
      loss = self.compiled_loss(targets, predictions) # Compute the loss viae self.compiled_loss
    gradients = tape.gradient(loss, self.trainable_weights)
    self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
    self.compiled_metrics.update_state(targets, predictions) # Update the model's metrics via self.compiled_metrics
    return {m.name: m.result() for m in self.metrics} # Return a dict mapping metric names to their current value.

In [58]:
inputs = keras.Input(shape=(28 * 28))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=keras.optimizers.RMSprop(),
              loss=keras.losses.SparseCategoricalCrossentropy(),
              metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=3)



Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x3ab414a30>