In [1]:
import jax.numpy as jnp
import jax 
from jax import jit
import matplotlib.pyplot as plt
import timeit
from jax import lax

'''for dataloading'''
import torch
import torchvision

In [2]:
def init_Conv2D(key, in_channels, out_channels, kernel_shape):
    key1, key2 = jax.random.split(key)
    kernel = jax.random.normal(key1, (out_channels, in_channels) + kernel_shape)
    bias = jax.random.normal(key2, (1, out_channels, 1, 1))

    return dict(kernel=kernel, bias=bias)   

def init_fc(key, input_dim, output_dim):
    key1, key2 = jax.random.split(key)
    weights = jax.random.normal(key1, (input_dim, output_dim))
    bias = jax.random.normal(key2, (output_dim,))

    return dict(weights=weights, bias=bias)


In [3]:
def forward_Conv2D(params, x):
    return lax.conv(x, params['kernel'], (1, 1), 'VALID') + params['bias']

def forward_fc(params, x):
    return jnp.dot(x, params['weights']) + params['bias']

In [4]:

def init_mnist_convnet(key):
    layers = []
    key, *subkey = jax.random.split(key, num=6)
    layers.append(init_Conv2D(subkey[0], 1, 1, (1, 1))) #28x28x1 -> 21x21x4
    layers.append(init_Conv2D(subkey[1], 1, 1, (1, 1))) #21x21x4 -> 14x14x
    layers.append(init_fc(subkey[2], 28*28, 200)) #14x14x8 -> 10
    layers.append(init_fc(subkey[3], 200, 10)) #10 -> 10
    return layers


def forward_mnist_convnet(params, x):
    x = jnp.reshape(x, (x.shape[0], -1))
    x = forward_fc(params[2], x)
    x = jax.nn.relu(x)
    x = forward_fc(params[3], x)
    x = jax.nn.log_softmax(x, axis=1)
    return x

def cross_entropy_loss(params, x, y):
    pred = forward_mnist_convnet(params, x)
    return -jnp.mean(jnp.sum(y * pred, axis=1))

def mse_loss(params, x, y):
    return jnp.mean((forward_mnist_convnet(params, x) - y) ** 2) 

@jit
def update(params, x: jnp.ndarray, y: jnp.ndarray, lr: float):

    grads = jax.grad(cross_entropy_loss)(params, x, y)    
    return jax.tree_util.tree_map(
       lambda p, g: p - lr * g, params, grads
    )

In [5]:
#load minst
train_data = torchvision.datasets.MNIST(
    root='.',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

test_data = torchvision.datasets.MNIST(
    root='.',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)

In [6]:
batch_size = 100

train_loader = torch.utils.data.DataLoader(
    dataset=train_data,
    batch_size=batch_size,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=batch_size,
    shuffle=False
)

key = jax.random.PRNGKey(0)
params = init_mnist_convnet(key)
lr = 0.01
num_epochs = 20

for epoch in range(num_epochs):
    loss_sum = 0
    for i, (x, y) in enumerate(train_loader):
        x = x.numpy()
        y = y.numpy()
        x = x.reshape(-1, 1, 28, 28)
        y = jax.nn.one_hot(y, 10)
        loss = cross_entropy_loss(params, x, y)
        params = update(params, x, y, lr)
        loss_sum += loss

    correct = 0
    total = 0
    for x, y in test_loader:
        x = x.numpy()
        y = y.numpy()
        x = x.reshape(-1, 1, 28, 28)
        y_pred = forward_mnist_convnet(params, x)
        y_pred = jnp.argmax(y_pred, axis=1)
        correct += (y_pred == y).sum()
        total += y.shape[0]
    print(f'Epoch: {epoch}, Train loss: {loss_sum / i}, Test accuracy: {correct / total}')

2024-04-24 14:24:08.676623: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Epoch: 0, Train loss: 16.510223388671875, Test accuracy: 0.7916000485420227
Epoch: 1, Train loss: 6.261163711547852, Test accuracy: 0.8375000357627869
Epoch: 2, Train loss: 4.716853618621826, Test accuracy: 0.8541000485420227
Epoch: 3, Train loss: 3.888270378112793, Test accuracy: 0.86680006980896
Epoch: 4, Train loss: 3.342078924179077, Test accuracy: 0.8747000694274902
Epoch: 5, Train loss: 2.9581427574157715, Test accuracy: 0.8828000426292419
Epoch: 6, Train loss: 2.662611246109009, Test accuracy: 0.8861000537872314
Epoch: 7, Train loss: 2.4234538078308105, Test accuracy: 0.8891000151634216
Epoch: 8, Train loss: 2.2271769046783447, Test accuracy: 0.893000066280365
Epoch: 9, Train loss: 2.0588293075561523, Test accuracy: 0.8958000540733337
Epoch: 10, Train loss: 1.9198358058929443, Test accuracy: 0.8997000455856323
Epoch: 11, Train loss: 1.794801950454712, Test accuracy: 0.9003000259399414
Epoch: 12, Train loss: 1.690253734588623, Test accuracy: 0.9027000665664673
Epoch: 13, Train lo