In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
keras.mixed_precision.set_global_policy("mixed_float16")

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

x_train = x_train.reshape((60000, 28, 28, 1))
x_test = x_test.reshape((10000, 28, 28, 1))

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

In [3]:
inputs = keras.Input(shape=(28, 28, 1), name="digits")
x = keras.layers.Conv2D(
    64, kernel_size=3, strides=2, padding="same", activation="relu"
)(inputs)
x = keras.layers.Conv2D(
    64, kernel_size=3, strides=2, padding="same", activation="relu"
)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dense(10, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=x, name="mnist_model")
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

model.summary()

In [4]:
history = model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

Epoch 1/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 7ms/step - accuracy: 0.8808 - loss: 0.3688 - val_accuracy: 0.9760 - val_loss: 0.0710
Epoch 2/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9839 - loss: 0.0539 - val_accuracy: 0.9843 - val_loss: 0.0535
Epoch 3/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - accuracy: 0.9908 - loss: 0.0304 - val_accuracy: 0.9840 - val_loss: 0.0544
Epoch 4/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9937 - loss: 0.0191 - val_accuracy: 0.9860 - val_loss: 0.0548
Epoch 5/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9956 - loss: 0.0138 - val_accuracy: 0.9872 - val_loss: 0.0505
Epoch 6/10
[1m750/750[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 0.9979 - loss: 0.0081 - val_accuracy: 0.9877 - val_loss: 0.0551
Epoch 7/10
[1m750/750[0m 

In [5]:
results = model.evaluate(x_test, y_test, batch_size=128)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 14ms/step - accuracy: 0.9867 - loss: 0.0768
