In [1]:
%matplotlib widget

In [2]:
import numpy

In [3]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [4]:
from functools import partial

In [5]:
from layers import Linear, Conv2d, SpatialPool2d, ReLU, Tanh, Softmax
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy, weight_decay, clip_norm
from utils import flatten_dict, apply_dict

In [6]:
from pl_bolts.datamodules import MNISTDataModule

In [7]:
rng = jrng.PRNGKey(1234)

In [25]:
mymodel = Model(rng, [Conv2d(7,7,1,256), ReLU(), 
                      Conv2d(7,7,256,256), ReLU(), 
                      Conv2d(7,7,256,256), ReLU(), 
                      SpatialPool2d(), Linear(256,10), 
                      Softmax()], 
                loss=[(cross_entropy, 1.)])
# , (weight_decay, 1e-5)])

In [26]:
optim = Adam(mymodel, lr=1e-5)

In [27]:
data = MNISTDataModule('./mnist/')
data.prepare_data()

In [28]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [29]:
n_epochs = 30

loss_running = numpy.Inf

for ei in range(n_epochs):
    for x, y in train_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        loss, grad = mymodel.loss_grad(x_, y_)
        grad = clip_norm(grad, thr=1.)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_.squeeze() == yp.squeeze())

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

epoch 1 loss 6.013945579528809 val acc 0.96533203125
epoch 2 loss 5.871925354003906 val acc 0.975830078125
epoch 3 loss 5.807745456695557 val acc 0.97998046875
epoch 4 loss 5.775405406951904 val acc 0.982421875
epoch 5 loss 5.750779628753662 val acc 0.985107421875
epoch 6 loss 5.734845161437988 val acc 0.98681640625
epoch 7 loss 5.726753234863281 val acc 0.98681640625
epoch 8 loss 5.723051071166992 val acc 0.988525390625
epoch 9 loss 5.707808494567871 val acc 0.989013671875
epoch 10 loss 5.698883533477783 val acc 0.99072265625
epoch 11 loss 5.690040588378906 val acc 0.991455078125
epoch 12 loss 5.683486461639404 val acc 0.99072265625
epoch 13 loss 5.685825347900391 val acc 0.99072265625
epoch 14 loss 5.673786163330078 val acc 0.9921875
epoch 15 loss 5.674094200134277 val acc 0.9921875
epoch 16 loss 5.669402599334717 val acc 0.9921875
epoch 17 loss 5.670246601104736 val acc 0.992431640625
epoch 18 loss 5.66935920715332 val acc 0.992919921875
epoch 19 loss 5.658464431762695 val acc 0.992

In [30]:
from matplotlib import pyplot as plot
from matplotlib import cm

In [31]:
weight = mymodel.params[mymodel.layers[0].name]['weight']
# weight = mymodel.layers[0].weight

In [32]:
fn = 256
w = int(numpy.ceil(numpy.sqrt(fn)))
h = int(fn // w)
filter_canvas = numpy.zeros((w * (1+weight.shape[-1]), h * (1+weight.shape[-2])))

for fid in range(fn):
    ri = fid // w
    ci = fid % w
    filter_canvas[ri * (weight.shape[-1]+1):(ri+1) * (weight.shape[-1]+1)-1,
                 ci * (weight.shape[-2]+1):(ci+1) * (weight.shape[-2]+1)-1] = weight[fid][0]

In [33]:
plot.figure()

plot.imshow(filter_canvas, cmap=cm.gray)

plot.axis(False)

plot.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
h

16.0