# Libraries

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np

print(keras.__version__)
print(keras.config.backend())

3.10.0
jax


# Loading the dataset

In [8]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

print("--- MNIST Dataset Shapes ---")
print(f"Training Images (x_train): {x_train.shape} -> {x_train.shape[0]} images, each {x_train.shape[1]}x{x_train.shape[2]} pixels")
print(f"Training Labels (y_train): {y_train.shape} -> {y_train.shape[0]} labels")
print(f"Test Images (x_test):      {x_test.shape} -> {x_test.shape[0]} images, each {x_test.shape[1]}x{x_test.shape[2]} pixels")
print(f"Test Labels (y_test):      {y_test.shape} -> {y_test.shape[0]} labels")
print("--------------------------")

--- MNIST Dataset Shapes ---
Training Images (x_train): (60000, 28, 28) -> 60000 images, each 28x28 pixels
Training Labels (y_train): (60000,) -> 60000 labels
Test Images (x_test):      (10000, 28, 28) -> 10000 images, each 28x28 pixels
Test Labels (y_test):      (10000,) -> 10000 labels
--------------------------


In [9]:
y_train[:2]

array([5, 0], dtype=uint8)

# Preprocess data

In [10]:
# Normalize pixel values to 0-1 range
x_train = x_train.reshape(-1, 28 * 28).astype("float32") / 255.0
x_test = x_test.reshape(-1, 28 * 28).astype("float32") / 255.0

# One-Hot Enconding

In [11]:
# Convert labels to one-hot encoding
num_classes = 10
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

In [12]:
y_train[:2]

array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [13]:
y_train.shape

(60000, 10)

# Defining the Neural Network Architecture

In [16]:
def create_mnist_model():
    model = keras.Sequential([
        keras.layers.Input(shape=(784,)),
        keras.layers.Dense(128, activation="relu"),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(num_classes, activation="softmax")
    ])
    return model

# Instantiating and Compiling the Model

In [42]:
# Instantiate and compile the model
model = create_mnist_model()
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.CategoricalAccuracy()]
)

# Training the Model

In [43]:
history = model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=5,
    validation_split=0.1
)

Epoch 1/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - categorical_accuracy: 0.7874 - loss: 0.7408 - val_categorical_accuracy: 0.9508 - val_loss: 0.1748
Epoch 2/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - categorical_accuracy: 0.9346 - loss: 0.2250 - val_categorical_accuracy: 0.9658 - val_loss: 0.1261
Epoch 3/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - categorical_accuracy: 0.9521 - loss: 0.1642 - val_categorical_accuracy: 0.9720 - val_loss: 0.1031
Epoch 4/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - categorical_accuracy: 0.9615 - loss: 0.1290 - val_categorical_accuracy: 0.9730 - val_loss: 0.0929
Epoch 5/5
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 1ms/step - categorical_accuracy: 0.9673 - loss: 0.1088 - val_categorical_accuracy: 0.9752 - val_loss: 0.0838


# Evaluate the Model's Performance

In [27]:
# Evaluate the model
print("\nEvaluating the model...")
loss, accuracy = model.evaluate(x_test, y_test)
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")


Evaluating the model...
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step - categorical_accuracy: 0.9675 - loss: 0.1062
Test Loss: 0.0938
Test Accuracy: 0.9716


# Inference with the Trained Model

In [46]:
num_sample = 15
sample_input = x_test[num_sample]
print(f"Original single sample shape: {sample_input.shape}") # (784,)

# Add a batch dimension at axis 0
sample_input_batch = np.expand_dims(sample_input, axis=0)
print(f"Shape after adding batch dimension: {sample_input_batch.shape}\n") # (1, 784)

print("Prediction...")
prediction = model.predict(sample_input_batch) # Pass the batch-ready input
print(f"\nRaw prediction output:\n{prediction}") # This will be (1, 10)

print(f"\nPrediction for a sample input: {np.argmax(prediction[0])}")
print(f"True Label for sample input: {np.where(y_test[num_sample] == 1.0)[0][0]} ")

Original single sample shape: (784,)
Shape after adding batch dimension: (1, 784)

Prediction...
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step

Raw prediction output:
[[7.7806098e-06 1.8786210e-04 6.5804066e-05 2.6002495e-02 6.9100747e-06
  9.7333771e-01 6.3755738e-06 9.0677504e-06 2.2939921e-04 1.4657571e-04]]

Prediction for a sample input: 5
True Label for sample input: 5 
