# JAX Implementation of MNIST

## Import Modules and Data
I will bring in the data with torchvision. For this module, I'll load the data manually

In [1]:
import torch
import jax
import optax
from torch.utils.data import Dataset
from torchvision import datasets
import jax.numpy as jnp
from jax import random, grad, jit, value_and_grad, vmap

In [2]:
training_data = datasets.MNIST(root = './data/mnist', train=True, download=True)
test_data = datasets.MNIST(root = './data/mnist', train=False, download=True)
train_images_np = training_data.data.numpy()
train_labels_np = training_data.targets.numpy()
test_images_np = test_data.data.numpy()
test_labels_np = test_data.targets.numpy()

In [3]:
train_images_jnp = jnp.array(train_images_np)
train_labels_jnp = jnp.array(train_labels_np)

test_images_jnp = jnp.array(test_images_np)
test_labels_jnp = jnp.array(test_labels_np)

In [4]:
print(f"train_images_jnp shape: {train_images_jnp.shape}")
print(f"train_labels_jnp shape: {train_labels_jnp.shape}")
print(f"test_images_jn shape: {test_images_jnp.shape}")
print(f"test_labels_jnp shape: {test_labels_jnp.shape}")

train_images_jnp shape: (60000, 28, 28)
train_labels_jnp shape: (60000,)
test_images_jn shape: (10000, 28, 28)
test_labels_jnp shape: (10000,)


## Preprocess Data
- Normalization: Bring range to [0,1]
We will skip standardization, which is more for pre-trained models
- Flatten: instead of 28*28 images, we are using (784,) vectors

In [5]:
train_images_jnp = train_images_jnp.astype(jnp.float32) / 255.0
test_images_jnp = test_images_jnp.astype(jnp.float32) / 255.0
train_images_jnp = train_images_jnp.reshape(train_images_jnp.shape[0], 784)
test_images_jnp = test_images_jnp.reshape(test_images_jnp.shape[0], 784)

print(f"train_images_jnp shape: {train_images_jnp.shape}")
print(f"train_labels_jnp shape: {train_labels_jnp.shape}")
print(f"test_images_jn shape: {test_images_jnp.shape}")
print(f"test_labels_jnp shape: {test_labels_jnp.shape}")

train_images_jnp shape: (60000, 784)
train_labels_jnp shape: (60000,)
test_images_jn shape: (10000, 784)
test_labels_jnp shape: (10000,)


## Create Forward pass and Loss in JAX

In [6]:
def init_mlp_params(layer_sizes, key):
    params = []
    keys = random.split(key, len(layer_sizes) - 1)
    for nin, nout, layer_key in zip(layer_sizes[:-1], layer_sizes[1:], keys):
        w_key, b_key = random.split(layer_key)
        layer_params = {
            'w': random.normal(w_key, (nout, nin)),
            'b': jnp.zeros((nout,))
        }
        params.append(layer_params)
    return params

#inputs shape is (784,)
def mlp_apply(params, inputs):
    x = inputs
    for layer_params in params[:-1]:
        z = layer_params['w'] @ x + layer_params['b']
        x = jax.nn.relu(z)
    final_layer_params = params[-1]
    output = final_layer_params['w'] @ x + final_layer_params['b']
    return output

#inputs shape is (128, 784)
def batched_loss_fn(params, inputs, targets):
    predictions = vmap(mlp_apply, in_axes=(None,0))(params, inputs)
    return jnp.mean(optax.losses.softmax_cross_entropy_with_integer_labels(predictions, labels=targets))

weight_decay = 1e-4  # A good starting point
solver = optax.adamw(learning_rate=0.001, weight_decay=weight_decay)
@jit 
def batched_train_step(params, inputs, targets, opt_state):
    loss, gradients = value_and_grad(batched_loss_fn)(params, inputs, targets)
    updates, opt_state = solver.update(gradients, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

## Initalize Model and test the forward pass and batched loss

- The single prediction should look like 10 random numbers
- The Cross Entropy Loss function should hover around -ln(1/10) for each number, around 2.3

In [7]:
# Chatgpt told me to use these layer sizes, optimizer, and learning rate. Since this is not a hyperparameter tuning exercise, I will just use them
layer_sizes = [784, 512, 256, 10]
key = random.PRNGKey(42)
params = init_mlp_params(layer_sizes, key)
opt_state = solver.init(params)

#single pass
single_prediction = mlp_apply(params, train_images_jnp[0])
print(f'single prediction (one hot encoded vector): \n{single_prediction}')

#batched loss
batch_size = 128
single_batched_loss = batched_loss_fn(params, train_images_jnp[0:batch_size], train_labels_jnp[0:batch_size])
print(f'\ncross entropy loss of numbers above for batch of {batch_size}: \n{single_batched_loss}')

single prediction (one hot encoded vector): 
[  391.79626   1618.063      696.4814     367.7426    2863.2422
  2379.1455   -1363.0208    -735.86993     74.458496   840.37585 ]

cross entropy loss of numbers above for batch of 128: 
1868.5386962890625


## Train the model

In [8]:
#redeclaring for idempotency
layer_sizes = [784, 512, 256, 10]
key = random.PRNGKey(42)
params = init_mlp_params(layer_sizes, key)
opt_state = solver.init(params)

epochs = 30
shuffling_key = random.PRNGKey(42)
n_train_samples = train_images_jnp.shape[0]

def test_model(params, test_data, test_labels):
    predictions = vmap(mlp_apply, in_axes=(None,0))(params, test_data)
    predicted_labels = jnp.argmax(predictions, axis=1)
    correct_count = jnp.sum(predicted_labels == test_labels)
    return correct_count / test_labels.shape[0]

accuracy = test_model(params, test_images_jnp, test_labels_jnp)
print(f'Accuracy before training: {accuracy}')
for epoch in range(epochs):
    shuffling_key, subkey = random.split(shuffling_key)
    permuted_indices = random.permutation(subkey, n_train_samples)
    for i in range(0, n_train_samples, batch_size):
        batch_indices = batch_indices = permuted_indices[i : i + batch_size]
        params, opt_state, loss = batched_train_step(params, train_images_jnp[batch_indices], train_labels_jnp[batch_indices], opt_state)
    accuracy = test_model(params, test_images_jnp, test_labels_jnp)
    print(f'Accuracy after epoch {epoch}: {accuracy}')

Accuracy before training: 0.062300000339746475
Accuracy after epoch 0: 0.8841000199317932
Accuracy after epoch 1: 0.9138000011444092
Accuracy after epoch 2: 0.9204999804496765
Accuracy after epoch 3: 0.9311000108718872
Accuracy after epoch 4: 0.9369000196456909
Accuracy after epoch 5: 0.9373000264167786
Accuracy after epoch 6: 0.9424999952316284
Accuracy after epoch 7: 0.9426000118255615
Accuracy after epoch 8: 0.945900022983551
Accuracy after epoch 9: 0.9402999877929688
Accuracy after epoch 10: 0.9459999799728394
Accuracy after epoch 11: 0.9473999738693237
Accuracy after epoch 12: 0.9531999826431274
Accuracy after epoch 13: 0.9539999961853027
Accuracy after epoch 14: 0.9534000158309937
Accuracy after epoch 15: 0.9545000195503235
Accuracy after epoch 16: 0.9556999802589417
Accuracy after epoch 17: 0.9591000080108643
Accuracy after epoch 18: 0.957099974155426
Accuracy after epoch 19: 0.9567000269889832
Accuracy after epoch 20: 0.9589999914169312
Accuracy after epoch 21: 0.95870000123977

In [9]:
accuracy = test_model(params, test_images_jnp, test_labels_jnp)
print(f'final accuracy: {accuracy}')

final accuracy: 0.960099995136261


## Some Reflections

So where can I improve...

1) Data loading: loading all the data into memory will crash for bigger datasets. I can fix this with a data loader, which is an abstraction I will start using in the next module.
2) From now on, I can encapsulate params and opt_state into one TrainState named tuple. This will ease complexity as more training things come into play like batch normalization statistics

Review:

1) Numpy operations. That sum(==) statement I had to copy from chatgpt, and it should be idiomatically engrained
2) jnp axes. I'm not exactly sure what vmap(in_axes(None, 0)) or axis=1 really means.
3) Keys. Understand exactly what splitting keys, does because now it's just a magical random number generator