### Imports

In [1]:
from inspect import signature
from collections import namedtuple
import time
import torch
from torch import nn

torch.backends.cudnn.benchmark = True
import numpy as np

import torchvision

import pandas as pd
pd.options.display.precision = 4
pd.options.display.width = 180

from IPython.display import display, clear_output, HTML
import ipywidgets as widgets

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Function defs

In [2]:
#####################
# utils
#####################

class Timer():
    def __init__(self):
        self.times = [time.time()]
    def __call__(self):
        self.times.append(time.time())
        return self.times[-1] - self.times[-2]
    def total_time(self):
        return self.times[-1] - self.times[0]
    
def curry(func):
    keyword_args = [p.name for p in signature(func).parameters.values() if p.default is not p.empty]
    n = namedtuple(func.__name__, keyword_args)
    n.__call__ = lambda self, x: func(x, **self._asdict())
    return n

localtime = lambda: time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())

def show(logs):
    clear_output(wait=True)
    df = pd.DataFrame(logs, columns=logs[0].keys())
    display(HTML(df.to_html(index=False)))

def warmup_cudnn(model, batch_size):
    batch = {
        'input': torch.Tensor(np.random.rand(batch_size,3,32,32)).cuda().half(), 
        'target': torch.LongTensor(np.random.randint(0,10,batch_size)).cuda()
    }
    model.train(True)
    o = model(batch)
    o['loss'].backward()
    model.zero_grad()
    torch.cuda.synchronize()
    
    
#####################
## data preprocessing
#####################

def normalise(x, mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)):
    x, mean, std = [np.array(a, np.float32) for a in (x, mean, std)]
    x -= mean*255
    x *= 1.0/(255*std)
    return x

def pad(x, border=4):
    return np.pad(x, [(0, 0), (border, border), (border, border), (0, 0)], mode='reflect')

def transpose(x, source='NHWC', target='NCHW'):
    return x.transpose([source.index(d) for d in target]) 


#####################
## data augmentation
#####################

def _random_window(x, h, w):
    C,H,W = x.shape
    y0, x0 = np.random.randint(H+1-h), np.random.randint(W+1-w)
    return x[:,y0:y0+h,x0:x0+w]

@curry
def random_crop(x, h=32, w=32):
    return _random_window(x, h, w)

def flip_lr(x):
    x = x[:, :, ::-1] if np.random.choice((True, False)) else x 
    return x.copy()

@curry
def cutout(x, h=8, w=8, prob=1.0):
    if np.random.uniform() < prob:
        _random_window(x, h, w).fill(0.0)
    return x

class Transform():
    def __init__(self, dataset, transforms):
        self.dataset, self.transforms = dataset, transforms
        
    def __len__(self):
        return len(self.dataset)
           
    def __getitem__(self, index):
        data, labels = self.dataset[index]
        for f in self.transforms:
            data = f(data)
        return data, labels

#####################
## data loading
#####################

class Batches():
    def __init__(self, dataset, batch_size, shuffle):
        self.dataloader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, num_workers=1, pin_memory=True, shuffle=shuffle
        )
    def __iter__(self): return ({'input': x.to(device).half(), 'target': y.to(device).long()} for (x,y) in self.dataloader)
    def __len__(self): return len(self.dataloader)


#####################
## torch stuff
#####################

class Mul():
    def __init__(self, weight): self.weight = weight
    def __call__(self, x): return x*self.weight
    
class Flatten(nn.Module):
    def forward(self, x): return x.view(x.size(0), x.size(1))

class Add(nn.Module):
    def forward(self, x, y): return x + y 
    
class Correct(nn.Module):
    def forward(self, classifier, target):
        return classifier.max(dim = 1)[1] == target

class TorchGraph(nn.Module):
    def __init__(self, net):
        self.graph = [node if len(node) == 3 else (node[0], node[1], [net[idx-1][0]]) for idx, node in enumerate(net)]
        super().__init__()
        for (n, v, _) in self.graph: 
            setattr(self, n, v)

    def forward(self, inputs):
        self.cache = dict(inputs)
        for (n, _, i) in self.graph:
            self.cache[n] = getattr(self, n)(*[self.cache[x] for x in i])
        return self.cache
    
    def half(self):
        for module in self.children():
            if type(module) is not nn.BatchNorm2d:
                module.half()    
        return self
    
pool = nn.AvgPool2d(2)
relu = nn.ReLU(inplace=True)
bn = nn.BatchNorm2d

def bn(num_channels, bias_init=-0.25, weight_init=0.5):
    #note weight initialisation scale is largely irrelevant to the forward computation
    #but changes the effective scale and thus learning rate of biases
    m = nn.BatchNorm2d(num_channels)
    if bias_init is not None:
        m.bias.data.fill_(bias_init)
    m.weight.data.fill_(weight_init)
    m.weight.requires_grad = False
    return m

def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

def to_numpy(x):
    return x.detach().cpu().numpy()  

#####################
## training
#####################
@curry
def piecewise_linear(t, knots=(), vals=()):
    return np.interp([t], knots, vals)[0]

def nesterov(model, momentum=None, weight_decay=None):
    return torch.optim.SGD(
        filter(lambda p: p.requires_grad, model.parameters()), 
        lr=0.0,
        momentum=momentum,
        weight_decay=weight_decay,
        nesterov=True
    )

avg = lambda xs: np.mean(np.array(xs) if xs[0].shape is () else np.concatenate(xs), dtype=np.float)

def train(model, epochs, lr_schedule, batch_size, optimizer, train_batches, test_batches):
    logs = [] 
    out = widgets.Output()
    print(f'Starting training at {localtime()}')
    display(out)
    t = Timer()
    for epoch in range(epochs):
        stats = {k: {'loss': [], 'correct': []} for k in ('train', 'test')}
        model.train(True)
        for i, batch in enumerate(train_batches):
            lr = lr_schedule(epoch + (i+1)/len(train_batches))
            optimizer.param_groups[0]['lr'] = lr
            model.zero_grad()
            output = model(batch)
            for k,v in stats['train'].items():
                v.append(to_numpy(output[k]))  
            output['loss'].backward()
            optimizer.step()
        train_time = t()
        model.train(False)
        for i, batch in enumerate(test_batches):
            output = model(batch)
            for k,v in stats['test'].items():
                v.append(to_numpy(output[k]))
        test_time = t()
        logs.append({
           'epoch': epoch+1, 'lr': lr*batch_size, 
            'train time': train_time, 'train loss': avg(stats['train']['loss'])/batch_size, 'train acc': avg(stats['train']['correct']), 
            'test time': test_time, 'test loss': avg(stats['test']['loss'])/batch_size, 'test acc': avg(stats['test']['correct']),
            'total time': t.total_time(), 
        })
        with out:
            show(logs)
    print(f'Finished training at {localtime()}')
    return logs


#####################
## network definition
#####################

def network(c=64, weight=0.25):
    net = [
        ('prep_conv', conv(3, c), ['input']),
        ('prep_bn', bn(c)),
        ('prep_relu', relu),

        ('layer1_conv', conv(c, c*2)),
        ('layer1_bn', bn(c*2)),
        ('layer1_relu', relu),
        ('layer1_pool', pool),
        ('layer1_res1_conv', conv(c*2, c*2)),
        ('layer1_res1_bn', bn(c*2)),
        ('layer1_res1_relu', relu),
        ('layer1_res2_conv', conv(c*2, c*2)),
        ('layer1_res2_bn', bn(c*2)),
        ('layer1_res2_relu', relu),
        ('layer1_add', Add(), ['layer1_pool', 'layer1_res2_relu']),
        
        ('layer2_conv', conv(c*2, c*4)),
        ('layer2_bn', bn(c*4)),
        ('layer2_relu', relu),
        ('layer2_pool', pool),

        ('layer3_conv', conv(c*4, c*4)),
        ('layer3_bn', bn(c*4)),
        ('layer3_relu', relu),

        ('layer4_conv', conv(c*4, c*8)),
        ('layer4_bn', bn(c*8)),
        ('layer4_relu', relu),

        ('pool', nn.AdaptiveMaxPool2d(1)),
        ('flatten', Flatten()),
        ('classifier_fc', nn.Linear(c*8, 10, bias=True)),
        ('classifier', Mul(weight)),

        ('loss', nn.CrossEntropyLoss(size_average=False), ['classifier', 'target']),
        ('correct', Correct(), ['classifier', 'target']),
    ]
    return TorchGraph(net).to(device).half()

### Download and preprocess data

In [3]:
DATA_DIR = './data'

train_set = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
test_set = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True)
t = Timer()
print('Preprocessing training data')
train_set = list(zip(transpose(normalise(pad(train_set.train_data, 4))), train_set.train_labels))
print(f'Finished in {t():.2} seconds')
print('Preprocessing test data')
test_set = list(zip(transpose(normalise(test_set.test_data)), test_set.test_labels))
print(f'Finished in {t():.2} seconds')

Files already downloaded and verified
Files already downloaded and verified
Preprocessing training data
Finished in 2.4 seconds
Preprocessing test data
Finished in 0.093 seconds


### Training

In [4]:
lr_schedule = piecewise_linear([0,6,25], [0, 0.1/128, 0])
epochs = 25
batch_size = 512

train_batches = Batches(Transform(train_set, [random_crop(32, 32), flip_lr, cutout(8, 8, 1.0)]), batch_size, shuffle=True)
test_batches = Batches(test_set, batch_size, shuffle=False)

model = network()

t = Timer()
print('Warming up cudnn on a batch of random inputs')
warmup_cudnn(network(), batch_size)
print(f'Finished in {t():.3} seconds')

logs = train(model, epochs, lr_schedule, batch_size, nesterov(model, momentum=0.9, weight_decay=batch_size*5e-4), train_batches, test_batches)

Warming up cudnn on a batch of random inputs
Finished in 1.17 seconds
Starting training at 2018-10-01 10:52:58


Output()

Finished training at 2018-10-01 10:54:26
