# MNIST using minijax
Train a multi-layered perceptron (MLP) for MNIST handwritten digit recognition. 


In [4]:
import itertools as it

import numpy as np

from minijax.compute_graph import make_graph
from minijax.core import add, div, mul, sqrt, square, sub
from minijax.eval import Array, ones, zeros
from minijax.grad import value_and_grad
from minijax.jit import jit
from minijax.nested_containers import map_structure
from minijax.nn import cross_entropy, init_mlp, mlp
from minijax.vmap import vmap
from mnist_dataset import load_mnist

## Model Setup

In [5]:
in_size = 784  # flat 28x28 images
layers = [128, 10]  # 10 classes

num_epochs = 10
batch_size = 32
learning_rate = 1e-3

params = init_mlp(in_size, layers, rng_key=0)

In [6]:
x = ones((28, 28))
y = mlp(x, params)
print(y)

Array([ 1.9688427   0.02259841 -0.05577275  0.72828809  0.62225548  2.10965871
       -1.57625537 -1.56898686  0.32959784  1.34568429])


## vmap

In [7]:

model = vmap(mlp, (0, None))
x = Array([[1.0] * in_size, [0.0] * in_size])
y = model(x, params)
print(y)

Array([[ 1.9688427   0.02259841 -0.05577275  0.72828809  0.62225548  2.10965871
        -1.57625537 -1.56898686  0.32959784  1.34568429]
       [ 0.          0.          0.          0.          0.          0.
         0.          0.          0.          0.        ]])


## make_graph
Let's inspect the compute graph of the model.

In [8]:

cg = make_graph(model)(x, params)
print(cg)

input: a[2, 784] b[784, 128] c[128] d[128, 10] e[10]
  f[2, 784] = reshape[new_shape: (2, -1)] a[2, 784]
  g[2, 128] = dot f[2, 784] b[784, 128]
  h[2, 128] = add g[2, 128] c[128]
  i[2, 128] = relu h[2, 128]
  j[2, 10] = dot i[2, 128] d[128, 10]
  k[2, 10] = add j[2, 10] e[10]
output: k[2, 10]


## Loss & grad

In [9]:
def loss(x, y_true, params):
    y_pred = model(x, params)
    return cross_entropy(y_pred, y_true)


# One-hot encoded class labels
# => vectors are 1.0 at the correct class
y_true = Array(
    [
        [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
        [0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
    ]
)
loss_val, grads = value_and_grad(loss)(x, y_true, params)

print("Loss:", loss_val)
print("Gradients:", grads)

Loss: Array(1.8139983104301463)
Gradients: (Array([[ 0.01794723 -0.02101453 -0.02960505 ...  0.0151266  -0.00292274
         0.00611423]
       [ 0.          0.          0.         ...  0.          0.
         0.        ]]), Array([[0.66270576 1.63582791 1.67501349 1.28298307 1.33599937 0.59229776
        2.4352548  2.43162054 1.4823282  0.97428497]
       [1.15129255 1.15129255 1.15129255 1.15129255 1.15129255 1.15129255
        1.15129255 1.15129255 1.15129255 1.15129255]]), [{'weight': Array([[ 0.         -0.09364378  0.         ...  0.          0.
         0.        ]
       [ 0.         -0.09364378  0.         ...  0.          0.
         0.        ]
       [ 0.         -0.09364378  0.         ...  0.          0.
         0.        ]
       ...
       [ 0.         -0.09364378  0.         ...  0.          0.
         0.        ]
       [ 0.         -0.09364378  0.         ...  0.          0.
         0.        ]
       [ 0.         -0.09364378  0.         ...  0.          0.
      

## Load Dataset

In [None]:
print("Loading MNIST dataset...")
train_images, train_labels, test_images, test_labels = load_mnist()
print(f"Training set: {train_images.shape} images, {train_labels.shape} labels")
print(f"Test set: {test_images.shape} images, {test_labels.shape} labels")

train_images[0], train_labels[0]

## Training with jit

In [None]:
def accuracy(x, y_true, params):  # Calculate accuracy in numpy. It's only for logging.
    y_pred = model(x, params)
    y_true, y_pred = y_true.array, y_pred.array
    return np.mean(np.argmax(y_pred, axis=-1) == np.argmax(y_true, axis=-1))


def adam(params, grads, opt_state, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8):
    def m_update(g, m_prev):
        return add(mul(Array(beta1), m_prev), mul(Array(1 - beta1), g))

    def v_update(g, v_prev):
        return add(mul(Array(beta2), v_prev), mul(Array(1 - beta2), square(g)))

    m_prevs, v_prevs, beta1powtm1, beta2powtm2 = opt_state
    # map_structure applies a function to each array in a nested container, e.g. the grads list
    m_new = map_structure(m_update, grads, m_prevs)
    v_new = map_structure(v_update, grads, v_prevs)
    beta1powt = mul(beta1powtm1, Array(beta1))
    beta2powt = mul(beta2powtm2, Array(beta2))

    def param_update(p, m, v):
        m_hat = div(m, sub(Array(1), beta1powt))
        v_hat = div(v, sub(Array(1), beta2powt))
        return sub(p, mul(Array(lr), div(m_hat, add(sqrt(v_hat), Array(eps)))))

    new_params = map_structure(param_update, params, m_new, v_new)
    return new_params, (m_new, v_new, beta1powt, beta2powt)


def init_adam_state(params):
    m = map_structure(lambda p: zeros(p.shape), params)
    v = map_structure(lambda p: zeros(p.shape), params)
    beta1powt = Array(1)
    beta2powt = Array(1)
    return m, v, beta1powt, beta2powt


@jit  # this will make things only a little faster
def train_step(x, y_true, params, opt_state):
    loss_val, (_, _, param_grads) = value_and_grad(loss)(x, y_true, params)
    new_params, new_opt_state = adam(params, param_grads, opt_state, lr=learning_rate)
    return new_params, new_opt_state, loss_val

In [None]:
opt_state = init_adam_state(params)

epoch_len = train_images.shape[0] // batch_size
np_rng = np.random.default_rng(1)
for t in range(num_epochs):
    rand_perm = np_rng.permutation(train_images.shape[0])
    loss_vals = []
    for i, batch_idx in enumerate(it.batched(rand_perm, batch_size)):
        x, y = Array(train_images[batch_idx, :]), Array(train_labels[batch_idx, :])
        params, opt_state, loss_val = train_step(x, y, params, opt_state)

        loss_vals.append(loss_val.array.item())
        if i % 100 == 99:
            avg_loss = sum(loss_vals) / len(loss_vals)
            loss_vals = []
            print(f"[Epoch {t + 1}, {(i + 1) / epoch_len:3.0%}]: loss = {avg_loss:.4f}")

    test_loss = loss(Array(test_images), Array(test_labels), params).array.item()
    test_accuracy = accuracy(Array(test_images), Array(test_labels), params).item()
    print(f"Epoch {t + 1}: test loss: {test_loss:.4f}, test accuracy: {test_accuracy:.2%}")