# Handwriting recognition


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import jax.numpy as jnp
import jax

from flax import linen as nn
from flax.training import train_state
import optax

### Data import and visualization


Import the MNIST train dataset ([https://en.wikipedia.org/wiki/MNIST_database](https://en.wikipedia.org/wiki/MNIST_database))


In [None]:
# This dataset is contained in the sample data directory of Google Colab online runtimes
data = np.genfromtxt("./mnist_train_small.csv", delimiter=",")
data.shape

Store the data in a 4-th order tensor (samples, x-pixel, y-pixel, channels) and the labels in a vector.
**NOTE:** The labels are the first column of the data matrix


In [None]:
# FILL HERE

Visualize the first 30 pictures with the corresponding labels


In [None]:
# FILL HERE

Create a [one-hot](https://en.wikipedia.org/wiki/One-hot) representation of the labels, that is a matrix where each row corresponds to a class (i.e. a digit).
the entries of the matrix are 1 if the sample corresponds to that digit, 0 otherwise.


In [None]:
# FILL HERE

Check that the matrix has exactly one element "1" in each row.


In [None]:
# FILL HERE

### ANN training


Define the architecture of the neural network.
For more details on CNNs see https://stanford.edu/~shervine/teaching/cs-230/cheatsheet-convolutional-neural-networks


In [None]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # Flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)

        # The `softmax_cross_entropy` expects unnormalized logits.
        # There is also the `softmax_cross_entropy_with_integer_labels`
        # version that uses integers target labels.
        # If you apply a softmax first, you turn logits into probabilities, and the
        # loss might becomes numerically unstable and incorrect. Optax/JAX expects to
        # handle the softmax internally in a stable way (using logsumexp tricks).
        x = nn.Dense(features=10)(x)  # There are 10 classes in MNIST
        return x


cnn = CNN()

table = cnn.tabulate(
    jax.random.PRNGKey(0), jnp.zeros((1, 28, 28, 1)), console_kwargs={"width": 200}
)

print(table)

Write a function to compute the **cross entropy loss** and the **accuracy** of the model given as parameters the unormalized logits (the output of the CNN) and the one-hot encoded target value. To compute the loss you can exploit `optax.softmax_cross_entropy`, check the documentation for the details.


In [None]:
def compute_metrics(logits, labels_onehot):
# FILL HERE
    return {"loss": loss, "accuracy": accuracy}

Here we define functions used for the training


In [None]:
@jax.jit
def loss_fn(params, x, y):
# FILL HERE
    return loss, logits


@jax.jit
def train_step(state, x, y):
# FILL HERE
    return state, metrics


def eval_model(state, dataset):
# FILL HERE
    return metrics["loss"], metrics["accuracy"]

Define the function for one training epoch


In [None]:
def train_epoch(state, train_ds, batch_size, epoch, rng):
# FILL HERE

    return state, training_epoch_metrics

Prepare data by randomizing and creating the train-validation split


In [None]:
np.random.seed(0)

# FILL HERE

train_ds = {
    "image": jnp.array(x_data[train_idxs]),
    "label": jnp.array(labels_onehot[train_idxs], dtype=jnp.float32),
}
valid_ds = {
    "image": jnp.array(x_data[valid_idx]),
    "label": jnp.array(labels_onehot[valid_idx], dtype=jnp.float32),
}

Run training


In [None]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Initialize lists to store metrics for graph visualization
training_losses = []
training_accuracies = []
valid_losses = []
valid_accuracies = []

# Hyperparameters
num_epochs = 10
batch_size = 64
learning_rate = 0.001

# Initialize the CNN and optimizer
# FILL HERE

print("epoch | train loss | train acc | valid loss | valid acc")
for epoch in range(1, num_epochs + 1):
# FILL HERE

## Testing


Load the dataset `sample_data/mnist_test.csv` and normalize and shape the data.


In [None]:
data_test = np.genfromtxt("./mnist_test.csv", delimiter=",")
data_test.shape
labels_test = data_test[:, 0]
x_test = data_test[:, 1:] / 255

labels_onehot_test = np.zeros((x_test.shape[0], 10))
for i in range(10):
    labels_onehot_test[labels_test == i, i] = 1.0


test_ds = {
    "image": jnp.array(x_test.reshape((-1, 28, 28, 1))),
    "label": jnp.array(labels_onehot_test),
}

Compute the accuracy of the classifier on this dataset.


In [None]:
# FILL HERE

Use the following script to visualize the predictions on a bunch of test images.


In [None]:
offset = 0
n_images = 40

images_per_row = 10
y_predicted = state.apply_fn({"params": state.params}, test_ds["image"])


def draw_bars(ax, y_predicted, label):
    myplot = ax.bar(range(10), (y_predicted))
    ax.set_ylim([0, 1])
    ax.set_xticks(range(10))

    label_predicted = np.argmax(y_predicted)
    if label == label_predicted:
        color = "green"
    else:
        color = "red"
    myplot[label_predicted].set_color(color)


import math

n_rows = 2 * math.ceil(n_images / images_per_row)
_, axs = plt.subplots(n_rows, images_per_row, figsize=(3 * images_per_row, 3 * n_rows))
row = 0
col = 0
for i in range(n_images):
    axs[2 * row, col].imshow(x_test[offset + i].reshape((28, 28)), cmap="gray")
    axs[2 * row, col].set_title(int(labels_test[offset + i]))
    axs[2 * row, col].axis("off")

    draw_bars(
        axs[2 * row + 1, col], jax.nn.softmax(y_predicted[i]), labels_test[offset + i]
    )

    col += 1
    if col == images_per_row:
        col = 0
        row += 1

# Adversarial attacks

You have trained your classifier. Cool, isn't it? Let us now try to fool it.

An adversarial attack consists of an (almost imperceptible) modification of the image, aimed at fooling the classifier into making a mistake.
See e.g. [this article](https://www.wired.com/story/tesla-speed-up-adversarial-example-mgm-breach-ransomware/)

To hack the classifier, compute the gradient of cross entropy loss funcion with respect to the input (not to the parameters!). Then, superimpose a multiple of the gradient to the original image. See e.g. [this article](https://www.tensorflow.org/tutorials/generative/adversarial_fgsm).

Namely, follow these steps:

1. Compute the gradient of `loss_fn` with respect to the **input of the CNN (the image)**
2. Compute the perturbed image, by summing $\epsilon \text{ sign}(\nabla_x \texttt{loss\_fn})$
3. Clip the result in $[0, 1]$

Visualize the original and the hacked images and the corresponding prediction of the classifier.


In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt


# FGSM attack
def fgsm_attack(params, image, label, epsilon):
# FILL HERE
    return adv_image


# By trial-and-error I give you the information that for the following images
# an `epsilon = 0.05` is large enough to fool the CNN
epsilon = 0.05
for idx in [11, 66, 115, 244]:
# FILL HERE