<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/mlpMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [18]:
mnist_img_size = (28, 28)
def mlp_model_initialization(parent_key, layer_widths):
    params =[]
    keys = jax.random.split(parent_key, num=len(layer_widths)-1)

    for n_in, n_out, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weights_key, bias_key = jax.random.split(key)
        weights = jax.random.normal(weights_key, (n_in, n_out))* jnp.sqrt(2 / n_in)
        biases = jax.random.normal(bias_key, (n_out,))
        params.append(dict(weights=weights, biases=biases))
    return params

# test
key = jax.random.PRNGKey(0)
mlp_params = mlp_model_initialization(key, [784, 512, 512, 10])
print(jax.tree.map(lambda x: x.shape, mlp_params))

[{'biases': (512,), 'weights': (784, 512)}, {'biases': (512,), 'weights': (512, 512)}, {'biases': (10,), 'weights': (512, 10)}]


In [23]:
def mlp_forward(params, x):
    *hidden_layers, last_layer = params
    for layer in hidden_layers:
        x = jax.nn.relu(jnp.dot(x, layer['weights']) + layer['biases'])
    logits = jax.nn.softmax(jnp.dot(x, last_layer['weights']) + last_layer['biases'])
    return logits

# tests

# test single example

dummy_img_flat = np.random.randn(np.prod(mnist_img_size))
print(dummy_img_flat.shape)

prediction = mlp_forward(mlp_params, dummy_img_flat)
print(prediction.shape)

# test batched function
batched_MLP_predict = vmap(mlp_forward, in_axes=(None, 0))

dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))
print(dummy_imgs_flat.shape)
predictions = batched_MLP_predict(mlp_params, dummy_imgs_flat)
print(predictions.shape)

(784,)
(10,)
(16, 784)
(16, 10)
