In [0]:
!pip install torch torchvision feather-format kornia pyarrow Pillow wandb --upgrade 
!pip install git+https://github.com/fastai/fastprogress  --upgrade
!pip install git+https://github.com/fastai/fastai_dev  

In [0]:
from fastai2.basics import *
from fastai2.vision.all import *
from fastai2.callback.all import *
from fastai2.basics import defaults

In [0]:
src = untar_data(URLs.IMAGEWOOF)
items = get_image_files(src)
split_idx = GrandparentSplitter(valid_name='val')(items)

In [0]:
lbl_dict = dict(
  n02086240= 'Shih-Tzu',
  n02087394= 'Rhodesian ridgeback',
  n02088364= 'Beagle',
  n02089973= 'English foxhound',
  n02093754= 'Australian terrier',
  n02096294= 'Border terrier',
  n02099601= 'Golden retriever',
  n02105641= 'Old English sheepdog',
  n02111889= 'Samoyed',
  n02115641= 'Dingo'
)

In [0]:
n_gpus = num_distrib() or 1
nw = min(8, num_cpus()//n_gpus)

split_idx = GrandparentSplitter(valid_name='val')(items)
tfms = [[PILImage.create], [parent_label, lbl_dict.__getitem__, Categorize()]]

dsrc = DataSource(items, tfms, splits=split_idx)

batch_tfms = [Cuda(), IntToFloatTensor(), Normalize(*imagenet_stats)]

dbch = dsrc.databunch(after_item=[ToTensor(),RandomResizedCrop(128, min_scale=0.35), FlipItem()], 
                      after_batch=batch_tfms, 
                      bs=64, num_workers=nw)

In [0]:
class myXResNet(nn.Sequential):
    def __init__(self, expansion, layers, c_in=3, c_out=1000, sa=False, sym=False, act_cls=defaults.activation):
        stem = []
        sizes = [c_in,16,32,64] if c_in<3 else [c_in,32,32,64]
        for i in range(3):
            stem.append(ConvLayer(sizes[i], sizes[i+1], stride=2 if i==0 else 1, act_cls=act_cls))

        block_szs = [64//expansion,64,128,256,512] +[256]*(len(layers)-4)
        blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2,
                                  sa = sa if i == (len(layers)-4) else False, sym=sym, act_cls=act_cls)
                  for i,l in enumerate(layers)]
        super().__init__(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *blocks,
            nn.AdaptiveAvgPool2d(1), Flatten(),
            nn.Linear(block_szs[-1]*expansion, c_out),
        )
        init_cnn(self)

    def _make_layer(self, expansion, ni, nf, blocks, stride, sa, sym, act_cls):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1,
                      sa if i == (blocks-1) else False, sym=sym, act_cls=act_cls)
              for i in range(blocks)])
        
def _myxresnet(pretrained, expansion, layers, **kwargs):
    # TODO pretrain all sizes. Currently will fail with non-xrn50
    url = 'https://s3.amazonaws.com/fast-ai-modelzoo/xrn50_940.pth'
    res = myXResNet(expansion, layers, **kwargs)
    if pretrained: res.load_state_dict(load_state_dict_from_url(url, map_location='cpu')['model'], strict=False)
    return res

def myxresnet50 (pretrained=False, **kwargs): return _myxresnet(pretrained, 4, [3, 4,  6, 3], **kwargs)

In [0]:
def opt_func(ps, lr=defaults.lr): return Lookahead(RAdam(ps, wd=1e-2,mom=0.95, eps=1e-6,lr=lr))
learn = Learner(dbch, myxresnet50(sa=True, c_out=10, act_cls=MishJit), opt_func=opt_func, loss_func=LabelSmoothingCrossEntropy(),
                metrics=accuracy)

In [20]:
learn.fit_flat_cos(5, 4e-3)

epoch,train_loss,valid_loss,accuracy,time
0,1.946718,2.031813,0.346,01:20
1,1.705437,2.231701,0.338,01:19
2,1.544824,1.539842,0.534,01:19
3,1.40836,1.399769,0.618,01:19
4,1.218001,1.238784,0.694,01:18
