In [1]:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
assert x_train.shape == (60000, 28, 28)
assert x_test.shape == (10000, 28, 28)
assert y_train.shape == (60000,)
assert y_test.shape == (10000,)

x_train = x_train.reshape((60000, 28, 28, 1))
x_test = x_test.reshape((10000, 28, 28, 1))

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

In [3]:
inputs = keras.Input(shape=(28, 28, 1), name="digits")
x = keras.layers.Conv2D(
    64, kernel_size=3, strides=2, padding="same", activation="relu"
)(inputs)
x = keras.layers.Conv2D(
    64, kernel_size=3, strides=2, padding="same", activation="relu"
)(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(512, activation="relu")(x)
x = keras.layers.Dense(10, activation="softmax")(x)

model = keras.Model(inputs=inputs, outputs=x, name="mnist_model")
model.compile(
    loss="sparse_categorical_crossentropy",
    optimizer=keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

model.summary()

In [4]:
history = model.fit(x_train, y_train, batch_size=128, epochs=10, validation_split=0.2)

Epoch 1/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 9ms/step - accuracy: 0.8478 - loss: 0.4840 - val_accuracy: 0.9753 - val_loss: 0.0813
Epoch 2/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 9ms/step - accuracy: 0.9797 - loss: 0.0620 - val_accuracy: 0.9833 - val_loss: 0.0582
Epoch 3/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 9ms/step - accuracy: 0.9891 - loss: 0.0330 - val_accuracy: 0.9855 - val_loss: 0.0467
Epoch 4/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.9935 - loss: 0.0202 - val_accuracy: 0.9858 - val_loss: 0.0520
Epoch 5/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.9957 - loss: 0.0131 - val_accuracy: 0.9863 - val_loss: 0.0502
Epoch 6/10
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.9967 - loss: 0.0094 - val_accuracy: 0.9875 - val_loss: 0.0523
Epoch 7/10
[1m375/375[0m 

In [5]:
results = model.evaluate(x_test, y_test, batch_size=128)

[1m79/79[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.9867 - loss: 0.0661
