# MNIST digits recognition
## Feedforward network with 2 hidden layers
Uses [Keras](https://keras.io/) with [JAX](https://github.com/jax-ml/jax) as NN engine.

In [None]:
# Install Jax running on Google/Colab Tensor Processing Units
!pip install "jax[tpu]"

# Install Keras
!pip install keras-cv
!pip install keras-hub
!pip install keras

In [None]:
# Import installed software and put things in place
import os
os.environ["KERAS_BACKEND"] = "jax"
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import ops
import keras
# print(keras.__version__)
# print(keras.backend.backend())

# Model `model` below is a feedforward network

- Input layer of dimension `784` (each MNIST character is encoded as a 28-by-28 image having `28*28 = 784` pixels when “flattened”).
- Dense (“fully connected”) hidden layer 1 of dimension `64`.
- Dense (“fully connected”) hidden layer 2 of dimension `64`.
- Dense (“fully connected”) output layer of dimension `10`.

The `model` is then created as a Keras model object.
It will process `784`-dimensional inputs and generate `10`-dimensional outputs.

- Each of the 64 dimensions of the 1st layer has 784 linear cofficients plus 1 bias = 785 parameters, for a total of 64*785 = `50,240` parameters.
- Each of the 64 dimensions of the 1st layer has 64 linear cofficients plus 1 bias = 65 parameters, for a total of 64*65 = `4,160` parameters.
- Each of the 10 dimensions of the 1st layer has 64 linear cofficients plus 1 bias = 65 parameters, for a total of 10*65 = `650` parameters.

Therefore, `model` has `50,240 + 4,160 + 650 = 55,050` total parameters.

**Note:** The MNIST data set has *not* yet been loaded at all!
Absolutely *no* data proper has been used at this point.


In [None]:
dim_hl_1 = 64 # @param {type: "integer"}
dim_hl_2 = 64 # @param {type: "integer"}

inputs = keras.Input(shape=(784,))  # Input layer shape for flatened MNIST images
hidden1 = layers.Dense(dim_hl_1, activation="relu")(inputs)  # First hidden layer
hidden2 = layers.Dense(dim_hl_2, activation="relu")(hidden1)  # Second hidden layer
# Output layer must have dimension 10 = number of categories.
outputs = layers.Dense(10)(hidden2)  # No ReLu here b/c output is "logits".
model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model")
model.summary()


In [None]:
keras.utils.plot_model(model, "feedforward_model.png")

In [None]:
keras.utils.plot_model(model, "feedforward_model_with_shape_info.png", show_shapes=True)

# Load MNIST dataset

  Each single MNIST image of a handwritten digit 0, 1, 2, …, 8, 9 is encoded as a matrix of size `28*28` whose entries are 8-bit unsigned integers (i.e., each entry is an integer in the range from `0` to `2^8 - 1 = 255`).
Each integer represents the grayscale color of a pixel (0 = black, 255 = white).

For reasons of computational efficiency and numerical stability, the grayscale is rescaled to be a floating-point (decimal) number in the interval `[0, 1)`, and the 28-by-28 matrix is “flattened” to be a single (column) vector of `28*28 = 784` decimal such entries, each a `float32` (32-bit floating point number).

Since the training set consists of 60,000 such digit images, each effectively a column vector of size 784, it gives a large matrix `x_train` of size `60,000 × 784`; similarly, the test set of 10,000 images gives a `10,000 × 784` matrix `x_test`.


In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.reshape(60000, 784).astype("float32") / 255
x_test = x_test.reshape(10000, 784).astype("float32") / 255

# Compile and train the model

By default, when a model is “compiled”, its parameters are initialized randomly and uniformly (with zero mean and variance determined by the input and output dimensions at each layer—but otherwise not of immediate interest to us right now).

When `model` is applied to an input of size `m × 764`, it is applied to *each* of the `m` columns, one at a time—because `model` only accepts as input a (column) vector of size 764, and outputs one of size 10;
in other words, `model` sees those `m` *rows* as completely independent of each other.
The result is a matrix of size `m × 10`.

`batch_size=64` means that each “training step” of the model uses *only* a randomly-chosen batch of (about) `m = 64` inputs (i.e., 64 rows of `x_train`) at a time, until all 60,000 rows are exhausted.
This called a (single) “epoch”, which consists of (about) `60,000/64 ≈ 934` such steps.
The process is restarted (bringing back all 60,000 rows of `x_train`) and carried out a total number of epochs `num_epochs = 20`.

(Actually, `validation_split=0.2` means that 20% of the training data is excluded during each epoch, so an epoch therefore only consists of about `0.8 * 934 ≈ 747` batches of `≈ 64` training images each.)

Keras has various built-in optimization algorithms; this model's algorithm is the *root mean-square propagation* or `RMSprop`.

The accuracy of the model is quantified using a "loss" function. A larger loss means less accuracy, so the goal as the model is trained is *reducing* or “minimizing” the loss.
(The loss function used in this implementation is called *categorical cross-entropy*. Entropy and related quantities will be introduced later in the semester.)

In [None]:
num_epochs = 20  # @param {type: "integer"}

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=keras.optimizers.RMSprop(),
    metrics=["accuracy"],
)

history = model.fit(x_train, y_train, batch_size=64, epochs=20, validation_split=0.2)

test_scores = model.evaluate(x_test, y_test, verbose=2)
print("Test loss:", test_scores[0])
print("Test accuracy:", test_scores[1])

In [None]:
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()


The code below randomly chooses 5 of the 10,000 images in the test set and applies `model` to assign a label (0, 1, …, 9) using a straightforward likelihood maximization procedure.
Very few images in the test set should be misclassified:
the `True` and `Pred`icted labels should be in agreement.

In [None]:
# Select a few random indices from the test set
num_samples = 5
indices = np.random.choice(len(x_test), num_samples, replace=False)

plt.figure(figsize=(10, 2))
for i, idx in enumerate(indices):
    img = x_test[idx]
    # If images are flattened, reshape to 28x28
    if img.shape[-1] != 28:
        img = img.reshape(28, 28)
    plt.subplot(1, num_samples, i + 1)
    plt.imshow(img, cmap='gray')
    plt.axis('off')
    # Predict label and compare with true label.
    pred = model.predict(np.expand_dims(x_test[idx], axis=0), verbose=0)
    pred_label = np.argmax(pred, axis=1)[0]
    true_label = y_test[idx]
    plt.title(f"True: {true_label}\nPred: {pred_label}")
plt.tight_layout()
plt.show()