In [None]:
from jax.scipy.special import logsumexp
from jax.example_libraries import optimizers
import numpy as onp
import jax.numpy as np
from jax import grad, jit, vmap, value_and_grad
from jax import random

import torch
from torchvision import datasets, transforms

In [None]:
key = random.PRNGKey(1)

In [None]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True
)

In [None]:
def ReLU(x):
    return np.maximum(x, 0)

def relu_layer(x, w, b):
    return ReLU(np.dot(w, x) + b)

In [None]:
def initialize_mlp(sizes, key):
    keys = random.split(key, len(sizes))
    def initialize_layer(m, n, key, scale=1e-2):
        w_key, b_key = random.split(key)
        return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
    return [initialize_layer(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]

layer_sizes = [784, 512, 512, 10] # 28x28 img -> 10 classes
params = initialize_mlp(layer_sizes, key)

In [None]:
def relu_layer(x, w, b):
    """ A single layer of a neural network """
    return np.maximum(0, np.dot(x, w) + b)

In [None]:
def forward_pass(params, features):
    activations = features
    for w, b in params[:-1]:
        activations = relu_layer(activations, w, b)
    w, b = params[-1]
    logits = np.dot(w, activations) + b
    return logits - logsumexp(logits, axis=1, keepdims=True)

batch_forward = vmap(forward_pass, in_axes=(None, 0), out_axes=0)

In [None]:
def one_hot(x, k, dtype=np.float32):
    return np.array(x[:, None] == np.arange(k), dtype)