In [3]:
import numpy as np

from torchvision.datasets import MNIST
import jax
import jax.numpy as jnp

from mlp import MLP, CCE, accuracy, update, train, validate
from dataset import DataLoader

In [4]:
mlp = MLP()
params = mlp.init_layer(layer_widths=[512, 256, 256])
jax.tree.map(lambda x: x.shape, params)

{'bias': [(512,), (256,), (256,), (10,)],
 'weights': [(784, 512), (512, 256), (256, 256), (256, 10)]}

In [5]:
key = jax.random.PRNGKey(seed=0)
temp_image = jax.random.normal(key, shape=(11, 784))

temp_y_indices = np.random.randint(0, 10, (11, ))
temp_y = jnp.zeros((11, 10))
temp_y = temp_y.at[jnp.arange(temp_y_indices.shape[0]), temp_y_indices].set(1.0)

temp_image.shape, temp_y.shape, temp_y_indices

((11, 784), (11, 10), array([0, 7, 6, 3, 4, 4, 6, 9, 6, 1, 5]))

In [6]:
outs = MLP.forward(params, temp_image)
outs.shape

(11, 10)

In [7]:
loss, new_params = update(params, temp_image, temp_y)
loss, new_params.keys()

(Array(0.23043767, dtype=float32), dict_keys(['bias', 'weights']))

In [8]:
train_dataset = MNIST(root='train_mnist', train=True, download=True)
train_dataloader = DataLoader(train_dataset.data, train_dataset.targets, batch_size=6000)

In [11]:
jax.print(mlp)

<mlp.MLP at 0x720c0235b5b0>

In [9]:
loss = train(mlp, train_dataloader, epochs=100)
loss

Epoch: 100%|██████████| 100/100 [00:16<00:00,  5.94it/s, train_loss=0.0652]


[0.23033944368362427,
 0.23019034415483475,
 0.2300218552350998,
 0.22984422743320465,
 0.2296554610133171,
 0.2294418841600418,
 0.22919444143772125,
 0.2289143681526184,
 0.2286090672016144,
 0.22828963547945022,
 0.22795817106962205,
 0.22760650515556335,
 0.22722183018922806,
 0.22678298205137254,
 0.22624792009592057,
 0.22562865614891053,
 0.22498219609260559,
 0.22430087476968766,
 0.22357094138860703,
 0.22278331518173217,
 0.22192663252353667,
 0.22098679691553116,
 0.21990866363048553,
 0.2185285657644272,
 0.21701250821352006,
 0.21539339870214463,
 0.21361469179391862,
 0.21164893805980683,
 0.20947340577840806,
 0.20706476122140885,
 0.204399211704731,
 0.20145314931869507,
 0.19820394068956376,
 0.19462998062372208,
 0.1907206282019615,
 0.18647219538688659,
 0.18189695477485657,
 0.17702219039201736,
 0.1718883380293846,
 0.16655167639255525,
 0.161080938577652,
 0.15555654019117354,
 0.1500645250082016,
 0.14469461888074875,
 0.13952914923429488,
 0.13463062196969985,
 

In [10]:
val_dataloader = DataLoader(train_dataset.data,
                            train_dataset.targets,
                            return_labels=True,
                            batch_size=6000)
val_acc = validate(mlp, val_dataloader)
val_acc

Validating


0.8051166534423828

In [None]:
test_dataset = MNIST(root="test_mnist", train=False, download=True)
test_dataloader = DataLoader(test_dataset.data,
                             test_dataset.targets,
                             return_labels=True,
                             batch_size=100)

In [None]:
test_acc = validate(mlp, test_dataloader)
test_acc