In [None]:
!sudo nvidia-persistenced
!sudo nvidia-smi -ac 877,1530

In [1]:
from utils import *

### Training loop

In [2]:
avg = lambda xs, N: 0.0 if N is 0 else np.sum(concat(xs), dtype=np.float)/N
        
def collect(stats, output):
    for k,v in stats.items():
        v.append(to_numpy(output[k]))

def train_epoch(model, batches, optimizer, lrs, monitors=('loss', 'correct')):
    stats = {k:[] for k in monitors}
    model.train(True)   
    for lr, batch in zip(lrs, batches):  
        output = model(batch)
        collect(stats, output)
        output['loss'].backward()
        optimizer.param_groups[0]['lr'] = lr
        optimizer.step()
        model.zero_grad()
    return stats

def test_epoch(model, batches, monitors=('loss', 'correct')):
    stats = {k:[] for k in monitors}
    model.train(False)
    for batch in batches:
        output = model(batch)
        collect(stats, output)
    return stats

def train(model, lr_schedule, optimizer, train_set, test_set, batch_size=512, loggers=(), num_workers=0):
    t = Timer()
    train_batches = Batches(train_set, batch_size, shuffle=True, num_workers=num_workers)
    test_batches = Batches(test_set, batch_size, shuffle=False, num_workers=num_workers)

    for epoch in range(lr_schedule.knots[-1]):
        train_batches.dataset.set_random_choices() 
        lrs = (lr_schedule(x)/batch_size for x in np.arange(epoch, epoch+1, 1/len(train_batches)))
        train_stats, train_time = train_epoch(model, train_batches, optimizer, lrs, ('loss', 'correct')), t()
        test_stats, test_time = test_epoch(model, test_batches, ('loss', 'correct')), t()
        summary = {
           'epoch': epoch+1, 'lr': lr_schedule(epoch+1), 
            'train time': train_time, 'train loss': avg(train_stats['loss'], len(train_set)), 'train acc': avg(train_stats['correct'], len(train_set)), 
            'test time': test_time, 'test loss': avg(test_stats['loss'], len(test_set)), 'test acc': avg(test_stats['correct'], len(test_set)),
            'total time': t.total_time(), 
        }
        for logger in loggers:
            logger.append(summary)    
    return summary

### Network definitions

In [3]:
def res_block(c_in, c_out, stride, **kw):
    block = {
        'bn1': batch_norm(c_in, **kw),
        'relu1': nn.ReLU(True),
        'branch': {
            'conv1': nn.Conv2d(c_in, c_out, kernel_size=3, stride=stride, padding=1, bias=False),
            'bn2': batch_norm(c_out, **kw),
            'relu2': nn.ReLU(True),
            'conv2': nn.Conv2d(c_out, c_out, kernel_size=3, stride=1, padding=1, bias=False),
        }
    }
    projection = (stride != 1) or (c_in != c_out)    
    if projection:
        block['conv3'] = (nn.Conv2d(c_in, c_out, kernel_size=1, stride=stride, padding=0, bias=False), [rel_path('relu1')])
    block['add'] =  (Add(), [(rel_path('conv3') if projection else rel_path('relu1')), rel_path('branch', 'conv2')])
    return block

def DAWN_net(c=64, **kw):
    return {
        'prep': {'conv': nn.Conv2d(3, c, kernel_size=3, stride=1, padding=1, bias=False)},
        'layer1': {
            'block0': res_block(c, c, 1, **kw),
            'block1': res_block(c, c, 1, **kw),
        },
        'layer2': {
            'block0': res_block(c, 2*c, 2, **kw),
            'block1': res_block(2*c, 2*c, 1, **kw),
        },
        'layer3': {
            'block0': res_block(2*c, 4*c, 2, **kw),
            'block1': res_block(4*c, 4*c, 1, **kw),
        },
        'layer4': {
            'block0': res_block(4*c, 4*c, 2, **kw),
            'block1': res_block(4*c, 4*c, 1, **kw),
        },
        'classifier': {
            'maxpool': nn.MaxPool2d(4),
            'avgpool': (nn.AvgPool2d(4), [('layer4', 'block1', 'add')]),
            'concat': (Concat(), [rel_path('maxpool'), rel_path('avgpool')]),
            'flatten': Flatten(),
            'linear': nn.Linear(8*c, 10, bias=True),
            'logits': Identity()
        }
    }

losses = {
    'loss':  (nn.CrossEntropyLoss(size_average=False), [('classifier','logits'), ('target',)]),
    'correct': (Correct(), [('classifier','logits'), ('target',)]),
}

### Download and preprocess data

In [5]:
DATA_DIR = './data'

train_set_raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=True, download=True)
test_set_raw = torchvision.datasets.CIFAR10(root=DATA_DIR, train=False, download=True)
t = Timer()
print('Preprocessing training data')
train_set = list(zip(transpose(normalise(pad(train_set_raw.train_data, 4))), train_set_raw.train_labels))
print(f'Finished in {t():.2} seconds')
print('Preprocessing test data')
test_set = list(zip(transpose(normalise(test_set_raw.test_data)), test_set_raw.test_labels))
print(f'Finished in {t():.2} seconds')

Files already downloaded and verified
Files already downloaded and verified
Preprocessing training data
Finished in 4.1 seconds
Preprocessing test data
Finished in 0.12 seconds


### [Post 1: Baseline](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_1/) - bulk random choices (300s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.1, 0.005, 0])
batch_size = 128

n = DAWN_net()
display(DotGraph(n))
model = TorchGraph(union(n, losses)).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
opt = nesterov(trainable_params(model), momentum=0.9, weight_decay=5e-4*batch_size)
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary=train(model, lr_schedule, opt, train_set_x, test_set, 
              batch_size=batch_size, loggers=(TableLogger(),), 
              num_workers=1)

### [Post 1: Baseline](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_1/) - final (297s)

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.1, 0.005, 0])
batch_size = 128

n = DAWN_net()
display(DotGraph(n))
model = TorchGraph(union(n, losses)).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
opt = nesterov(trainable_params(model), momentum=0.9, weight_decay=5e-4*batch_size)
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary=train(model, lr_schedule, opt, train_set_x, test_set, 
              batch_size=batch_size, loggers=(TableLogger(),), 
              num_workers=0)

### [Post 2: Mini-batches](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_2/) - batch size=512 ()

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.44, 0.005, 0])
batch_size = 512

n = DAWN_net()
display(DotGraph(n))
model = TorchGraph(union(n, losses)).to(device)
#convert all children including batch norms to half precision (triggering slow codepath!)
for v in model.children(): 
    v.half()
opt = nesterov(trainable_params(model), momentum=0.9, weight_decay=5e-4*batch_size)
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary=train(model, lr_schedule, opt, train_set_x, test_set, 
              batch_size=batch_size, loggers=(TableLogger(),), 
              num_workers=0)

### [Post 3: Regularisation](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_3/) - speed up batch norms ()

In [None]:
lr_schedule = PiecewiseLinear([0, 15, 30, 35], [0, 0.44, 0.005, 0])
batch_size = 512

n = DAWN_net()
display(DotGraph(n))
model = TorchGraph(union(n, losses)).to(device).half()
opt = nesterov(trainable_params(model), momentum=0.9, weight_decay=5e-4*batch_size)
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR()])
summary=train(model, lr_schedule, opt, train_set_x, test_set, 
              batch_size=batch_size, loggers=(TableLogger(),), 
              num_workers=0)

### [Post 3: Regularisation](https://www.myrtle.ai/2018/09/24/how_to_train_your_resnet_3/) - final ()

In [None]:
lr_schedule = PiecewiseLinear([0, 8, 30], [0, 0.40, 0])
batch_size = 512

n = DAWN_net()
display(DotGraph(n))
model = TorchGraph(union(n, losses)).to(device).half()
opt = nesterov(trainable_params(model), momentum=0.9, weight_decay=5e-4*batch_size)
train_set_x = Transform(train_set, [Crop(32, 32), FlipLR(), Cutout(8,8)])
summary=train(model, lr_schedule, opt, train_set_x, test_set, 
              batch_size=batch_size, loggers=(TableLogger(),), 
              num_workers=0)