<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/nn_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MLP Training on MNIST

In [None]:
# init MLP and predict()
# torch dataloader
# training loop, loss fn
import jax
import numpy as np
import jax.numpy as jnp

from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad, value_and_grad

In [173]:
# constants
seed = 0
mnist_img_size = 784

def init_MLP(layer_widths, parent_key, scale=0.01):

  params = []
  keys = jax.random.split(parent_key, num=len(layer_widths) - 1)

  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):

    weight_key, bias_key = jax.random.split(key)
    params.append(
        [scale * jax.random.normal(weight_key, shape=(out_width, in_width)), # weights = (hidden dims, input)
         scale * jax.random.normal(bias_key, shape=(out_width,))]           # biases = (hidden_dims, 1)
    )
  return params

parent_key = jax.jax.random.PRNGKey(seed)
MLP_params = init_MLP([784, 512, 256, 10], parent_key)
jax.tree_map(lambda x: x.shape, MLP_params)

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]

In [None]:
def predict_MLP(params, x):

  hidden_layers = params[:-1]

  activation = x
  for w, b in hidden_layers:
    # ReLu( dot((512, 784)(784, 1)))
    activation = jax.nn.relu(jnp.dot(w, activation) + b)

  w_last, b_last = params[-1]
  logits = jnp.dot(w_last, activation) + b_last

  # mimics the same behavior as the softmax
  return logits - logsumexp(logits) # log(exp(o1)) - log(sum(exp(01), exp(02), ..., exp(03)))

# test on single image dimension (1, 784)
dummy_img_flat = np.random.randn(np.prod(mnist_img_size))
assert dummy_img_flat.shape == (784,)
prediction = predict_MLP(MLP_params, dummy_img_flat)
assert prediction.shape == (10,)

# test on a batched dimension (17, 784)
dummy_imgs_flats = np.random.randn(16, np.prod(mnist_img_size))
batched_MLP_predict = vmap(predict_MLP, in_axes=(None, 0))  # None to broadcast only params, batched dimension on zeroth dimension set to 0
predictions = batched_MLP_predict(MLP_params, dummy_imgs_flats)

In [171]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
## adding the dataloader

batch_size = 256
transform_ds = lambda x: np.ravel(np.array(x, dtype=np.float32) / 255.)

def collate_custom(batch):
  transposed_data = list(zip(*batch)) 
  return np.array(transposed_data[0]), np.array(transposed_data[1])

train_ds = MNIST(root='train_mnist', train=True, download=True, transform=transform_ds)
train_dl = DataLoader(train_ds, batch_size, shuffle=True, collate_fn=collate_custom, drop_last=True)

test_ds = MNIST(root='test_mnist', train=False, download=True, transform=transform_ds)
test_dl = DataLoader(test_ds, batch_size, shuffle=False, collate_fn=collate_custom, drop_last=True)

train_images = jnp.array(train_ds.data).reshape(len(train_ds), -1)
train_lbls = jnp.array(train_ds.targets)

test_images = jnp.array(test_ds.data).reshape(len(test_ds), -1)
test_lbls = jnp.array(test_ds.targets)

In [175]:
## training loop function
epochs = 5

def accuracy(params, dataset_imgs, dataset_lbls):
  
  pred_classes = jnp.argmax(batched_MLP_predict(params, dataset_imgs), axis=1)
  return jnp.mean(dataset_lbls == pred_classes)

def loss_fn(params, imgs, gt_labels):
  # feedforward implementaiton
  predictions = batched_MLP_predict(params, imgs)  # size = (b, 10)
  return -jnp.mean(predictions * gt_labels)

def update(params, imgs, gt_labels, lr=0.01):
  # determine the gradients
  loss, grads = value_and_grad(loss_fn)(params, imgs, gt_labels)
  return loss, jax.tree_multimap(lambda p, g: p - lr * g, params, grads)

## main training loop
for epoch in range(epochs):
  for idx, (imgs, labels) in enumerate(train_dl):
    gt_labels = jax.nn.one_hot(labels, len(MNIST.classes))
    loss, MLP_params = update(MLP_params, imgs, gt_labels)
  
    if idx % 100 == 0:
      print(loss)

  print(f"Epoch: {epoch} | Train acc: {accuracy(MLP_params, train_images, train_lbls)} | Test acc: {accuracy(MLP_params, test_images, test_lbls)}")
      

0.23023362
0.23022087


KeyboardInterrupt: ignored