In [1]:
import keras
import tensorflow as tf  # only for tf.data

print("Keras version", keras.version())
print("Keras is running on", keras.config.backend())

2025-11-17 14:47:50.861523: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


Keras version 3.12.0
Keras is running on tensorflow


In [2]:
# tf.data is a great API for putting together a data stream.
# It works whether you use the TensorFlow, PyTorch or Jax backend,
# as long as you use it in the data stream only and not inside of a model.

BATCH_SIZE = 256

(x_train, train_labels), (x_test, test_labels) = keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.map(
    lambda x, y: (tf.expand_dims(x, axis=-1), y)
)  # 1-channel monochrome
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.cache()
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)
train_data = train_data.repeat()

test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))
test_data = test_data.map(
    lambda x, y: (tf.expand_dims(x, axis=-1), y)
)  # 1-channel monochrome
test_data = test_data.batch(10000)
test_data = test_data.cache()

STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE
EPOCHS = 5

I0000 00:00:1763387273.017068   19385 gpu_device.cc:2020] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 8948 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 2080 Ti, pci bus id: 0000:08:00.0, compute capability: 7.5


In [3]:
class MnistModel(keras.Model):
    def __init__(self, **kwargs):
        # Keras Functional model definition. This could have used Sequential as
        # well. Sequential is just syntactic sugar for simple functional models.

        # 1-channel monochrome input
        inputs = keras.layers.Input(shape=(None, None, 1), dtype="uint8")
        # pixel format conversion from uint8 to float32
        y = keras.layers.Rescaling(1 / 255.0)(inputs)

        # 3 convolutional layers
        y = keras.layers.Conv2D(
            filters=16, kernel_size=3, padding="same", activation="relu"
        )(y)
        y = keras.layers.Conv2D(
            filters=32, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)
        y = keras.layers.Conv2D(
            filters=48, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)

        # 2 dense layers
        y = keras.layers.GlobalAveragePooling2D()(y)
        y = keras.layers.Dense(48, activation="relu")(y)
        y = keras.layers.Dropout(0.4)(y)
        outputs = keras.layers.Dense(
            10, activation="softmax", name="classification_head"  # 10 classes
        )(y)

        # A Keras Functional model is created by calling keras.Model(inputs, outputs)
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

In [4]:
model = MnistModel()

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)

history = model.fit(
    train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=test_data,
)

Epoch 1/5


2025-11-17 14:48:20.019487: I external/local_xla/xla/service/service.cc:163] XLA service 0x748724006f00 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-11-17 14:48:20.019500: I external/local_xla/xla/service/service.cc:171]   StreamExecutor device (0): NVIDIA GeForce RTX 2080 Ti, Compute Capability 7.5
2025-11-17 14:48:20.045876: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2025-11-17 14:48:20.180051: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:473] Loaded cuDNN version 91600
2025-11-17 14:48:20.368084: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=0} for conv (f32[256,16,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[256,1,28,28]{3,2,1,0}, f32[16,1,3,3]{3,2,1,0}, f32[16]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_cal

[1m 36/234[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 3ms/step - loss: 2.2828 - sparse_categorical_accuracy: 0.1348

I0000 00:00:1763387301.522011   19509 device_compiler.h:196] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2025-11-17 14:48:21.830143: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=0} for conv (f32[96,16,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[96,1,28,28]{3,2,1,0}, f32[16,1,3,3]{3,2,1,0}, f32[16]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kNone","conv_result_scale":1,"side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}


[1m224/234[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 8ms/step - loss: 1.8690 - sparse_categorical_accuracy: 0.3135

2025-11-17 14:48:23.638494: I external/local_xla/xla/service/gpu/autotuning/conv_algorithm_picker.cc:546] Omitted potentially buggy algorithm eng14{k25=0} for conv (f32[10000,16,28,28]{3,2,1,0}, u8[0]{0}) custom-call(f32[10000,1,28,28]{3,2,1,0}, f32[16,1,3,3]{3,2,1,0}, f32[16]{0}), window={size=3x3 pad=1_1x1_1}, dim_labels=bf01_oi01->bf01, custom_call_target="__cudnn$convBiasActivationForward", backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"activation_mode":"kRelu","conv_result_scale":1,"side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false,"reification_cost":[]}


[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 18ms/step - loss: 1.3795 - sparse_categorical_accuracy: 0.5103 - val_loss: 0.4426 - val_sparse_categorical_accuracy: 0.8967
Epoch 2/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.5441 - sparse_categorical_accuracy: 0.8305 - val_loss: 0.2538 - val_sparse_categorical_accuracy: 0.9266
Epoch 3/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.4047 - sparse_categorical_accuracy: 0.8800 - val_loss: 0.1860 - val_sparse_categorical_accuracy: 0.9434
Epoch 4/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.3440 - sparse_categorical_accuracy: 0.8982 - val_loss: 0.1655 - val_sparse_categorical_accuracy: 0.9538
Epoch 5/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 4ms/step - loss: 0.3058 - sparse_categorical_accuracy: 0.9083 - val_loss: 0.1501 - val_sparse_categorical_accuracy: 0.9543


In [5]:
model = MnistModel()

# Model summary works
model.summary()


# Recursively walking the layer graph works as well
def walk_layers(layer):
    if hasattr(layer, "layers"):
        for layer in layer.layers:
            walk_layers(layer)
    else:
        print(layer.name)


print("\nWalking model layers:\n")
walk_layers(model)


Walking model layers:

input_layer_1
rescaling_1
conv2d_3
conv2d_4
conv2d_5
global_average_pooling2d_1
dense_1
dropout_1
classification_head


In [6]:
model = MnistModel()

input = model.input
# cut before the classification head
y = model.get_layer("classification_head").input

# add a new classification head
output = keras.layers.Dense(
    1,  # single class for binary classification
    activation="sigmoid",
    name="binary_classification_head",
)(y)

# create a new functional model
binary_model = keras.Model(input, output)

binary_model.summary()

In [7]:
# new dataset with 0 / 1 labels (1 = digit '0', 0 = all other digits)
bin_train_data = train_data.map(
    lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))
)
bin_test_data = test_data.map(
    lambda x, y: (x, tf.cast(tf.math.equal(y, tf.zeros_like(y)), dtype=tf.uint8))
)

# appropriate loss and metric for binary classification
binary_model.compile(
    optimizer="adam", loss="binary_crossentropy", metrics=["binary_accuracy"]
)

history = binary_model.fit(
    bin_train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=bin_test_data,
)

Epoch 1/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 8ms/step - binary_accuracy: 0.8997 - loss: 0.3010 - val_binary_accuracy: 0.9261 - val_loss: 0.1839
Epoch 2/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - binary_accuracy: 0.9565 - loss: 0.1285 - val_binary_accuracy: 0.9770 - val_loss: 0.0713
Epoch 3/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - binary_accuracy: 0.9760 - loss: 0.0791 - val_binary_accuracy: 0.9861 - val_loss: 0.0495
Epoch 4/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - binary_accuracy: 0.9848 - loss: 0.0528 - val_binary_accuracy: 0.9795 - val_loss: 0.0656
Epoch 5/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - binary_accuracy: 0.9878 - loss: 0.0417 - val_binary_accuracy: 0.9883 - val_loss: 0.0375


In [8]:
class MnistDictModel(keras.Model):
    def __init__(self, **kwargs):
        #
        # The input is a dictionary
        #
        inputs = {
            "image": keras.layers.Input(
                shape=(None, None, 1),  # 1-channel monochrome
                dtype="uint8",
                name="image",
            )
        }

        # pixel format conversion from uint8 to float32
        y = keras.layers.Rescaling(1 / 255.0)(inputs["image"])

        # 3 conv layers
        y = keras.layers.Conv2D(
            filters=16, kernel_size=3, padding="same", activation="relu"
        )(y)
        y = keras.layers.Conv2D(
            filters=32, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)
        y = keras.layers.Conv2D(
            filters=48, kernel_size=6, padding="same", activation="relu", strides=2
        )(y)

        # 2 dense layers
        y = keras.layers.GlobalAveragePooling2D()(y)
        y = keras.layers.Dense(48, activation="relu")(y)
        y = keras.layers.Dropout(0.4)(y)
        outputs = keras.layers.Dense(
            10, activation="softmax", name="classification_head"  # 10 classes
        )(y)

        # A Keras Functional model is created by calling keras.Model(inputs, outputs)
        super().__init__(inputs=inputs, outputs=outputs, **kwargs)

In [9]:
model = MnistDictModel()

# reformat the dataset as a dictionary
dict_train_data = train_data.map(lambda x, y: ({"image": x}, y))
dict_test_data = test_data.map(lambda x, y: ({"image": x}, y))

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["sparse_categorical_accuracy"],
)

history = model.fit(
    dict_train_data,
    steps_per_epoch=STEPS_PER_EPOCH,
    epochs=EPOCHS,
    validation_data=dict_test_data,
)

Epoch 1/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - loss: 1.4163 - sparse_categorical_accuracy: 0.5044 - val_loss: 0.6028 - val_sparse_categorical_accuracy: 0.8534
Epoch 2/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.5876 - sparse_categorical_accuracy: 0.8186 - val_loss: 0.2925 - val_sparse_categorical_accuracy: 0.9160
Epoch 3/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.4368 - sparse_categorical_accuracy: 0.8693 - val_loss: 0.2301 - val_sparse_categorical_accuracy: 0.9321
Epoch 4/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.3649 - sparse_categorical_accuracy: 0.8924 - val_loss: 0.1818 - val_sparse_categorical_accuracy: 0.9439
Epoch 5/5
[1m234/234[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - loss: 0.3143 - sparse_categorical_accuracy: 0.9069 - val_loss: 0.1717 - val_sparse_categorical_accuracy: 0.9478
