In [1]:
import numpy

In [2]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [3]:
from functools import partial

In [4]:
from layers import Linear, ReLU, Softmax
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy

In [5]:
from pl_bolts.datamodules import MNISTDataModule

In [6]:
rng = jrng.PRNGKey(1234)

In [7]:
mymodel = Model(rng, [Linear(784,512), ReLU(), Linear(512,512), ReLU(), Linear(512,10), Softmax()], loss=cross_entropy)

In [8]:
optim = Adam(mymodel, lr=1e-5)

In [9]:
data = MNISTDataModule('./mnist/')
data.prepare_data()

In [10]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [None]:
n_epochs = 100

loss_running = numpy.Inf

for ei in range(n_epochs):
    mymodel.train()
    for x, y in train_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        x_ = x_.reshape(x_.shape[0], -1)
        
        loss, grad = mymodel.loss_grad(x_, y_)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
    mymodel.eval()
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        x_ = x_.reshape(x_.shape[0], -1)
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_ == yp)

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

epoch 1 loss 6.501333236694336 val acc 0.84814453125
epoch 2 loss 6.3682732582092285 val acc 0.872314453125
epoch 3 loss 6.285041809082031 val acc 0.88427734375
epoch 4 loss 6.2357378005981445 val acc 0.891357421875
epoch 5 loss 6.187690258026123 val acc 0.8974609375
epoch 6 loss 6.138576030731201 val acc 0.904296875
epoch 7 loss 6.127128601074219 val acc 0.90625
epoch 8 loss 6.092522144317627 val acc 0.9091796875
epoch 9 loss 6.082522869110107 val acc 0.911865234375
epoch 10 loss 6.05191707611084 val acc 0.914794921875
epoch 11 loss 6.02699089050293 val acc 0.9169921875
epoch 12 loss 6.021353721618652 val acc 0.922119140625
epoch 13 loss 5.99982213973999 val acc 0.924072265625
epoch 14 loss 5.986532211303711 val acc 0.92578125
epoch 15 loss 5.978908538818359 val acc 0.927490234375
epoch 16 loss 5.968048095703125 val acc 0.928466796875
epoch 17 loss 5.955752849578857 val acc 0.928955078125
epoch 18 loss 5.94883918762207 val acc 0.93017578125
epoch 19 loss 5.932291507720947 val acc 0.93

In [12]:
y_

array([9, 0, 7, ..., 7, 5, 4])

In [16]:
jnp.argmax(mymodel.forward(x_), -1)

DeviceArray([9, 0, 9, ..., 7, 8, 4], dtype=int32)

In [20]:
x_.shape

(1024, 784)

In [21]:
x_

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)