In [1]:
%load_ext autoreload
%autoreload 2

### WideResnet Model

In [2]:

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class Flatten(nn.Module):
    def __init__(self): super().__init__()
    def forward(self, x): return x.view(x.size(0), -1)
    
def conv_2d(ni, nf, ks, stride): return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=False)

def bn(ni, init_zero=False):
    m = nn.BatchNorm2d(ni)
    m.weight.data.fill_(0 if init_zero else 1)
    m.bias.data.zero_()
    return m

def bn_relu_conv(ni, nf, ks, stride, init_zero=False):
    bn_initzero = bn(ni, init_zero=init_zero)
    return nn.Sequential(bn_initzero, nn.ReLU(inplace=True), conv_2d(ni, nf, ks, stride))

def noop(x): return x

class BasicBlock(nn.Module):
    def __init__(self, ni, nf, stride, drop_p=0.0):
        super().__init__()
        self.bn = nn.BatchNorm2d(ni)
        self.conv1 = conv_2d(ni, nf, 3, stride)
        self.conv2 = bn_relu_conv(nf, nf, 3, 1)
        self.drop = nn.Dropout(drop_p, inplace=True) if drop_p else None
        self.shortcut = conv_2d(ni, nf, 1, stride) if ni != nf else noop

    def forward(self, x):
        x2 = F.relu(self.bn(x), inplace=True)
        r = self.shortcut(x2)
        x = self.conv1(x2)
        if self.drop: x = self.drop(x)
        x = self.conv2(x) * 0.2
        return x.add_(r)


def _make_group(N, ni, nf, block, stride, drop_p):
    return [block(ni if i == 0 else nf, nf, stride if i == 0 else 1, drop_p) for i in range(N)]

class WideResNet(nn.Module):
    def __init__(self, num_groups, N, num_classes, k=1, drop_p=0.0, start_nf=16):
        super().__init__()
        n_channels = [start_nf]
        for i in range(num_groups): n_channels.append(start_nf*(2**i)*k)

        layers = [conv_2d(3, n_channels[0], 3, 1)]  # conv1
        for i in range(num_groups):
            layers += _make_group(N, n_channels[i], n_channels[i+1], BasicBlock, (1 if i==0 else 2), drop_p)

        layers += [nn.BatchNorm2d(n_channels[3]), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
                   Flatten(), nn.Linear(n_channels[3], num_classes)]
        self.features = nn.Sequential(*layers)

    def forward(self, x): return self.features(x)


def wrn_22(): return WideResNet(num_groups=3, N=3, num_classes=10, k=6, drop_p=0.)

# FP16

In [3]:
from nb_004c import *

In [4]:
DATA_PATH = Path('data')
PATH = DATA_PATH/'cifar10'

data_mean,data_std = map(tensor, ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]))
cifar_norm,cifar_denorm = normalize_funcs(data_mean.half(),data_std.half())

train_tfms = [flip_lr(p=0.5),
              pad(padding=4),
              crop(size=32, row_pct=(0,1.), col_pct=(0,1.))]
valid_tfms = []

bs = 512

In [5]:
train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'test')
data = DataBunch.create(train_ds, valid_ds, bs=bs, num_workers=4, 
                        train_tfm=train_tfms, valid_tfm=valid_tfms, dl_tfms=cifar_norm)
len(data.train_dl), len(data.valid_dl)

(98, 10)

In [6]:
data.train_dl.half = True
data.valid_dl.half = True

In [7]:
x, y = next(iter(data.train_dl)); x.shape, y.shape

HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

(torch.Size([512, 3, 32, 32]), torch.Size([512]))

Test with discriminative lrs

In [16]:
model = wrn_22()
model = model2half(model)
learn = Learner(data, model)
learn.metrics = [accuracy]
scheds = [MixedPrecision(learn, flat_master=True), OneCycleScheduler(learn, 1.5, 30, div_factor=20, pct_end=0.2)]

In [17]:
learn.opt_fn = optim.SGD

In [18]:
learn.fit(30, 1.5, wd=1e-4, callbacks=scheds)

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

0 1.439251834535414 1.618476250267029 0.46419999723434446


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

1 1.053591507028088 1.1520830810546876 0.6232999984264374


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

2 0.8151500166765011 0.8019329390764236 0.7261999978065491


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

3 0.6695854059837069 0.9568809501647949 0.6800999973297119


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

4 0.5625933710590443 0.6651814011096955 0.7729999962806702


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

5 0.48748539488358567 0.5345599263191223 0.8159999981880188


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

6 0.43813652966688005 0.45673575782775877 0.8432999984741211


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

7 0.3864736071070316 0.48811323993206024 0.8412999993324279


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

8 0.34909628865934916 0.5535349521636963 0.8170999985694886


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

9 0.3300732426623422 0.45838377656936646 0.8483999974250793


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

10 0.2966116237428217 0.4485387715816498 0.8597999975204468


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

11 0.2817426322851659 0.5293854030609131 0.8387999987602234


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

12 0.2525515532812811 0.4367010248661041 0.8602999988555908


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

13 0.23294538603864678 0.48797142958641054 0.8529999999046326


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

14 0.2117564409128641 0.4194347010850906 0.8698999979972839


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

15 0.1863828925391392 0.39450054783821104 0.8774999994277954


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

16 0.16728826512524886 0.4353257716178894 0.8749999988555908


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

17 0.15090870263519393 0.3474326530814171 0.8944999962806701


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

18 0.12600242945434986 0.31673477783203124 0.9083000005722046


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

19 0.11418178220212159 0.3800867903709412 0.8993999969482422


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

20 0.09191854172026448 0.314341002368927 0.9119


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

21 0.07457368609281706 0.33580230789184573 0.9138999963760376


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

22 0.051604911687948625 0.27890865828990935 0.9278999989509582


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

23 0.035675205809762446 0.27420417280197146 0.9289999964714051


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

24 0.02412612041603711 0.2659908633708954 0.9317999990463257


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

25 0.019304512794141387 0.27977412860393525 0.9303999990463256


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

26 0.01645185911371405 0.28220183203220367 0.9309999991416931


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

27 0.013802060057560425 0.27816363427639007 0.9331999988555908


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

28 0.01295181065892819 0.28306075079441073 0.9318999989509582


HBox(children=(IntProgress(value=0, max=98), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

29 0.012273602693991838 0.280206849861145 0.9332999965667724


In [11]:
learn.opt

OptimWrapper over Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9285714285714285, 0.99)
    eps: 1e-08
    lr: 0.38035714285714284
    weight_decay: 0
).
True weight decay: True

In [10]:
learn.recorder

Recorder(opt=OptimWrapper over Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9285714285714285, 0.99)
    eps: 1e-08
    lr: 0.38035714285714284
    weight_decay: 0
).
True weight decay: True, train_dl=DeviceDataLoader(dl=<torch.utils.data.dataloader.DataLoader object at 0x7f6274498048>, device=device(type='cuda'), progress_func=<function tqdm_notebook at 0x7f6204f6d158>, tfms=functools.partial(<function normalize_batch at 0x7f6200272048>, mean=tensor([0.4910, 0.4819, 0.4470], device='cuda:0', dtype=torch.float16), std=tensor([0.2469, 0.2430, 0.2610], device='cuda:0', dtype=torch.float16)), half=True))

In [None]:
learn.model.layers[0][0].weight.type()

In [None]:
for master in scheds[0].master_params:
    print(master[0].size(),master[0].type())