In [1]:
import numpy

In [2]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [3]:
from functools import partial

In [4]:
from layers import Linear, ReLU, Softmax
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy

In [5]:
from pl_bolts.datamodules import MNISTDataModule

In [6]:
rng = jrng.PRNGKey(1234)

In [34]:
mymodel = Model(rng, [Linear(784,512), ReLU(), Linear(512,512), ReLU(), Linear(512,10), Softmax()], loss=cross_entropy)

In [35]:
optim = Adam(mymodel, lr=1e-5)

In [36]:
data = MNISTDataModule('./mnist/')
data.prepare_data()

In [37]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [38]:
n_epochs = 100

loss_running = numpy.Inf

for ei in range(n_epochs):
    for x, y in train_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        x_ = x_.reshape(x_.shape[0], -1)
        
        loss, grad = mymodel.loss_grad(x_, y_)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        x_ = x_.reshape(x_.shape[0], -1)
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_ == yp)

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

epoch 1 loss 7.421951770782471 val acc 0.73388671875
epoch 2 loss 6.81615686416626 val acc 0.8203125
epoch 3 loss 6.505246639251709 val acc 0.853515625
epoch 4 loss 6.368214130401611 val acc 0.878662109375
epoch 5 loss 6.282380104064941 val acc 0.891357421875
epoch 6 loss 6.232450485229492 val acc 0.9033203125
epoch 7 loss 6.189967155456543 val acc 0.90625
epoch 8 loss 6.150373458862305 val acc 0.91015625
epoch 9 loss 6.114304542541504 val acc 0.91650390625
epoch 10 loss 6.085427761077881 val acc 0.921142578125
epoch 11 loss 6.068504810333252 val acc 0.922119140625
epoch 12 loss 6.0504255294799805 val acc 0.92626953125
epoch 13 loss 6.030655860900879 val acc 0.927490234375
epoch 14 loss 6.018169403076172 val acc 0.93115234375
epoch 15 loss 6.001720428466797 val acc 0.93359375
epoch 16 loss 5.987714767456055 val acc 0.93603515625
epoch 17 loss 5.9766845703125 val acc 0.9365234375
epoch 18 loss 5.969396114349365 val acc 0.9384765625
epoch 19 loss 5.956642150878906 val acc 0.939697265625


KeyboardInterrupt: 

In [12]:
y_

array([9, 0, 7, ..., 7, 5, 4])

In [16]:
jnp.argmax(mymodel.forward(x_), -1)

DeviceArray([9, 0, 9, ..., 7, 8, 4], dtype=int32)

In [20]:
x_.shape

(1024, 784)

In [21]:
x_

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)