In [None]:
# default_exp vision.models.xresnet

In [None]:
#export
from fastai2.torch_basics import *
from fastai2.test import *
from torchvision.models.utils import load_state_dict_from_url

In [None]:
from nbdev.showdoc import *

# XResnet

> Resnet from bags of tricks paper

In [None]:
#export
def init_cnn(m):
    if getattr(m, 'bias', None) is not None: nn.init.constant_(m.bias, 0)
    if isinstance(m, (nn.Conv2d,nn.Linear)): nn.init.kaiming_normal_(m.weight)
    for l in m.children(): init_cnn(l)

In [None]:
#export
class XResNet(nn.Sequential):
    def __init__(self, expansion, layers, c_in=3, c_out=1000, sa=False, sym=False, act_cls=defaults.activation, norm_type=NormType.Batch):
        stem = []
        sizes = [c_in,16,32,64] if c_in<3 else [c_in,32,32,64]
        for i in range(3):
            stem.append(ConvLayer(sizes[i], sizes[i+1], stride=2 if i==0 else 1, act_cls=act_cls, norm_type=norm_type))

        block_szs = [64//expansion,64,128,256,512] +[256]*(len(layers)-4)
        blocks = [self._make_layer(expansion, block_szs[i], block_szs[i+1], l, 1 if i==0 else 2,
                                  sa = sa if i==len(layers)-4 else False, sym=sym, act_cls=act_cls, norm_type=norm_type)
                  for i,l in enumerate(layers)]
        super().__init__(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *blocks,
            nn.AdaptiveAvgPool2d(1), Flatten(),
            nn.Linear(block_szs[-1]*expansion, c_out),
        )
        init_cnn(self)

    def _make_layer(self, expansion, ni, nf, blocks, stride, sa, sym, act_cls, norm_type):
        return nn.Sequential(
            *[ResBlock(expansion, ni if i==0 else nf, nf, stride if i==0 else 1,
                      sa if i==blocks-1 else False, sym=sym, act_cls=act_cls, norm_type=norm_type)
              for i in range(blocks)])

In [None]:
#export
def _xresnet(pretrained, expansion, layers, **kwargs):
    # TODO pretrain all sizes. Currently will fail with non-xrn50
    url = 'https://s3.amazonaws.com/fast-ai-modelzoo/xrn50_940.pth'
    res = XResNet(expansion, layers, **kwargs)
    if pretrained: res.load_state_dict(load_state_dict_from_url(url, map_location='cpu')['model'], strict=False)
    return res

def xresnet18 (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2, 2,  2, 2], **kwargs)
def xresnet34 (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3, 4,  6, 3], **kwargs)
def xresnet50 (pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 4,  6, 3], **kwargs)
def xresnet101(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 4, 23, 3], **kwargs)
def xresnet152(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3, 8, 36, 3], **kwargs)
def xresnet18_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2,2,2,2,1,1], **kwargs)
def xresnet34_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3,4,6,3,1,1], **kwargs)
def xresnet50_deep  (pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3,4,6,3,1,1], **kwargs)
def xresnet18_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 1, [2,2,1,1,1,1,1,1], **kwargs)
def xresnet34_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 1, [3,4,6,3,1,1,1,1], **kwargs)
def xresnet50_deeper(pretrained=False, **kwargs): return _xresnet(pretrained, 4, [3,4,6,3,1,1,1,1], **kwargs)

In [None]:
tst = xresnet18()

In [None]:
x = torch.randn(64, 3, 128, 128)
y = tst(x)

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_test.ipynb.
Converted 01_core.foundation.ipynb.
Converted 01a_core.utils.ipynb.
Converted 01b_core.dispatch.ipynb.
Converted 01c_core.transform.ipynb.
Converted 02_core.script.ipynb.
Converted 03_torch_core.ipynb.
Converted 03a_layers.ipynb.
Converted 04_data.load.ipynb.
Converted 05_data.core.ipynb.
Converted 06_data.transforms.ipynb.
Converted 07_data.block.ipynb.
Converted 08_vision.core.ipynb.
Converted 09_vision.augment.ipynb.
Converted 09a_vision.data.ipynb.
Converted 09b_vision.utils.ipynb.
Converted 10_tutorial.pets.ipynb.
Converted 11_vision.models.xresnet.ipynb.
Converted 12_optimizer.ipynb.
Converted 13_learner.ipynb.
Converted 13a_metrics.ipynb.
Converted 14_callback.schedule.ipynb.
Converted 14a_callback.data.ipynb.
Converted 15_callback.hook.ipynb.
Converted 15a_vision.models.unet.ipynb.
Converted 16_callback.progress.ipynb.
Converted 17_callback.tracker.ipynb.
Converted 18_callback.fp16.ipynb.
Converted 19_callback.mixup.ipynb.
Converted 20_interpret.ipynb.
