In [1]:
import numpy

In [2]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [3]:
from functools import partial

In [4]:
from layers import Linear, Conv2d, SpatialPool2d, ReLU, Softmax
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy, weight_decay

In [5]:
from pl_bolts.datamodules import MNISTDataModule

In [6]:
rng = jrng.PRNGKey(1234)

In [7]:
mymodel = Model(rng, [Conv2d(7,7,1,256), ReLU(), 
                      Conv2d(3,3,256,256), ReLU(), 
                      SpatialPool2d(), Linear(256,10), 
                      Softmax()], 
                loss=[(cross_entropy, 1.), (weight_decay, 1e-5)])

In [8]:
optim = Adam(mymodel, lr=1e-5)

In [9]:
data = MNISTDataModule('./mnist/')
data.prepare_data()

In [10]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [None]:
n_epochs = 100

loss_running = numpy.Inf

for ei in range(n_epochs):
    for x, y in train_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        loss, grad = mymodel.loss_grad(x_, y_)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_.squeeze() == yp.squeeze())

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')



epoch 1 loss 9.073610305786133 val acc 0.208251953125
epoch 2 loss 7.786298751831055 val acc 0.4248046875
epoch 3 loss 7.407053470611572 val acc 0.568359375
epoch 4 loss 7.186738014221191 val acc 0.650634765625
epoch 5 loss 7.024512767791748 val acc 0.702880859375
epoch 6 loss 6.921138763427734 val acc 0.744140625
epoch 7 loss 6.8214335441589355 val acc 0.770263671875
epoch 8 loss 6.77091121673584 val acc 0.79638671875
epoch 9 loss 6.71079158782959 val acc 0.820068359375
epoch 10 loss 6.684136390686035 val acc 0.830322265625
epoch 11 loss 6.622817516326904 val acc 0.8427734375
epoch 12 loss 6.581671237945557 val acc 0.85791015625


In [None]:
n_corrects, n_all

In [19]:
(y_.squeeze() == yp.squeeze()).shape

(1024,)

In [20]:
lax.conv(x_, jnp.transpose(mymodel.layers[0].weight,[1,0,2,3]), (1,1), 'VALID').shape

(256, 128, 26, 26)

In [20]:
x_.shape

(1024, 784)

In [21]:
x_

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)