In [1]:
%reload_ext autoreload
%autoreload 2

from fastai import *
from fastai.vision import *
from fastai.vision.models.wrn import wrn_22
from fastai.docs import *
from fastai.docs import CIFAR_PATH

torch.backends.cudnn.benchmark = True

### Model Definition

In [2]:
# --
# Model definition
# Derived from models in `https://github.com/kuangliu/pytorch-cifar`

class PreActBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.bn1   = nn.BatchNorm2d(in_channels)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)
            )
    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        return out + shortcut
    
class ResNet18(nn.Module):
    def __init__(self, num_blocks=[2, 2, 2, 2], num_classes=10):
        super().__init__()
        self.in_channels = 64
        self.prep = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layers = nn.Sequential(
            self._make_layer(64, 64, num_blocks[0], stride=1),
            self._make_layer(64, 128, num_blocks[1], stride=2),
            self._make_layer(128, 256, num_blocks[2], stride=2),
            self._make_layer(256, 256, num_blocks[3], stride=2),
        )
        self.classifier = nn.Linear(512, num_classes)
        
    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        
        strides = [stride] + [1] * (num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(PreActBlock(in_channels=in_channels, out_channels=out_channels, stride=stride))
            in_channels = out_channels
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = x.half()
        x = self.prep(x)
        
        x = self.layers(x)
        
        x_avg = F.adaptive_avg_pool2d(x, (1, 1))
        x_avg = x_avg.view(x_avg.size(0), -1)
        
        x_max = F.adaptive_max_pool2d(x, (1, 1))
        x_max = x_max.view(x_max.size(0), -1)
        
        x = torch.cat([x_avg, x_max], dim=-1)
        
        x = self.classifier(x)
        
        return x

In [3]:
untar_data(CIFAR_PATH)

In [6]:
ds_tfms = ([pad(padding=4), crop(size=32, row_pct=(0,1), col_pct=(0,1)), flip_lr(p=0.5)], [])
data = image_data_from_folder(CIFAR_PATH, valid='test', ds_tfms=ds_tfms, tfms=cifar_norm, bs=1024)

In [7]:
learn = Learner(data, ResNet18(), metrics=accuracy).to_fp16().mixup()
learn.bn_wd = False
learn.fit_one_cycle(30, 8e-3, wd=0.2, div_factor=20, pct_start=0.5)

VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value='0.00% [0/30 00:00<00:00]'))), HTML(val…

Total time: 04:33
epoch  train loss  valid loss  accuracy
0      1.863839    1.428895    0.494100  (00:17)
1      1.686357    1.292163    0.555500  (00:08)
2      1.545858    1.143956    0.614600  (00:08)
3      1.453602    0.974595    0.686900  (00:08)
4      1.363806    0.967794    0.683500  (00:08)
5      1.296154    1.037295    0.648200  (00:08)
6      1.252518    0.708865    0.791900  (00:08)
7      1.227285    0.759126    0.763200  (00:08)
8      1.205061    0.805400    0.752900  (00:08)
9      1.174518    0.646748    0.812500  (00:09)
10     1.153771    0.847638    0.750900  (00:08)
11     1.137956    0.817254    0.748400  (00:08)
12     1.130720    0.676482    0.796100  (00:08)
13     1.115529    0.622713    0.814500  (00:08)
14     1.108972    0.630571    0.830100  (00:08)
15     1.094643    0.619395    0.826500  (00:08)
16     1.085728    0.677245    0.807700  (00:08)
17     1.068075    0.560595    0.850900  (00:08)
18     1.052747    0.566342    0.841600  (00:09)
19     1.03

In [8]:
learn = Learner(data, ResNet18(), metrics=accuracy).to_fp16().mixup()
learn.bn_wd = False
learn.fit_one_cycle(30, 1e-2, wd=0.2, div_factor=20, pct_start=0.5)

VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value='0.00% [0/30 00:00<00:00]'))), HTML(val…

Total time: 04:18
epoch  train loss  valid loss  accuracy
0      1.850430    1.401118    0.508300  (00:08)
1      1.661367    1.192342    0.584200  (00:08)
2      1.552455    1.109327    0.624200  (00:08)
3      1.441292    1.005031    0.671800  (00:08)
4      1.360136    1.068661    0.647400  (00:08)
5      1.299787    0.871593    0.731400  (00:08)
6      1.267726    0.787459    0.754300  (00:08)
7      1.235394    0.963737    0.704900  (00:08)
8      1.205380    0.777847    0.759100  (00:08)
9      1.177572    0.898474    0.725700  (00:08)
10     1.162290    0.695065    0.794000  (00:08)
11     1.147247    0.643912    0.819000  (00:08)
12     1.136665    0.650594    0.826300  (00:08)
13     1.125513    0.782817    0.751200  (00:08)
14     1.119277    0.617884    0.826900  (00:08)
15     1.104281    0.730081    0.782700  (00:08)
16     1.093953    0.590395    0.834500  (00:08)
17     1.080366    0.598041    0.832600  (00:08)
18     1.067626    0.527368    0.861200  (00:08)
19     1.04

In [9]:
learn = Learner(data, ResNet18(), metrics=accuracy).to_fp16().mixup()
learn.bn_wd = False
learn.fit_one_cycle(30, 1e-2, wd=0.2, div_factor=25, pct_start=0.5)

VBox(children=(HBox(children=(IntProgress(value=0, max=30), HTML(value='0.00% [0/30 00:00<00:00]'))), HTML(val…

Total time: 04:18
epoch  train loss  valid loss  accuracy
0      1.873983    1.790770    0.371000  (00:08)
1      1.678811    1.212362    0.577800  (00:08)
2      1.537629    1.090717    0.620000  (00:08)
3      1.435880    1.212508    0.579300  (00:08)
4      1.352413    0.849448    0.725900  (00:08)
5      1.299714    0.909035    0.722500  (00:08)
6      1.249889    0.721878    0.787200  (00:08)
7      1.228680    0.781462    0.758700  (00:08)
8      1.198071    0.823750    0.735500  (00:08)
9      1.180463    0.874900    0.727800  (00:08)
10     1.164486    0.875520    0.717600  (00:08)
11     1.148277    0.679772    0.807200  (00:08)
12     1.138567    0.663854    0.806200  (00:08)
13     1.134031    0.945333    0.701200  (00:08)
14     1.118752    0.663251    0.801000  (00:08)
15     1.111242    1.069927    0.651500  (00:08)
16     1.095738    0.713565    0.781000  (00:08)
17     1.081970    0.638267    0.810600  (00:08)
18     1.062134    0.503238    0.876300  (00:08)
19     1.04

In [5]:
class fp16cb(Callback):    
    def on_train_begin(self, n_epochs:int, **kwargs:Any)->None: pass
    def on_loss_begin(self, last_output:Tensor, **kwargs:Any) -> Tensor:
        "Convert half precision output to FP32 to avoid reduction overflow."
        return last_output.float()
