In [1]:
import numpy

In [2]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [3]:
from functools import partial

In [4]:
from layers import Linear, Conv2d, SpatialPool2d, ReLU, Softmax
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy

In [5]:
from pl_bolts.datamodules import MNISTDataModule

In [6]:
rng = jrng.PRNGKey(1234)

In [13]:
mymodel = Model(rng, [Conv2d(3,3,1,128), ReLU(), 
                      Conv2d(3,3,128,128), ReLU(), 
                      SpatialPool2d(), Linear(128,10), 
                      Softmax()], loss=cross_entropy)

In [14]:
optim = Adam(mymodel, lr=1e-5)

In [15]:
data = MNISTDataModule('./mnist/')
data.prepare_data()

In [16]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [None]:
n_epochs = 100

loss_running = numpy.Inf

for ei in range(n_epochs):
    for x, y in train_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        loss, grad = mymodel.loss_grad(x_, y_)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_.squeeze() == yp.squeeze())

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

epoch 1 loss 8.000088691711426 val acc 0.16455078125
epoch 2 loss 7.625463485717773 val acc 0.296875
epoch 3 loss 7.418938636779785 val acc 0.4072265625
epoch 4 loss 7.274871349334717 val acc 0.499755859375
epoch 5 loss 7.145114421844482 val acc 0.568603515625
epoch 6 loss 7.0391130447387695 val acc 0.618896484375
epoch 7 loss 6.958877086639404 val acc 0.65673828125
epoch 8 loss 6.875216484069824 val acc 0.683349609375
epoch 9 loss 6.80990743637085 val acc 0.7080078125
epoch 10 loss 6.751032829284668 val acc 0.728271484375
epoch 11 loss 6.700778007507324 val acc 0.747314453125
epoch 12 loss 6.663313865661621 val acc 0.760009765625
epoch 13 loss 6.621245861053467 val acc 0.77294921875
epoch 14 loss 6.5793867111206055 val acc 0.78369140625
epoch 15 loss 6.552623748779297 val acc 0.79736328125
epoch 16 loss 6.528267860412598 val acc 0.806396484375
epoch 17 loss 6.504374980926514 val acc 0.820068359375
epoch 18 loss 6.468982696533203 val acc 0.82373046875
epoch 19 loss 6.460025787353516 va

In [12]:
n_corrects, n_all

(DeviceArray(663, dtype=int32), 4096)

In [19]:
(y_.squeeze() == yp.squeeze()).shape

(1024,)

In [20]:
lax.conv(x_, jnp.transpose(mymodel.layers[0].weight,[1,0,2,3]), (1,1), 'VALID').shape

(256, 128, 26, 26)

In [20]:
x_.shape

(1024, 784)

In [21]:
x_

array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)