In [1]:
import jax.numpy as jnp
import jax 
from jax import jit
import matplotlib.pyplot as plt
import timeit
from jax import lax

'''for dataloading'''
import torch
import torchvision

In [22]:
def init_Conv2D(key, in_channels, out_channels, kernel_shape):
    key1, key2 = jax.random.split(key)
    kernel = jax.random.normal(key1, (out_channels, in_channels) + kernel_shape)
    bias = jax.random.normal(key2, (1, out_channels, 1, 1))

    return dict(kernel=kernel, bias=bias)   

def init_fc(key, input_dim, output_dim):
    key1, key2 = jax.random.split(key)
    weights = jax.random.normal(key1, (input_dim, output_dim))
    bias = jax.random.normal(key2, (output_dim,))

    return dict(weights=weights, bias=bias)


In [23]:
def forward_Conv2D(params, x):
    return lax.conv(x, params['kernel'], (1, 1), 'VALID') + params['bias']

def forward_fc(params, x):
    return jnp.dot(x, params['weights']) + params['bias']

In [24]:

def init_mnist_convnet(key):
    layers = []
    key, *subkey = jax.random.split(key, num=6)
    layers.append(init_Conv2D(subkey[0], 1, 1, (1, 1))) #28x28x1 -> 21x21x4
    layers.append(init_Conv2D(subkey[1], 1, 1, (1, 1))) #21x21x4 -> 14x14x
    layers.append(init_fc(subkey[2], 28*28, 200)) #14x14x8 -> 10
    layers.append(init_fc(subkey[3], 200, 10)) #10 -> 10
    return layers


def forward_mnist_convnet(params, x):
    x = jnp.reshape(x, (x.shape[0], -1))
    x = forward_fc(params[2], x)
    x = jax.nn.relu(x)
    x = forward_fc(params[3], x)
    x = jax.nn.log_softmax(x, axis=1)
    return x

def cross_entropy_loss(params, x, y):
    pred = forward_mnist_convnet(params, x)
    return -jnp.mean(jnp.sum(y * pred, axis=1))

def mse_loss(params, x, y):
    return jnp.mean((forward_mnist_convnet(params, x) - y) ** 2) 

@jit
def update(params, x: jnp.ndarray, y: jnp.ndarray, lr: float):

    grads = jax.grad(cross_entropy_loss)(params, x, y)    
    return jax.tree_util.tree_map(
       lambda p, g: p - lr * g, params, grads
    )

In [25]:
#load minst
train_data = torchvision.datasets.MNIST(
    root='.',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.MNIST(
    root='.',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [26]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=batch_size,
    shuffle=False
)

key = jax.random.PRNGKey(0)
params = init_mnist_convnet(key)
lr = 0.01
num_epochs = 20

for epoch in range(num_epochs):
    loss_sum = 0
    for i, (x, y) in enumerate(train_loader):
        x = x.numpy()
        y = y.numpy()
        x = x.reshape(-1, 1, 28, 28)
        y = jax.nn.one_hot(y, 10)
        loss = cross_entropy_loss(params, x, y)
        params = update(params, x, y, lr)
        loss_sum += loss

    correct = 0
    total = 0
    for x, y in test_loader:
        x = x.numpy()
        y = y.numpy()
        x = x.reshape(-1, 1, 28, 28)
        y_pred = forward_mnist_convnet(params, x)
        y_pred = jnp.argmax(y_pred, axis=1)
        correct += (y_pred == y).sum()
        total += y.shape[0]
    print(f'Epoch: {epoch}, Train loss: {loss_sum / i}, Test accuracy: {correct / total}')

Epoch: 0, Train loss: 16.53502655029297, Test accuracy: 0.789900004863739


KeyboardInterrupt: 