<a href="https://colab.research.google.com/github/mohsenh17/jaxLearning/blob/main/mlpMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
import numpy as np
import jax.numpy as jnp
from jax.scipy.special import logsumexp
import jax
from jax import jit, vmap, pmap, grad, value_and_grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

In [40]:
mnist_img_size = (28, 28)
def mlp_model_initialization(parent_key, layer_widths):
    params =[]
    keys = jax.random.split(parent_key, num=len(layer_widths)-1)

    for n_in, n_out, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weights_key, bias_key = jax.random.split(key)
        weights = jax.random.normal(weights_key, (n_in, n_out))* jnp.sqrt(2 / n_in)
        biases = jax.random.normal(bias_key, (n_out,))
        params.append(dict(weights=weights, biases=biases))
    return params

# test
key = jax.random.PRNGKey(0)
mlp_params = mlp_model_initialization(key, [784, 512, 512, 10])
print(jax.tree.map(lambda x: x.shape, mlp_params))

[{'biases': (512,), 'weights': (784, 512)}, {'biases': (512,), 'weights': (512, 512)}, {'biases': (10,), 'weights': (512, 10)}]


In [88]:
def mlp_forward(params, x):
    *hidden_layers, last_layer = params
    for layer in hidden_layers:
        x = jax.nn.relu(jnp.dot(x, layer['weights']) + layer['biases'])
    logits = jax.nn.softmax(jnp.dot(x, last_layer['weights']) + last_layer['biases'])
    return logits

# tests

# test single example

dummy_img_flat = np.random.randn(np.prod(mnist_img_size))
print(dummy_img_flat.shape)

prediction = mlp_forward(mlp_params, dummy_img_flat)
print(prediction.shape)

# test batched function
batched_MLP_predict = vmap(mlp_forward, in_axes=(None, 0))

dummy_imgs_flat = np.random.randn(16, np.prod(mnist_img_size))
print(dummy_imgs_flat.shape)
predictions = batched_MLP_predict(mlp_params, dummy_imgs_flat)
print(predictions.shape)

(784,)
(10,)
(16, 784)
(16, 10)


In [71]:
def custom_transform(x):
    return np.ravel(np.array(x, dtype=np.float32))

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    labels = np.array(transposed_data[1])
    imgs = np.array(transposed_data[0])

    return imgs, labels

batch_size = 128
train_dataset = MNIST(root='train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size, shuffle=True, collate_fn=custom_collate_fn, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size, shuffle=False, collate_fn=custom_collate_fn, drop_last=True)

# test
batch_data = next(iter(train_loader))
imgs = batch_data[0]
lbls = batch_data[1]
print(imgs.shape, imgs[0].dtype, lbls.shape, lbls[0].dtype)

# Loading the whole dataset into memory
train_images = jnp.array(train_dataset.data).reshape(len(train_dataset), -1)
train_lbls = jnp.array(train_dataset.targets)

test_images = jnp.array(test_dataset.data).reshape(len(test_dataset), -1)
test_lbls = jnp.array(test_dataset.targets)

(128, 784) float32 (128,) int64


In [91]:
def loss_fn(params, imgs, lbls):
    logits = batched_MLP_predict(params, imgs)
    logits = jnp.clip(logits, 1e-10, 1-1e-10)
    print(logits.shape[-1])
    labels_one_hot = jax.nn.one_hot(lbls, logits.shape[-1])
    #q = jnp.max(logits*labels_one_hot, axis=1)
    #print(q.shape)
    loss = -jnp.mean((jnp.log(logits)*labels_one_hot))
    #loss = -jnp.mean(jnp.log(jnp.max(labels_one_hot * logits)))
    #loss = -jnp.mean(jnp.max(labels_one_hot * logits))
    #print(loss)
    #loss = -jnp.mean(jnp.sum(jnp.log(logits)*labels_one_hot, axis=1))
    return loss, logits

def accuracy(params, dataset_imgs, dataset_lbls):
    pred_classes = jnp.argmax(batched_MLP_predict(params, dataset_imgs), axis=1)
    return jnp.mean(dataset_lbls == pred_classes)
@jit
def update(params, imgs, lbls, lr=0.01):
    (loss, logits), grads = value_and_grad(loss_fn, has_aux=True)(params, imgs, lbls)
    return loss,logits, jax.tree.map(lambda p,g: p- lr*g, params, grads)

# Create a MLP
MLP_params = mlp_model_initialization(key, [np.prod(mnist_img_size), 512, 256, len(MNIST.classes)])

num_epochs = 20
for epoch in range(num_epochs):

    for cnt, (imgs, lbls) in enumerate(train_loader):


        loss, logits, MLP_params = update(MLP_params, imgs, lbls)

        #if cnt % 400 == 0:
    print(loss)



    print(f'Epoch {epoch}, train acc = {accuracy(MLP_params, train_images, train_lbls)} test acc = {accuracy(MLP_params, test_images, test_lbls)}')



10
0.79526097
Epoch 0, train acc = 0.613099992275238 test acc = 0.6128000020980835
0.68373513
Epoch 1, train acc = 0.6370333433151245 test acc = 0.6345999836921692
0.9147736
Epoch 2, train acc = 0.6436333656311035 test acc = 0.6358000040054321
0.9293503
Epoch 3, train acc = 0.6506500244140625 test acc = 0.64410001039505
0.7195587
Epoch 4, train acc = 0.6570667028427124 test acc = 0.6502000093460083
0.8274972
Epoch 5, train acc = 0.6597000360488892 test acc = 0.6488999724388123
0.85227454
Epoch 6, train acc = 0.6636833548545837 test acc = 0.6539999842643738
0.7735247
Epoch 7, train acc = 0.6659833192825317 test acc = 0.6565999984741211
0.5037578
Epoch 8, train acc = 0.7561833262443542 test acc = 0.7479999661445618
0.48337212
Epoch 9, train acc = 0.7600666880607605 test acc = 0.7487999796867371
0.55207837
Epoch 10, train acc = 0.7646833658218384 test acc = 0.7519999742507935
0.59893435
Epoch 11, train acc = 0.7681833505630493 test acc = 0.7565999627113342
0.58268565
Epoch 12, train acc =