In [1]:
import os

os.environ["KERAS_BACKEND"] = "tensorflow"

import tensorflow as tf
import keras

# Load MNIST dataset and standardize the data
mnist = keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

2025-11-17 11:05:21.661135: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


In [2]:
class MyTrainer(keras.Model):
    def __init__(self, model):
        super().__init__()
        self.model = model
        # Create loss and metrics here.
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy()
        self.accuracy_metric = keras.metrics.SparseCategoricalAccuracy()

    @property
    def metrics(self):
        # List metrics here.
        return [self.accuracy_metric]

    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape:
            y_pred = self.model(x, training=True)  # Forward pass
            # Compute loss value
            loss = self.loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value.
        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        x, y = data

        # Inference step
        y_pred = self.model(x, training=False)

        # Update metrics
        for metric in self.metrics:
            metric.update_state(y, y_pred)
        return {m.name: m.result() for m in self.metrics}

    def call(self, x):
        # Equivalent to `call()` of the wrapped keras.Model
        x = self.model(x)
        return x

In [3]:
# A model defined using Sequential API
model_a = keras.models.Sequential(
    [
        keras.layers.Flatten(input_shape=(28, 28)),
        keras.layers.Dense(256, activation="relu"),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation="softmax"),
    ]
)

# A model defined using Functional API
func_input = keras.Input(shape=(28, 28, 1))
x = keras.layers.Flatten(input_shape=(28, 28))(func_input)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dropout(0.4)(x)
func_output = keras.layers.Dense(10, activation="softmax")(x)

model_b = keras.Model(func_input, func_output)

  super().__init__(**kwargs)
I0000 00:00:1763373949.466731    9503 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8971 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:08:00.0, compute capability: 7.5


In [4]:
trainer_1 = MyTrainer(model_a)
trainer_2 = MyTrainer(model_b)

In [5]:
trainer_1.compile(optimizer=keras.optimizers.SGD())
trainer_1.fit(
    x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)

trainer_2.compile(optimizer=keras.optimizers.Adam())
trainer_2.fit(
    x_train, y_train, epochs=5, batch_size=64, validation_data=(x_test, y_test)
)

Epoch 1/5


2025-11-17 11:06:06.260689: I external/local_xla/xla/service/service.cc:163] XLA service 0x7cff780073c0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-11-17 11:06:06.260699: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-11-17 11:06:06.292719: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-11-17 11:06:06.367965: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91600


[1m182/938[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 834us/step - sparse_categorical_accuracy: 0.3585

I0000 00:00:1763373966.893779    9618 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - sparse_categorical_accuracy: 0.7705 - val_sparse_categorical_accuracy: 0.8888
Epoch 2/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 970us/step - sparse_categorical_accuracy: 0.8782 - val_sparse_categorical_accuracy: 0.9086
Epoch 3/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 972us/step - sparse_categorical_accuracy: 0.8948 - val_sparse_categorical_accuracy: 0.9154
Epoch 4/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 966us/step - sparse_categorical_accuracy: 0.9054 - val_sparse_categorical_accuracy: 0.9222
Epoch 5/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 970us/step - sparse_categorical_accuracy: 0.9128 - val_sparse_categorical_accuracy: 0.9270
Epoch 1/5
[1m938/938[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - sparse_categorical_accuracy: 0.9209 - val_sparse_categorical_accuracy: 0.9617
Epoch 2/5
[

<keras.src.callbacks.history.History at 0x7d009d912530>