In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0' 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
import numpy

In [4]:
from fastai.script import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.distributed import *
from fastprogress import fastprogress
from torchvision.models import *
from fastai.vision.models.xresnet import *
from fastai.vision.models.xresnet2 import *
from fastai.vision.models.presnet import *

In [5]:
torch.backends.cudnn.benchmark = True

# XResNet with Self Attention

In [6]:
#Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
def conv1d(ni:int, no:int, ks:int=1, stride:int=1, padding:int=0, bias:bool=False):
    "Create and initialize a `nn.Conv1d` layer with spectral normalization."
    conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
    nn.init.kaiming_normal_(conv.weight)
    if bias: conv.bias.data.zero_()
    return spectral_norm(conv)



# Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py
# Inspired by https://arxiv.org/pdf/1805.08318.pdf
class SimpleSelfAttention(nn.Module):
    
    def __init__(self, n_in:int, ks=1):#, n_out:int):
        super().__init__()
        
        
        self.n_in = n_in
        self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False)
       
       
        self.gamma = nn.Parameter(tensor([0.]))
        
        

    def forward(self,x):
        # symmetry hack
        c = self.conv.weight.view(self.n_in,self.n_in)
        c = (c + c.t())/2
        self.conv.weight = c.view(self.n_in,self.n_in,1)
        
        size = x.size()
        x = x.view(*size[:2],-1)
        o = torch.bmm(x.permute(0,2,1).contiguous(),self.conv(x))
        
       
        o = self.gamma * torch.bmm(x,o) + x
        
           
        return o.view(*size).contiguous()        
        

In [7]:
#unmodified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
act_fn = nn.ReLU(inplace=True)

def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

def conv(ni, nf, ks=3, stride=1, bias=False):
    return nn.Conv2d(ni, nf, kernel_size=ks, stride=stride, padding=ks//2, bias=bias)

def noop(x): return x

def conv_layer(ni, nf, ks=3, stride=1, zero_bn=False, act=True):
    bn = nn.BatchNorm2d(nf)
    nn.init.constant_(bn.weight, 0. if zero_bn else 1.)
    layers = [conv(ni, nf, ks, stride=stride), bn]
    if act: layers.append(act_fn)
    return nn.Sequential(*layers)

In [8]:
# Modified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
# Added self attention
class ResBlock(nn.Module):
    def __init__(self, expansion, ni, nh, stride=1,sa=False):
        super().__init__()
        
        
        nf,ni = nh*expansion,ni*expansion
        layers  = [conv_layer(ni, nh, 3, stride=stride),
                   conv_layer(nh, nf, 3, zero_bn=True, act=False)
        ] if expansion == 1 else [
                   conv_layer(ni, nh, 1),
                   conv_layer(nh, nh, 3, stride=stride),
                   
                   conv_layer(nh, nf, 1, zero_bn=True, act=False)
                
        ]
        
        self.sa = SimpleSelfAttention(nf,ks=1) if sa else noop
        
        self.convs = nn.Sequential(*layers)
        self.idconv = noop if ni==nf else conv_layer(ni, nf, 1, act=False)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)

    def forward(self, x): 
        
        
        return act_fn(self.sa(self.convs(x)) + self.idconv(self.pool(x)))
        

In [9]:
# Modified from https://github.com/fastai/fastai/blob/9b9014b8967186dc70c65ca7dcddca1a1232d99d/fastai/vision/models/xresnet.py
# Added self attention

class XResNet_sa(nn.Sequential):
    @classmethod
    def create(cls, expansion, layers, c_in=3, c_out=1000):
        nfs = [c_in, (c_in+1)*8, 64, 64]
        stem = [conv_layer(nfs[i], nfs[i+1], stride=2 if i==0 else 1)
            for i in range(3)]

        nfs = [64//expansion,64,128,256,512]
        res_layers = [cls._make_layer(expansion, nfs[i], nfs[i+1],
                                      n_blocks=l, stride=1 if i==0 else 2, sa = True if i in[len(layers)-4] else False)
                  for i,l in enumerate(layers)]
        res = cls(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *res_layers,
            
            nn.AdaptiveAvgPool2d(1), Flatten(),
            nn.Linear(nfs[-1]*expansion, c_out),
        )
        init_cnn(res)
        return res

    @staticmethod
    def _make_layer(expansion, ni, nf, n_blocks, stride, sa = False):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1, sa if i in [n_blocks -1] else False)
              for i in range(n_blocks)])

In [10]:
def xresnet50_sa (**kwargs): return XResNet_sa.create(4, [3, 4,  6, 3], **kwargs)

# Data loading

In [11]:
#https://github.com/fastai/fastai/blob/master/examples/train_imagenette.py

def get_data(size, woof, bs, workers=None):
    if   size<=128: path = URLs.IMAGEWOOF_160 if woof else URLs.IMAGENETTE_160
    elif size<=224: path = URLs.IMAGEWOOF_320 if woof else URLs.IMAGENETTE_320
    else          : path = URLs.IMAGEWOOF     if woof else URLs.IMAGENETTE
    path = untar_data(path)

    n_gpus = num_distrib() or 1
    if workers is None: workers = min(8, num_cpus()//n_gpus)

    return (ImageList.from_folder(path).split_by_folder(valid='val')
            .label_from_folder().transform(([flip_lr(p=0.5)], []), size=size)
            .databunch(bs=bs, num_workers=workers)
            .presize(size, scale=(0.35,1))
            .normalize(imagenet_stats))

# Train

In [12]:
opt_func = partial(optim.Adam, betas=(0.9,0.99), eps=1e-6)

In [13]:
# we use the same parameters for baseline and new model
bs = 64
lr = 3e-3
mixup = 0
num_loop = 30

##### New model

In [14]:
def do_cycle(epochs = 5):
    m = xresnet50_sa(c_out=10)
    learn = None
    gc.collect()
    learn = (Learner(data, m, wd=1e-2, opt_func=opt_func,
             metrics=[accuracy,top_k_accuracy],
             bn_wd=False, true_wd=True,
             loss_func = LabelSmoothingCrossEntropy())
            )
    if mixup: learn = learn.mixup(alpha=mixup)
    learn = learn.to_fp16(dynamic=True)
    learn.fit_one_cycle(epochs, lr, div_factor=10, pct_start=0.3)
    val_preds,val_targets = learn.get_preds()
    return accuracy(val_preds, tensor(learn.data.valid_ds.y.items)).item()*100

## Imagenette

In [None]:
image_size = 128
data = get_data(image_size,woof =False,bs=bs)

In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=5))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[85.79999804496765, 85.19999980926514, 85.00000238418579, 86.59999966621399, 86.40000224113464, 87.00000047683716, 86.79999709129333, 86.79999709129333, 86.19999885559082, 87.99999952316284, 86.40000224113464, 86.59999966621399, 86.79999709129333, 84.60000157356262, 85.00000238418579, 85.79999804496765, 86.00000143051147, 86.00000143051147, 85.00000238418579, 85.79999804496765, 84.3999981880188, 87.1999979019165, 87.1999979019165, 86.40000224113464, 86.40000224113464, 85.79999804496765, 85.00000238418579, 85.6000006198883, 85.00000238418579, 84.60000157356262]
85.98000009854634 0.8874675041324023 84.3999981880188 87.99999952316284


In [None]:
image_size = 256
data = get_data(image_size,woof =False,bs=bs)

In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=5))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[88.59999775886536, 88.40000033378601, 87.8000020980835, 85.79999804496765, 87.8000020980835, 87.99999952316284, 87.59999871253967, 88.59999775886536, 87.40000128746033, 88.40000033378601, 86.40000224113464, 88.40000033378601, 88.40000033378601, 87.40000128746033, 88.40000033378601, 88.20000290870667, 87.00000047683716, 88.59999775886536, 88.40000033378601, 87.40000128746033, 88.80000114440918, 88.40000033378601, 88.99999856948853, 87.99999952316284, 87.1999979019165, 87.40000128746033, 87.59999871253967, 88.59999775886536, 86.40000224113464, 88.80000114440918]
87.9066667954127 0.7758576185495241 85.79999804496765 88.99999856948853


In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=20))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[93.99999976158142, 95.20000219345093, 93.99999976158142, 93.80000233650208, 94.59999799728394, 94.40000057220459, 94.19999718666077, 93.19999814033508, 94.59999799728394, 93.59999895095825, 94.19999718666077, 94.40000057220459, 94.80000138282776, 93.99999976158142, 94.9999988079071, 94.80000138282776, 94.19999718666077, 94.59999799728394, 94.59999799728394, 93.59999895095825, 94.19999718666077, 95.20000219345093, 94.59999799728394, 94.40000057220459, 95.39999961853027, 94.19999718666077, 93.80000233650208, 94.9999988079071, 93.80000233650208, 94.19999718666077]
94.3533327182134 0.515579715127874 93.19999814033508 95.39999961853027


## Imagewoof

In [15]:
image_size = 128
data = get_data(image_size,woof =True,bs=bs)

In [16]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=5))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[63.40000033378601, 65.79999923706055, 65.6000018119812, 69.59999799728394, 66.00000262260437, 65.20000100135803, 66.00000262260437, 68.19999814033508, 65.20000100135803, 66.79999828338623, 65.39999842643738, 66.39999747276306, 64.20000195503235, 65.39999842643738, 66.60000085830688, 66.60000085830688, 64.99999761581421, 66.79999828338623, 68.00000071525574, 65.20000100135803, 67.00000166893005, 64.20000195503235, 64.99999761581421, 64.3999993801117, 66.79999828338623, 66.60000085830688, 68.59999895095825, 63.999998569488525, 66.39999747276306, 66.79999828338623]
66.03999972343445 1.391785018381755 63.40000033378601 69.59999799728394


In [17]:
image_size = 256
data = get_data(image_size,woof =True,bs=bs)

In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=5))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[69.9999988079071, 67.79999732971191, 64.80000019073486, 66.20000004768372, 67.40000247955322, 69.19999718666077, 69.40000057220459, 68.99999976158142, 65.39999842643738, 66.20000004768372, 67.40000247955322, 69.59999799728394, 66.20000004768372, 68.99999976158142, 67.1999990940094, 68.4000015258789, 69.9999988079071, 72.2000002861023, 68.4000015258789, 68.99999976158142, 68.00000071525574, 63.999998569488525, 68.19999814033508, 64.99999761581421, 67.59999990463257, 66.60000085830688, 65.6000018119812, 64.99999761581421, 66.39999747276306, 64.99999761581421]
67.47333288192749 1.899111093933636 63.999998569488525 72.2000002861023


In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=20))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[84.79999899864197, 85.00000238418579, 85.39999723434448, 83.79999995231628, 85.79999804496765, 86.79999709129333, 84.79999899864197, 85.6000006198883, 85.00000238418579, 85.6000006198883, 85.19999980926514, 85.19999980926514, 84.79999899864197, 85.00000238418579, 85.79999804496765, 85.39999723434448, 86.00000143051147, 84.3999981880188, 85.39999723434448, 86.00000143051147, 86.59999966621399, 85.6000006198883, 84.60000157356262, 86.40000224113464, 85.6000006198883, 84.20000076293945, 84.3999981880188, 85.6000006198883, 83.79999995231628, 85.19999980926514]
85.25999983151753 0.728743141366321 83.79999995231628 86.79999709129333


In [None]:
results = []
for i in range(num_loop):
    results.append(do_cycle(epochs=80))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[89.99999761581421, 90.79999923706055, 88.99999856948853, 89.60000276565552, 89.3999993801117, 88.40000033378601, 88.80000114440918, 89.99999761581421, 89.99999761581421, 88.99999856948853, 89.60000276565552, 89.99999761581421, 90.20000100135803, 89.80000019073486, 88.59999775886536, 89.20000195503235, 89.99999761581421, 88.59999775886536, 89.60000276565552, 90.20000100135803, 89.20000195503235, 89.80000019073486, 89.99999761581421, 89.99999761581421, 88.40000033378601, 88.99999856948853, 89.99999761581421, 89.99999761581421, 88.99999856948853, 88.20000290870667]
89.47999954223633 0.6462193032336647 88.20000290870667 90.79999923706055


In [18]:
%%time
results = []
for i in range(num_loop//5):
    results.append(do_cycle(epochs=400))
    
print(results)
print(np.mean(results), np.std(results), np.min(results), np.max(results))

[63.40000033378601, 65.79999923706055, 65.6000018119812, 69.59999799728394, 66.00000262260437, 65.20000100135803, 66.00000262260437, 68.19999814033508, 65.20000100135803, 66.79999828338623, 65.39999842643738, 66.39999747276306, 64.20000195503235, 65.39999842643738, 66.60000085830688, 66.60000085830688, 64.99999761581421, 66.79999828338623, 68.00000071525574, 65.20000100135803, 67.00000166893005, 64.20000195503235, 64.99999761581421, 64.3999993801117, 66.79999828338623, 66.60000085830688, 68.59999895095825, 63.999998569488525, 66.39999747276306, 66.79999828338623, 88.40000033378601, 89.3999993801117, 89.20000195503235, 89.60000276565552, 89.20000195503235, 88.80000114440918]
69.88333331214056 8.688866313784294 63.40000033378601 89.60000276565552
CPU times: user 7d 9h 35min 37s, sys: 7d 4h 39min 2s, total: 14d 14h 14min 39s
Wall time: 2d 5h 18min 4s
