In [1]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

In [2]:
from __future__ import absolute_import

'''Resnet for cifar dataset.
Ported form
https://github.com/facebook/fb.resnet.torch
and
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
(c) YANG, Wei
'''
import torch.nn as nn
import math


__all__ = ['preresnet']

def conv3x3(in_planes, out_planes, stride=1):
    "3x3 convolution with padding"
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(inplanes)
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.conv1(out)

        out = self.bn2(out)
        out = self.relu(out)
        out = self.conv2(out)

        out = self.bn3(out)
        out = self.relu(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual

        return out


class PreResNet(nn.Module):

    def __init__(self, depth, num_classes=1000):
        super(PreResNet, self).__init__()
        # Model type specifies number of layers for CIFAR-10 model
        assert (depth - 2) % 6 == 0, 'depth should be 6n+2'
        n = (depth - 2) // 6

        block = Bottleneck if depth >=44 else BasicBlock

        self.inplanes = 16
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1,
                               bias=False)
        self.layer1 = self._make_layer(block, 16, n)
        self.layer2 = self._make_layer(block, 32, n, stride=2)
        self.layer3 = self._make_layer(block, 64, n, stride=2)
        self.bn = nn.BatchNorm2d(64 * block.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)

        x = self.layer1(x)  # 32x32
        x = self.layer2(x)  # 16x16
        x = self.layer3(x)  # 8x8
        x = self.bn(x)
        x = self.relu(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def preresnet(**kwargs):
    """
    Constructs a ResNet model.
    """
    return PreResNet(**kwargs)

In [3]:
from fastai.conv_learner import *
from fastai.model import fit
from fastai.core import SGD_Momentum

PATH = "data/cifar10/"
os.makedirs(PATH,exist_ok=True)

In [4]:
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
stats = (np.array([ 0.4914 ,  0.48216,  0.44653]), np.array([ 0.24703,  0.24349,  0.26159]))

In [5]:
def get_data(sz,bs):
    tfms = tfms_from_stats(stats, sz, aug_tfms=[RandomFlip()], pad=sz//8)
    return ImageClassifierData.from_paths(PATH, trn_name='train_', val_name='test_', tfms=tfms, bs=bs)

In [6]:
bs=128

In [7]:
data = get_data(32,bs)

In [8]:
lr=0.1

In [11]:
from fastai.sgdr import Callback

class SgdLrUpdater(Callback):
    def __init__(self, layer_opt, init_lr, budget):
        self.layer_opt=layer_opt
        self.init_lr=init_lr
        self.budget=budget
        
    def on_train_begin(self):
        self.epoch = 0
        
    def on_epoch_end(self, metrics):
        self.epoch += 1
        self.update_lr()
        
    def update_lr(self):  
        new_lr = self.calc_lr()
        self.layer_opt.set_lrs([new_lr])
        
    def calc_lr(self):
        if self.epoch < self.budget//2:
            return self.init_lr
        elif self.epoch > 0.9 * self.budget:
            return 0.01 * self.init_lr
        else:
            return self.init_lr - (self.init_lr * 0.99 / int(0.4 * self.budget) * (self.epoch - self.budget//2))

In [35]:
from fastai.sgdr import LoggingCallback

In [None]:
wd = 3e-4
budget = 150

# train 3 preresnet110 models with normal SGD with momentum
for i in range(3):
    preresnet110 = preresnet(depth=110, num_classes=10)
    learn = ConvLearner.from_model_data(preresnet110, data)
    layer_opt = learn.get_layer_opt([lr], [wd])
    learn.crit = F.cross_entropy
    learn.fit_gen(
        learn.model, 
        learn.data, 
        layer_opt, 
        budget, 
        callbacks=[SgdLrUpdater(layer_opt, lr, budget), LoggingCallback(f'{PATH}logs/sgd_{i}.txt')]
    )
    learn.save(f'sgd_{i}')

epoch      trn_loss   val_loss   accuracy                   
    0      1.321136   1.399471   0.486847  
    1      1.042043   1.044005   0.632516                   
    2      0.871521   0.895404   0.688588                    
    3      0.743716   0.824015   0.723398                    
    4      0.679371   0.934189   0.70085                     
    5      0.606691   0.70843    0.767405                    
    6      0.574801   0.637944   0.782437                    
    7      0.550256   0.573916   0.802116                    
    8      0.525207   0.655205   0.77769                     
    9      0.518154   0.669004   0.772053                    
    10     0.512644   0.684118   0.777393                    
    11     0.488844   0.550262   0.810324                    
  2%|▏         | 7/391 [00:08<08:13,  1.28s/it, loss=0.483]

In [10]:
from fastai.sgdr import Callback

class SwaLrUpdater(Callback):
    def __init__(self, layer_opt, init_lr, budget, swa_start, swa_lr):
        self.layer_opt=layer_opt
        self.init_lr=init_lr
        self.budget=budget
        self.swa_start=swa_start
        self.swa_lr=swa_lr
        
    def on_train_begin(self):
        self.epoch = 0
        
    def on_epoch_end(self, metrics):
        self.epoch += 1
        self.update_lr()
        
    def update_lr(self):  
        new_lr = self.calc_lr()
        self.layer_opt.set_lrs([new_lr])
        
    def calc_lr(self):
        if self.epoch < self.swa_start//2:
            return self.init_lr
        elif self.epoch > 0.9 * self.swa_start:
            return self.swa_lr
        else:
            return self.init_lr - ((self.init_lr - self.swa_lr) / int(0.4 * self.swa_start) * (self.epoch - self.swa_start//2))

In [14]:
lr=0.1
swa_lr = 0.01
wd = 3e-4
swa_start = 126

In [None]:
budget = 150

# train 3 preresnet110 models with SWA training schedule
for i in range(3):
    preresnet110 = preresnet(depth=110, num_classes=10)
    learn = ConvLearner.from_model_data(preresnet110, data)
    layer_opt = learn.get_layer_opt([lr], [wd])
    learn.crit = F.cross_entropy
    learn.fit_gen(
        learn.model, 
        learn.data, 
        layer_opt, 
        budget,
        use_swa=True,
        swa_start=swa_start,
        swa_eval_freq=1,
        callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_{i}.txt')]
    )
    learn.save(f'swa_{i}')

In [None]:
# 1.25 budgets
budget = 187

# train 3 preresnet110 models with SWA training schedule and 1.25 budgets
for i in range(3):
    preresnet110 = preresnet(depth=110, num_classes=10)
    learn = ConvLearner.from_model_data(preresnet110, data)
    layer_opt = learn.get_layer_opt([lr], [wd])
    learn.crit = F.cross_entropy
    learn.fit_gen(
        learn.model, 
        learn.data, 
        layer_opt, 
        budget,
        use_swa=True,
        swa_start=swa_start,
        swa_eval_freq=1,
        callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_187_{i}.txt')]
    )
    learn.save(f'swa_187_{i}')

In [None]:
# 1.5 budgets
budget = 225

# train 3 preresnet110 models with SWA training schedule and 1.25 budgets
for i in range(3):
    preresnet110 = preresnet(depth=110, num_classes=10)
    learn = ConvLearner.from_model_data(preresnet110, data)
    layer_opt = learn.get_layer_opt([lr], [wd])
    learn.crit = F.cross_entropy
    learn.fit_gen(
        learn.model, 
        learn.data, 
        layer_opt, 
        budget,
        use_swa=True,
        swa_start=swa_start,
        swa_eval_freq=1,
        callbacks=[SwaLrUpdater(layer_opt, lr, budget, swa_start, swa_lr), LoggingCallback(f'{PATH}logs/swa_225{i}.txt')]
    )
    learn.save(f'swa_225{i}')