In [28]:
import numpy as np

import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp
from jax import jit, vmap, pmap, grad

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

# MLP training on MNIST

In [None]:
# TODO: init MLP and add the predict
# TODO: add data loading in PyTorch
# TODO: ad the training loop, loss fn

In [30]:
seed = 0
mnist_img_size = (28, 28)
batch_size = 128

def init_MLP(layer_widths, rng_key, scale=0.01):

    params = []
    keys = jax.random.split(rng_key, num=len(layer_widths)-1)

    for in_width, out_width, key in zip(layer_widths[:-1], layer_widths[1:], keys):
        weight_key, bias_key = jax.random.split(key)

        params.append([
            scale * jax.random.normal(weight_key, (out_width, in_width)),    # weight
            scale * jax.random.normal(bias_key, (out_width,))                 # bias
        ])

    return params

# tests
key = jax.random.PRNGKey(0)
MLP_params = init_MLP([784, 512, 256, 10], key)

print(jax.tree.map(lambda x: x.shape, MLP_params))

[[(512, 784), (512,)], [(256, 512), (256,)], [(10, 256), (10,)]]


In [6]:
def MLP_predict(params, x):
    hidden_layers = params[:-1]

    for w, b in hidden_layers:
        x = jnp.dot(w, x) + b
        x = jax.nn.relu(x)

    w_final, b_final = params[-1]
    logits = jnp.dot(w_final, x) + b_final
    return logits - logsumexp(logits)

# tests
dummy_img_flat = np.random.randn(np.prod(mnist_img_size))
print(dummy_img_flat.shape)

predict = jax.jit(MLP_predict)
print(predict(MLP_params, dummy_img_flat).shape)

batched_MLP_predict = jit(vmap(MLP_predict, in_axes=(None, 0)))
print(batched_MLP_predict(MLP_params, np.stack([dummy_img_flat, dummy_img_flat])).shape)

(784,)
(10,)
(2, 10)


In [53]:
def custom_transform(img):
    img = np.array(img, dtype=np.float32).reshape(-1)

    return img

def custom_collate_fn(batch):
    transposed_data = list(zip(*batch))

    images = np.stack(transposed_data[0])
    labels = np.array(transposed_data[1])

    return (images, labels)


train_dataset = MNIST(root='./train_mnist', train=True, download=True, transform=custom_transform)
test_dataset = MNIST(root='./test_mnist', train=False, download=True, transform=custom_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)

batch_data = next(iter(train_loader))
print(batch_data[0].shape)  # images shape
print(batch_data[1].shape)  # labels shape
print(type(batch_data[0]))

(128, 784)
(128,)
<class 'numpy.ndarray'>


# Visualizations

In [None]:
# TODO: visualize the MLP weight
# TODO: visualize embeddings using t-SNE
# TODO: dead neurons

# Parallelization