In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_004c import *

# Dogs and cats

## Basic data aug

In [None]:
PATH = Path('data/stl10')

In [None]:
data_mean, data_std = map(tensor, ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
data_norm,data_denorm = normalize_funcs(data_mean,data_std)

In [None]:
train_ds = FilesDataset.from_folder(PATH/'train')
valid_ds = FilesDataset.from_folder(PATH/'valid')

In [None]:
x=Image(valid_ds[2][0])

In [None]:
#export
def uniform_int(low, high, size=None):
    return random.randint(low,high) if size is None else torch.randint(low,high,size)

@TfmPixel
def dihedral(x, k:partial(uniform_int,0,8)):
    flips=[]
    if k&1: flips.append(1)
    if k&2: flips.append(2)
    if flips: x = torch.flip(x,flips)
    if k&4: x = x.transpose(1,2)
    return x.contiguous()

In [None]:
#export
def get_transforms(do_flip=False, flip_vert=False, max_rotate=0., max_zoom=1., max_lighting=0., max_warp=0.,
                   p_affine=0.75, p_lighting=0.5, xtra_tfms=None):
    res = [rand_crop()]
    if do_flip:    res.append(dihedral() if flip_vert else flip_lr(p=0.5))
    if max_warp:   res.append(symmetric_warp(magnitude=(-max_warp,max_warp), p=p_affine))
    if max_rotate: res.append(rotate(degrees=(-max_rotate,max_rotate), p=p_affine))
    if max_zoom>1: res.append(rand_zoom(scale=(1.,max_zoom), p=p_affine))
    if max_lighting:
        res.append(brightness(change=(0.5*(1-max_lighting), 0.5*(1+max_lighting)), p=p_lighting))
        res.append(contrast(scale=(1-max_lighting, 1/(1-max_lighting)), p=p_lighting))
    #       train                   , valid
    return (res + listify(xtra_tfms), [crop_pad()])  

def transform_datasets(train_ds, valid_ds, tfms, size=None):
    return (DatasetTfm(train_ds, tfms[0], size=size),
            DatasetTfm(valid_ds, tfms[1], size=size))

In [None]:
# 14 epochs -> 96.6

In [None]:
size=96

tfms = get_transforms(do_flip=True, max_rotate=5, max_zoom=1.1, max_lighting=0.4, max_warp=0.15)
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)
data = DataBunch(*tds, bs=64, num_workers=8, tfms=data_norm)

## Train

In [None]:
from torchvision.models import resnet18, resnet34, resnet50
arch = resnet50

In [None]:
#export
class AdaptiveConcatPool2d(nn.Module):
    def __init__(self, sz=None):
        super().__init__()
        sz = sz or 1
        self.ap,self.mp = nn.AdaptiveAvgPool2d(sz), nn.AdaptiveMaxPool2d(sz)
    def forward(self, x): return torch.cat([self.mp(x), self.ap(x)], 1)

def create_skeleton(model, cut):
    layers = list(model.children())[:-cut] if cut else [model]
    layers += [AdaptiveConcatPool2d(), Flatten()]
    return nn.Sequential(*layers)

def num_features(m):
    c=list(m.children())
    if len(c)==0: return None
    for l in reversed(c):
        if hasattr(l, 'num_features'): return l.num_features
        res = num_features(l)
        if res is not None: return res

In [None]:
#export
def bn_dp_lin(n_in, n_out, bn=True, dp=0., actn=None):
    layers = [nn.BatchNorm1d(n_in)] if bn else []
    if dp != 0: layers.append(nn.Dropout(dp))
    layers.append(nn.Linear(n_in, n_out))
    if actn is not None: layers.append(actn)
    return layers

def create_head(nf, nc, lin_ftrs=None, dps=None):
    lin_ftrs = [nf, 512, nc] if lin_ftrs is None else [nf] + lin_ftrs + [nc]
    if dps is None: dps = [0.25] * (len(lin_ftrs)-2) + [0.5]
    actns = [nn.ReLU(inplace=True)] * (len(lin_ftrs)-2) + [None]
    layers = []
    for ni,no,dp,actn in zip(lin_ftrs[:-1],lin_ftrs[1:],dps,actns): 
        layers += bn_dp_lin(ni,no,True,dp,actn)
    return nn.Sequential(*layers)

In [None]:
class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, dps=None, **kwargs):
        self.skeleton = create_skeleton(arch(pretrained), cut)
        nf = num_features(self.skeleton) * 2
        # XXX: better way to get num classes
        self.head = create_head(nf, len(data.train_ds.ds.classes), lin_ftrs, dps)
        model = nn.Sequential(self.skeleton, self.head)
        super().__init__(data, model, **kwargs)
    
    def freeze(self):
        for p in self.skeleton.parameters(): p.requires_grad = False
    
    def unfreeze(self):
        for p in self.skeleton.parameters(): p.requires_grad = True

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2)
learn.metrics = [accuracy]
learn.freeze()

In [None]:
lr_find(learn)
learn.recorder.plot()

In [None]:
lr = 1e-3

In [None]:
fit_one_cycle(learn, lr, 3)

In [None]:
learn.save('0')

## Gradual unfreezing

In [None]:
#export
class ConvLearner(Learner):
    def __init__(self, data, arch, cut, pretrained=True, lin_ftrs=None, dps=None, **kwargs):
        self.skeleton = create_skeleton(arch(pretrained), cut)
        nf = num_features(self.skeleton) * 2
        # XXX: better way to get num classes
        self.head = create_head(nf, len(data.train_ds.ds.classes), lin_ftrs, dps)
        model = nn.Sequential(self.skeleton, self.head)
        super().__init__(data, model, **kwargs)
    
    def freeze_to(self, n):
        for g in self.layer_groups[:n]: 
            for p in g.parameters(): p.requires_grad = False
        for g in self.layer_groups[n:]:
            for p in g.parameters(): p.requires_grad = True
            
    def freeze(self): self.freeze_to(len(self.layer_groups))
    def unfreeze(self): self.freeze_to(0)

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-2)
learn.metrics = [accuracy]

In [None]:
learn.load('0')

In [None]:
learn.split(lambda m: (m[0][6], m[1]))

In [None]:
learn.freeze_to(1)

In [None]:
lrs = np.array([lr/9, lr/3, lr])

In [None]:
fit_one_cycle(learn, lrs, 3)

In [None]:
learn.save('1')

In [None]:
learn.load('1')

In [None]:
learn.unfreeze()

In [None]:
fit_one_cycle(learn, lrs/10, 3)

In [None]:
fit_one_cycle(learn, lrs/10, 3)

In [None]:
fit_one_cycle(learn, lrs/100, 3)

In [None]:
learn.save('2')

In [None]:
# TODO remove layer groups; use start_layer / start_lr

## TTA

In [None]:
def transform_datasets(train_ds, valid_ds, tfms, size=None):
    return (DatasetTfm(train_ds, tfms[0], size=size),
            DatasetTfm(valid_ds, tfms[1], size=size),
            DatasetTfm(valid_ds, tfms[0], size=size))

In [None]:
class DataBunch():
    def __init__(self, train_ds, valid_ds, augm_ds, bs=64, device=None, num_workers=4, **kwargs):
        self.device = default_device if device is None else device
        self.train_dl = DeviceDataLoader.create(train_ds, bs,   shuffle=True,  num_workers=num_workers, **kwargs)
        self.valid_dl = DeviceDataLoader.create(valid_ds, bs*2, shuffle=False, num_workers=num_workers, **kwargs)
        self.augm_dl  = DeviceDataLoader.create(augm_ds,  bs*2, shuffle=False, num_workers=num_workers, **kwargs)

    @classmethod
    def create(cls, train_ds, valid_ds, train_tfm=None, valid_tfm=None, dl_tfms=None, **kwargs):
        return cls(DatasetTfm(train_ds, train_tfm), DatasetTfm(valid_ds, valid_tfm), DatasetTfm(valid_ds, train_tfm), 
                   tfms=dl_tfms, **kwargs)

    @property
    def train_ds(self): return self.train_dl.dl.dataset
    @property
    def valid_ds(self): return self.valid_dl.dl.dataset

In [None]:
tds = transform_datasets(train_ds, valid_ds, tfms, size=size)
data = DataBunch(*tds, num_workers=8, tfms=data_norm)

In [None]:
learn = ConvLearner(data, arch, 2, wd=1e-3)
learn.metrics = [accuracy]

In [None]:
learn.load('2')

In [None]:
_,axs = plt.subplots(2,4,figsize=(12,6))
for ax in axs.flat: show_image(tds[2][1][0], ax)

In [None]:
model = learn.model
model.eval();

In [None]:
with torch.no_grad():
    preds,y = zip(*[(model(xb.detach()), yb.detach()) for xb,yb in data.valid_dl])

preds = torch.cat(preds)
y = torch.cat(y)

In [None]:
accuracy(preds, y)

In [None]:
def get_preds(model, dl):
    with torch.no_grad():
        return torch.cat([model(xb.detach()) for xb,yb in dl])

In [None]:
all_preds = torch.stack([get_preds(model, data.augm_dl) for _ in range(4)])

In [None]:
avg_preds = all_preds.mean(0)
avg_preds.shape

In [None]:
accuracy(avg_preds, y)

In [None]:
beta=0.5
accuracy(preds*beta + avg_preds*(1-beta), y)

## Fin