In [1]:
%matplotlib widget

In [2]:
import numpy

In [3]:
from tqdm.notebook import trange, tqdm

In [4]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [5]:
from functools import partial

In [6]:
from layers import Linear, Conv2d, SpatialPool2d, ReLU, Tanh, Softmax, MaxPool2d
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy, weight_decay, clip_norm
from utils import flatten_dict, apply_dict

In [7]:
from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule

In [8]:
rng = jrng.PRNGKey(1234)

In [9]:
data = MNISTDataModule('./mnist/')
data = CIFAR10DataModule('./cifar-10/')
data.prepare_data()

Files already downloaded and verified
Files already downloaded and verified


In [10]:
train_loader = data.train_dataloader(batch_size=256)
val_loader = data.val_dataloader(batch_size=1024)

In [11]:
mymodel = Model(rng, [Conv2d(3,3,3,256, mode='SAME'), ReLU(), MaxPool2d(2,2), 
                      Conv2d(3,3,256,256, mode='SAME'), ReLU(), MaxPool2d(2,2), 
                      Conv2d(3,3,256,256, mode='SAME'), ReLU(),
                      SpatialPool2d(), Linear(256,10), 
                      Softmax()], 
                loss=[(cross_entropy, 1.)])
# mymodel = Model(rng, [Conv2d(7,7,1,256), ReLU(), 
#                       Conv2d(7,7,256,256), ReLU(), 
#                       Conv2d(7,7,256,256), ReLU(), 
#                       SpatialPool2d(), Linear(256,10), 
#                       Softmax()], 
#                 loss=[(cross_entropy, 1.)])
# , (weight_decay, 1e-5)])

In [12]:
optim = Adam(mymodel, lr=1e-5)

In [13]:
n_devices = len(jax.devices())

In [None]:
n_epochs = 30

loss_running = numpy.Inf

for ei in range(n_epochs):
    tloader = tqdm(train_loader)
    for x, y in tloader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        loss, grad = mymodel.loss_grad(x_, y_)
        grad = clip_norm(grad, thr=1.)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
        tloader.set_postfix(dict([('loss', loss_running)]))
            
    n_corrects = 0
    n_all = 0
    for x, y in val_loader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_.squeeze() == yp.squeeze())

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

HBox(children=(FloatProgress(value=0.0, max=175.0), HTML(value='')))



In [None]:
from matplotlib import pyplot as plot
from matplotlib import cm

In [None]:
weight = mymodel.params[mymodel.layers[0].name]['weight']
# weight = mymodel.layers[0].weight

In [19]:
fn = 256
w = int(numpy.ceil(numpy.sqrt(fn)))
h = int(fn // w)
filter_canvas = numpy.zeros((w * (1+weight.shape[-1]), h * (1+weight.shape[-2])))

for fid in range(fn):
    ri = fid // w
    ci = fid % w
    filter_canvas[ri * (weight.shape[-1]+1):(ri+1) * (weight.shape[-1]+1)-1,
                 ci * (weight.shape[-2]+1):(ci+1) * (weight.shape[-2]+1)-1] = weight[fid][0]

NameError: name 'weight' is not defined

In [52]:
plot.figure()

plot.imshow(filter_canvas, cmap=cm.gray)

plot.axis(False)

plot.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [18]:
h

16.0