# MNIST digits recognition
## Feedforward network with 2 hidden layers
Uses [Keras](https://keras.io/) with [JAX](https://github.com/jax-ml/jax) as NN engine.

In [None]:
# Install Jax running on Google/Colab Tensor Processing Units
!pip install "jax[tpu]"

# Install Keras
!pip install keras-cv
!pip install keras-hub
!pip install keras

In [None]:
# Import installed software and put things in place
import os
os.environ["KERAS_BACKEND"] = "jax"
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import ops
import keras
# print(keras.__version__)
# print(keras.backend.backend())

# Feedforward network
## Describe the topology (connectivity graph) of the model
- Input layer of dimension 784 = 28*28.
- “Dense” (fully connected) hidden layer 1 of dimension 64.
- “Dense” (fully connected) hidden layer 2 of dimension 64.
- “Dense” (fully connected) output layer of dimension 10.

The model is then created.
It will process 784-dimensional inputs and generate 10-dimensional outputs

**Note:** Note that the MNIST data set has not yet been loaded at all!
Absolutely *no* data proper is used at this point.

In [None]:
inputs = keras.Input(shape=(784,))  # Input layer shape for flatened MNIST images
x = layers.Dense(64, activation="relu")(inputs)  # First hidden layer
x = layers.Dense(64, activation="relu")(x)  # Second hidden layer
outputs = layers.Dense(10)(x)  # Output layer
model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
model.summary()


In [None]:
keras.utils.plot_model(model, "feedforward_model.png")

In [None]:
keras.utils.plot_model(model, "feedforward_model_with_shape_info.png", show_shapes=True)

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_split=0.2)

test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])

In [None]:
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()


In [None]:
# Select a few random indices from the test set
num_samples = 5
indices = np.random.choice(len(x_test), num_samples, replace=False)

plt.figure(figsize=(10, 2))
for i, idx in enumerate(indices):
    img = x_test[idx]
    # If images are flattened, reshape to 28x28
    if img.shape[-1] != 28:
        img = img.reshape(28, 28)
    plt.subplot(1, num_samples, i + 1)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    # Predict label and compare with true label.
    pred = model.predict(np.expand_dims(x_test[idx], axis=0), verbose=0)
    pred_label = np.argmax(pred, axis=1)[0]
    true_label = y_test[idx]
    plt.title(f"True: {true_label}\nPred: {pred_label}")
plt.tight_layout()
plt.show()