In [1]:
%matplotlib widget

In [2]:
import numpy

In [3]:
from tqdm.notebook import trange, tqdm

In [4]:
from jax import random as jrng
from jax import numpy as jnp
import jax

In [5]:
from functools import partial

In [6]:
from layers import Linear, Conv2d, SpatialPool2d, FakeResConv2d, MaxPool2d
from layers import ReLU, Tanh, Softmax, LeakyReLU
from layers import LayerNorm, BatchNorm2d, BatchNorm
from model import Model
from optimizers import SGD, Adam
from functionals import cross_entropy, weight_decay, clip_norm
from utils import flatten_dict, apply_dict, split_and_sample

In [7]:
from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule

In [8]:
rng = jrng.PRNGKey(1234)

In [9]:
data = MNISTDataModule('./mnist/')
data = CIFAR10DataModule('./cifar-10/')
data.prepare_data()

Files already downloaded and verified
Files already downloaded and verified


In [10]:
train_loader = data.train_dataloader(batch_size=64)
val_loader = data.val_dataloader(batch_size=64)

In [11]:
mymodel = Model(rng, [Conv2d(3,3,3,64, mode='SAME'), BatchNorm2d(64), LeakyReLU(), 
                      FakeResConv2d(5,5,64,64, mode='SAME'), BatchNorm2d(64), LeakyReLU(), 
                      FakeResConv2d(5,5,64,64, mode='SAME'), BatchNorm2d(64), LeakyReLU(), 
                      MaxPool2d(2,2), 
                      Conv2d(1,1,64,128, mode='SAME'), BatchNorm2d(128), LeakyReLU(), 
                      FakeResConv2d(5,5,128,128, mode='SAME'), 
                      MaxPool2d(2,2), 
                      Conv2d(1,1,128,256, mode='SAME'), BatchNorm2d(256), LeakyReLU(), 
                      SpatialPool2d(), 
                      Linear(256,256), 
                      BatchNorm(256), LeakyReLU(),
                      Linear(256,10),
                      Softmax()], 
                loss=[(cross_entropy, 1.)])#, (weight_decay, 1e-5)])
# mymodel = Model(rng, [Conv2d(7,7,1,256), ReLU(), 
#                       Conv2d(7,7,256,256), ReLU(), 
#                       Conv2d(7,7,256,256), ReLU(), 
#                       SpatialPool2d(), Linear(256,10), 
#                       Softmax()], 
#                 loss=[(cross_entropy, 1.)])

In [12]:
optim = Adam(mymodel, lr=1e-2)

In [13]:
n_devices = len(jax.devices())

In [None]:
n_epochs = 100

loss_running = numpy.Inf

for ei in range(n_epochs):    
    mymodel.train()
    tloader = tqdm(train_loader)
    for x, y in tloader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        loss, grad = mymodel.loss_grad(x_, y_)
        
        grad = clip_norm(grad, thr=1.)
        optim.step(grad)
        
        if loss_running == numpy.Inf:
            loss_running = loss
        else:
            loss_running = 0.95 * loss_running + 0.05 * loss
            
        tloader.set_postfix(dict([('loss', loss_running)]))
            
    mymodel.eval()
    n_corrects = 0
    n_all = 0
    tloader = tqdm(val_loader)
    for x, y in tloader:
        x_, y_ = x.data.numpy(), y.data.numpy()
        
        yp = jnp.argmax(mymodel.forward(x_), -1)
        
        n_all = n_all + len(y_)
        n_corrects = n_corrects + jnp.sum(y_.squeeze() == yp.squeeze())
        
        tloader.set_postfix(dict([('acc', n_corrects/n_all)]))

    print(F'epoch {ei+1} loss {loss_running} val acc {n_corrects/n_all}')

HBox(children=(FloatProgress(value=0.0, max=703.0), HTML(value='')))



In [None]:
from matplotlib import pyplot as plot
from matplotlib import cm

In [48]:
weight = mymodel.params[mymodel.layers[0].name]['weight']
# weight = mymodel.layers[0].weight

In [49]:
fn = weight.shape[0]
w = int(numpy.ceil(numpy.sqrt(fn)))
h = int(fn // w)
filter_canvas = numpy.zeros((w * (1+weight.shape[-1]), h * (1+weight.shape[-2]), 3))

for fid in range(fn):
    ri = fid // w
    ci = fid % w
    filter_canvas[ri * (weight.shape[-1]+1):(ri+1) * (weight.shape[-1]+1)-1,
                 ci * (weight.shape[-2]+1):(ci+1) * (weight.shape[-2]+1)-1, :] = numpy.transpose(weight[fid], [1, 2, 0])

In [50]:
plot.figure()

plot.imshow(filter_canvas, cmap=cm.gray)

plot.axis(False)

plot.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



In [51]:
weight.shape

(64, 3, 1, 1)