# Handwriting recognition


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax

from flax import linen as nn
from flax.training import train_state
import optax

### Data import and visualization


Import the MNIST train dataset ([https://en.wikipedia.org/wiki/MNIST_database](https://en.wikipedia.org/wiki/MNIST_database))


In [None]:
# This dataset is contained in the sample data directory of Google Colab online runtimes
data = np.genfromtxt("./mnist_train_small.csv", delimiter=",")
data.shape

Store the data in a 4-th order tensor (samples, x-pixel, y-pixel, channels) and the labels in a vector.
**NOTE:** The labels are the first column of the data matrix


In [None]:
# SOLUTION-BEGIN
labels = data[:, 0]
x_data = data[:, 1:].reshape((-1, 28, 28, 1)) / 255
labels.shape, x_data.shape
# SOLUTION-END

Visualize the first 30 pictures with the corresponding labels


In [None]:
# SOLUTION-BEGIN
fig, axs = plt.subplots(ncols=10, nrows=3, figsize=(20, 6))
axs = axs.reshape((-1,))
for i in range(30):
    image_i = x_data[i]
    axs[i].imshow(image_i, cmap="gray")
    axs[i].set_title(int(labels[i]))
    axs[i].axis("off")
# SOLUTION-END

Create a [one-hot](https://en.wikipedia.org/wiki/One-hot) representation of the labels, that is a matrix where each row corresponds to a class (i.e. a digit).
the entries of the matrix are 1 if the sample corresponds to that digit, 0 otherwise.


In [None]:
# SOLUTION-BEGIN
labels_onehot = np.zeros((20000, 10))
for i in range(10):
    labels_onehot[labels == i, i] = 1
# SOLUTION-END

Check that the matrix has exactly one element "1" in each row.


In [None]:
# SOLUTION-BEGIN
row_sums = np.sum(labels_onehot, axis=1)
row_sums.min(), row_sums.max()
# SOLUTION-END

### ANN training


Define the architecture of the neural network.
For more details on CNNs see https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks


In [None]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)

        # The `softmax_cross_entropy` expects unnormalized logits.
        # There is also the `softmax_cross_entropy_with_integer_labels`
        # version that uses integers target labels.
        # If you apply a softmax first, you turn logits into probabilities, and the
        # loss might becomes numerically unstable and incorrect. Optax/JAX expects to
        # handle the softmax internally in a stable way (using logsumexp tricks).
        x = nn.Dense(features=10)(x)  # There are 10 classes in MNIST
        return x


cnn = CNN()

table = cnn.tabulate(
    jax.random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)), console_kwargs={"width": 200}
)

print(table)

Write a function to compute the **cross entropy loss** and the **accuracy** of the model given as parameters the unormalized logits (the output of the CNN) and the one-hot encoded target value. To compute the loss you can exploit `optax.softmax_cross_entropy`, check the documentation for the details.


In [None]:
def compute_metrics(logits, labels_onehot):
    # SOLUTION-BEGIN
    loss = jnp.mean(optax.softmax_cross_entropy(logits, labels_onehot))
    accuracy = jnp.mean(jnp.argmax(logits, -1) == jnp.argmax(labels_onehot, -1))
    # SOLUTION-END
    return {"loss": loss, "accuracy": accuracy}

Here we define functions used for the training


In [None]:
@jax.jit
def loss_fn(params, x, y):
    # SOLUTION-BEGIN
    logits = cnn.apply({"params": params}, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
    # SOLUTION-END
    return loss, logits


@jax.jit
def train_step(state, x, y):
    # SOLUTION-BEGIN
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_, logits), grads = grad_fn(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(logits, y)
    # SOLUTION-END
    return state, metrics


def eval_model(state, dataset):
    # SOLUTION-BEGIN
    logits = state.apply_fn({"params": state.params}, dataset["image"])
    metrics = compute_metrics(logits, dataset["label"])
    # SOLUTION-END
    return metrics["loss"], metrics["accuracy"]

Define the function for one training epoch


In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
    # SOLUTION-BEGIN
    train_ds_size = len(train_ds["image"])
    steps_per_epoch = train_ds_size // batch_size

    perms = jax.random.permutation(rng, len(train_ds["image"]))
    perms = perms[: steps_per_epoch * batch_size]
    perms = perms.reshape((steps_per_epoch, batch_size))

    batch_metrics = []

    for perm in perms:
        batch = {k: v[perm, ...] for k, v in train_ds.items()}
        state, metrics = train_step(state, batch["image"], batch["label"])
        batch_metrics.append(metrics)

    training_epoch_metrics = {
        k: np.mean([metrics[k] for metrics in batch_metrics])
        for k in batch_metrics[0]
    }

    print(
        f"{epoch:04}  | "
        f"{training_epoch_metrics['loss']:.4e} | "
        f"    {training_epoch_metrics['accuracy'] * 100:.2f} | ",
        end=""
    )

    # SOLUTION-END

    return state, training_epoch_metrics

Prepare data by randomizing and creating the train-validation split


In [None]:
np.random.seed(0)

# SOLUTION-BEGIN
n_samples = x_data.shape[0]
perm = np.random.permutation(n_samples)
train_perc = 0.8
n_train_samples = int(train_perc * n_samples)
train_idxs = perm[:n_train_samples]
valid_idx = perm[n_train_samples:]
# SOLUTION-END

train_ds = {
    "image": jnp.array(x_data[train_idxs]),
    "label": jnp.array(labels_onehot[train_idxs], dtype=jnp.float32),
}
valid_ds = {
    "image": jnp.array(x_data[valid_idx]),
    "label": jnp.array(labels_onehot[valid_idx], dtype=jnp.float32),
}

Run training


In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Initialize lists to store metrics for graph visualization
training_losses = []
training_accuracies = []
valid_losses = []
valid_accuracies = []

# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001

# Initialize the CNN and optimizer
# SOLUTION-BEGIN
params = cnn.init(init_rng, jnp.ones([1, 28, 28, 1]))["params"]
tx = optax.adam(learning_rate=learning_rate)
state = train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)
# SOLUTION-END

print("epoch | train loss | train acc | valid loss | valid acc")
for epoch in range(1, num_epochs + 1):
    # SOLUTION-BEGIN
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    state, train_metrics = train_epoch(state, train_ds, batch_size, epoch, input_rng)
    # Evaluate on the test set after each training epoch
    valid_loss, valid_accuracy = eval_model(state, valid_ds)
    print(
        f"{valid_loss:.4e} | " f"{valid_accuracy * 100:.2f}",
    )
    # Store metrics for graph visualization
    training_losses.append(train_metrics["loss"])
    training_accuracies.append(train_metrics["accuracy"])
    valid_losses.append(valid_loss)
    valid_accuracies.append(valid_accuracy)
    # SOLUTION-END

## Testing


Load the dataset `sample_data/mnist_test.csv` and normalize and shape the data.


In [None]:
data_test = np.genfromtxt("./mnist_test.csv", delimiter=",")
data_test.shape
labels_test = data_test[:, 0]
x_test = data_test[:, 1:] / 255

labels_onehot_test = np.zeros((x_test.shape[0], 10))
for i in range(10):
    labels_onehot_test[labels_test == i, i] = 1.0


test_ds = {
    "image": jnp.array(x_test.reshape((-1, 28, 28, 1))),
    "label": jnp.array(labels_onehot_test),
}

Compute the accuracy of the classifier on this dataset.


In [None]:
# SOLUTION-BEGIN
test_loss, test_accuracy = eval_model(state, test_ds)
print(f"Loss: {test_loss:.2e}")
print(f"Accuracy: {test_accuracy * 100.:.2f}%")
# SOLUTION-END

Use the following script to visualize the predictions on a bunch of test images.


In [None]:
offset = 0
n_images = 40

images_per_row = 10
y_predicted = state.apply_fn({"params": state.params}, test_ds["image"])


def draw_bars(ax, y_predicted, label):
    myplot = ax.bar(range(10), (y_predicted))
    ax.set_ylim([0, 1])
    ax.set_xticks(range(10))

    label_predicted = np.argmax(y_predicted)
    if label == label_predicted:
        color = "green"
    else:
        color = "red"
    myplot[label_predicted].set_color(color)


import math

n_rows = 2 * math.ceil(n_images / images_per_row)
_, axs = plt.subplots(n_rows, images_per_row, figsize=(3 * images_per_row, 3 * n_rows))
row = 0
col = 0
for i in range(n_images):
    axs[2 * row, col].imshow(x_test[offset + i].reshape((28, 28)), cmap="gray")
    axs[2 * row, col].set_title(int(labels_test[offset + i]))
    axs[2 * row, col].axis("off")

    draw_bars(
        axs[2 * row + 1, col], jax.nn.softmax(y_predicted[i]), labels_test[offset + i]
    )

    col += 1
    if col == images_per_row:
        col = 0
        row += 1

# Adversarial attacks

You have trained your classifier. Cool, isn't it? Let us now try to fool it.

An adversarial attack consists of an (almost imperceptible) modification of the image, aimed at fooling the classifier into making a mistake.
See e.g. [this article](https://www.wired.com/story/tesla-speed-up-adversarial-example-mgm-breach-ransomware/)

To hack the classifier, compute the gradient of cross entropy loss funcion with respect to the input (not to the parameters!). Then, superimpose a multiple of the gradient to the original image. See e.g. [this article](https://www.tensorflow.org/tutorials/generative/adversarial_fgsm).

Namely, follow these steps:

1. Compute the gradient of `loss_fn` with respect to the **input of the CNN (the image)**
2. Compute the perturbed image, by summing $\epsilon \text{ sign}(\nabla_x \texttt{loss\_fn})$
3. Clip the result in $[0, 1]$

Visualize the original and the hacked images and the corresponding prediction of the classifier.


In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


# FGSM attack
def fgsm_attack(params, image, label, epsilon):
    # SOLUTION-BEGIN
    # Compute gradient of loss w.r.t. the input image
    grad_fn = jax.grad(lambda img: loss_fn(params, img, label), has_aux=True)

    gradient, _ = grad_fn(image)

    # Apply perturbation
    adv_image = image + epsilon * jnp.sign(gradient)

    # Clip back to valid pixel range
    adv_image = jnp.clip(adv_image, 0.0, 1.0)
    # SOLUTION-END
    return adv_image


# By trial-and-error I give you the information that for the following images
# an `epsilon = 0.05` is large enough to fool the CNN
epsilon = 0.05
for idx in [11, 66, 115, 244]:
    # SOLUTION-BEGIN
    x = test_ds["image"][idx : idx + 1]
    y = test_ds["label"][idx]

    # True prediction
    logits = cnn.apply({"params": state.params}, x)
    true_pred = jnp.argmax(logits, axis=-1)
    print("Original prediction:", true_pred)

    # Create adversarial example
    x_adv = fgsm_attack(state.params, x, y, epsilon)

    # Prediction on adversarial image
    logits_adv = cnn.apply({"params": state.params}, x_adv)
    adv_pred = jnp.argmax(logits_adv, axis=-1)
    print("Adversarial prediction:", adv_pred)

    # Plot
    plt.figure(figsize=(6, 3))
    plt.subplot(1, 2, 1)
    plt.title(f"Original ({true_pred})")
    plt.imshow(x[0, :, :, 0], cmap="gray")

    plt.subplot(1, 2, 2)
    plt.title(f"Adversarial ({adv_pred})")
    plt.imshow(x_adv[0, :, :, 0], cmap="gray")

    plt.show()
    # SOLUTION-END

In [None]:
def get_first_layer_output(cnn_module, params, input_data):
    """
    Manually applies the first convolutional layer and ReLU.
    """
    # 1. Access the first Conv layer from the parameters
    # The name of the first Conv layer is typically 'Conv_0' by default in Flax
    # if it's the first nn.Conv without an explicit name.
    
    # Instantiate the first Conv layer with the correct features and kernel size
    conv_layer = nn.Conv(features=32, kernel_size=(3, 3), name='Conv_0')

    # Apply the convolution using the parameters (weights and biases)
    # This requires using the `apply` method of the layer and passing the specific subset of parameters
    conv_output = conv_layer.apply({'params': params['Conv_0']}, input_data)
    
    # 2. Apply ReLU
    final_output = nn.relu(conv_output)
    
    return final_output

idx = 0
first_layer_output = get_first_layer_output(cnn, state.params, test_ds["image"][idx : idx + 1])
first_layer_output.shape

for i in range(32):
    plt.figure()
    plt.imshow(first_layer_output[0, :, :, i])