<a href="https://colab.research.google.com/github/kbrezinski/JAX-Practice/blob/main/nn_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MLP Training on MNIST

In [33]:
# init MLP and predict()
# torch dataloader
# training loop, loss fn

import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad
import jax

In [27]:
# constants
seed = 0
mnist_img_size = 784

def init_MLP(layer_widths, parent_key, scale=0.01):

  params = []
  keys = jax.random.split(parent_key, num=len(layer_widths) - 1)

  for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):

    weight_key, bias_key = jax.random.split(key)
    params.append(
        [scale * jax.random.normal(weight_key, shape=(out_width, in_width)), # weights = (hidden dims, input)
         scale * jax.random.normal(bias_key, shape=(out_width,))]           # biases = (hidden_dims, 1)
    )
  return params

parent_key = jax.jax.random.PRNGKey(seed)
MLP_params = init_MLP([784, 512, 256, 10], parent_key)
jax.tree_map(lambda x: x.shape, MLP_params)

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]

In [36]:
def predict_MLP(params, x):

  hidden_layers = params[:-1]

  activation = x
  for w, b in hidden_layers:
    # ReLu( dot((512, 784)(784, 1)))
    activation = jax.nn.relu(jnp.dot(w, activation) + b)

  w_last, b_last = params[-1]
  logits = jnp.dot(w_last, activation) + b_last

  # mimics the same behavior as the softmax
  return logits - logsumexp(logits) # log(exp(o1)) - log(sum(exp(01), exp(02), ..., exp(03)))

dummy_img_flat = np.random.randn(mnist_img_size)
assert dummy_img_flat.shape == (784,)

prediction = predict_MLP(MLP_params, dummy_img_flat)
 

In [None]:
#