# 7 Working with Keras: A deep dive

In [20]:
import time
from tqdm import tqdm
import tensorflow as tf

## 7.3 Using built-in training and evaluation loops

**The standard workflow: `compile()`, `fit()`, `evaluate()`, `predict()`**

In [3]:
def get_mnist_model():
    inputs = tf.keras.Input(shape=(28 * 28,))
    features = tf.keras.layers.Dense(512, activation="relu")(inputs)
    features = tf.keras.layers.Dropout(0.5)(features)
    outputs = tf.keras.layers.Dense(10, activation="softmax")(features)
    model = tf.keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [4]:
model = get_mnist_model()
model.compile(
    optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
model.fit(
    train_images, train_labels, epochs=3, validation_data=(val_images, val_labels)
)
test_metrics = model.evaluate(test_images, test_labels)
predictions = model.predict(test_images)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 5ms/step - accuracy: 0.8674 - loss: 0.4525 - val_accuracy: 0.9575 - val_loss: 0.1415
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9517 - loss: 0.1630 - val_accuracy: 0.9697 - val_loss: 0.1104
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 4ms/step - accuracy: 0.9621 - loss: 0.1316 - val_accuracy: 0.9731 - val_loss: 0.1003
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9680 - loss: 0.1128
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step


### 7.3.1 Writing your own metrics

**Implementing a custom metric by subclassing the `Metric` class**

In [5]:
class RootMeanSquaredError(tf.keras.metrics.Metric):
    def __init__(self, name="rmse", **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")
        self.total_samples = self.add_weight(
            name="total_samples", initializer="zeros", dtype="float32"
        )
        # dtype="int32" would have TF place it on the CPU automatically,
        # see: https://github.com/keras-team/keras/issues/20250#issuecomment-2344087536

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

    def reset_state(self):
        self.mse_sum.assign(0.0)
        self.total_samples.assign(0)

In [6]:
model = get_mnist_model()
model.compile(
    optimizer="rmsprop",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy", RootMeanSquaredError()],
)
model.fit(
    train_images, train_labels, epochs=3, validation_data=(val_images, val_labels)
)
test_metrics = model.evaluate(test_images, test_labels)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 7ms/step - accuracy: 0.8615 - loss: 0.4598 - rmse: 0.4441 - val_accuracy: 0.9598 - val_loss: 0.1464 - val_rmse: 0.2500
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 3ms/step - accuracy: 0.9522 - loss: 0.1615 - rmse: 0.2694 - val_accuracy: 0.9685 - val_loss: 0.1071 - val_rmse: 0.2145
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9610 - loss: 0.1334 - rmse: 0.2445 - val_accuracy: 0.9711 - val_loss: 0.1048 - val_rmse: 0.2081
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2ms/step - accuracy: 0.9680 - loss: 0.1122 - rmse: 0.2229


---

## 7.4 Writing your own training and evaluation loops

### 7.4.1 Training versus inference

### 7.4.2 Low-level usage of metrics

In [7]:
metric = tf.keras.metrics.SparseCategoricalAccuracy()
targets = [0, 1, 2]
predictions = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
metric.update_state(targets, predictions)
current_result = metric.result()
print(f"result: {current_result:.2f}")

result: 1.00


In [8]:
values = [0, 1, 2, 3, 4]
mean_tracker = tf.keras.metrics.Mean()
for value in values:
    mean_tracker.update_state(value)
print(f"Mean of values: {mean_tracker.result():.2f}")

Mean of values: 2.00


### 7.4.3 A complete training and evaluation loop

**Writing a step-by-step training loop: the training step function**

In [22]:
model = get_mnist_model()

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.RMSprop()
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metric = tf.keras.metrics.Mean()

def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs[metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["loss"] = loss_tracking_metric.result()
    return logs

**Writing a step-by-step training loop: resetting the metrics**

In [23]:
def reset_metrics():
    for metric in metrics:
        metric.reset_state()
    loss_tracking_metric.reset_state()

**Writing a step-by-step training loop: the loop itself**

In [24]:
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
    reset_metrics()
    for (inputs_batch, targets_batch) in tqdm(training_dataset):
        logs = train_step(inputs_batch, targets_batch)
    print()
    print(f"Results at the end of epoch {epoch}")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")
    print("---")

100%|██████████| 1563/1563 [01:22<00:00, 19.05it/s]



Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9152
...loss: 0.2896
---


100%|██████████| 1563/1563 [00:45<00:00, 34.18it/s]



Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9539
...loss: 0.1599
---


100%|██████████| 1563/1563 [00:45<00:00, 34.03it/s]


Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9640
...loss: 0.1274
---





**Writing a step-by-step evaluation loop**

In [25]:
def test_step(inputs, targets):
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
times = []
for inputs_batch, targets_batch in val_dataset:
    t = time.time()
    logs = test_step(inputs_batch, targets_batch)
    times.append(time.time() - t)
print("Mean test_step time:", sum(times)/len(times))
print("Evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

Mean test_step time: 0.008391315563799094
Evaluation results:
...val_sparse_categorical_accuracy: 0.9680
...val_loss: 0.1164


In [26]:
%%timeit
for inputs_batch, targets_batch in val_dataset:
    logs = test_step(inputs_batch, targets_batch)

3.49 s ± 1.05 s per loop (mean ± std. dev. of 7 runs, 1 loop each)


### 7.4.4 Make it fast with tf.function

**Adding a `tf.function` decorator to our evaluation-step function**

In [27]:
@tf.function
def test_step_compiled(inputs, targets):
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)

    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()

    loss_tracking_metric.update_state(loss)
    logs["val_loss"] = loss_tracking_metric.result()
    return logs

val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
times = []
for inputs_batch, targets_batch in val_dataset:
    t = time.time()
    logs = test_step_compiled(inputs_batch, targets_batch)
    times.append(time.time() - t)
print("Mean test_step time:", sum(times)/len(times)) # an order of magnitude smaller
print("Evaluation results:")
for key, value in logs.items():
    print(f"...{key}: {value:.4f}")

Mean test_step time: 0.0017196583671691698
Evaluation results:
...val_sparse_categorical_accuracy: 0.9680
...val_loss: 0.1164


In [28]:
%%timeit
for inputs_batch, targets_batch in val_dataset:
    logs = test_step_compiled(inputs_batch, targets_batch)

432 ms ± 8.82 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### 7.4.5 Leveraging fit() with a custom training loop

**Implementing a custom training step to use with `fit()`**

In [29]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
loss_tracker = tf.keras.metrics.Mean(name="loss")

class CustomModel(tf.keras.Model):
    def train_step(self, data): # ← `train_step` is what we need to rewrit
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            loss = loss_fn(targets, predictions)
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))

        loss_tracker.update_state(loss)
        return {"loss": loss_tracker.result()}

    @property
    def metrics(self):
        return [loss_tracker]

In [30]:
inputs = tf.keras.Input(shape=(28 * 28,))
features = tf.keras.layers.Dense(512, activation="relu")(inputs)
features = tf.keras.layers.Dropout(0.5)(features)
outputs = tf.keras.layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(optimizer=tf.keras.optimizers.RMSprop())
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - loss: 0.4564
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 0.1687
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - loss: 0.1256


<keras.src.callbacks.history.History at 0x7f5d0a019210>

Updated from DLWP, see [reference here](https://keras.io/guides/custom_train_step_in_tensorflow/#a-first-simple-example).

In [31]:
class CustomModel(tf.keras.Model):
    def train_step(self, data):
        inputs, targets = data
        with tf.GradientTape() as tape:
            predictions = self(inputs, training=True)
            #           ↓ compute loss
            loss = self.compute_loss(y=targets, y_pred=predictions)
        gradients = tape.gradient(loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_weights))
        #    ↓ compiled metrics for updating
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(targets, predictions)
        return {m.name: m.result() for m in self.metrics} # ← metrics for reporting

In [32]:
inputs = tf.keras.Input(shape=(28 * 28,))
features = tf.keras.layers.Dense(512, activation="relu")(inputs)
features = tf.keras.layers.Dropout(0.5)(features)
outputs = tf.keras.layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

model.compile(
    optimizer=tf.keras.optimizers.RMSprop(),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 3ms/step - sparse_categorical_accuracy: 0.8628 - loss: 0.4515
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2ms/step - sparse_categorical_accuracy: 0.9506 - loss: 0.1660
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - sparse_categorical_accuracy: 0.9642 - loss: 0.1243


<keras.src.callbacks.history.History at 0x7f5cfe6fe740>