In [None]:
#default_exp vision.models.xsenet

In [None]:
#export
from fastai2.torch_basics import *
from fastai2.test import *

In [None]:
# export
class ProdLayer(Module):
    "Merge a shortcut with the result of the module by multiplying them."
    def forward(self, x): return x * x.orig

In [None]:
#export
inplace_relu = partial(nn.ReLU, inplace=True)

In [None]:
# export
def SEModule(ch, reduction):
    return SequentialEx(nn.AdaptiveAvgPool2d(1), 
                        ConvLayer(ch, ch//reduction, ks=1, norm_type=None, act_cls=inplace_relu),
                        ConvLayer(ch//reduction, ch, ks=1, norm_type=None, act_cls=nn.Sigmoid), 
                        ProdLayer())

In [None]:
tst = SEModule(64, 16)
x = torch.randn(32, 64, 16, 16)
z = F.adaptive_avg_pool2d(x, 1)
z = F.relu(tst.layers[1][0](z))
z = torch.sigmoid(tst.layers[2][0](z))
test_eq(tst(x), x*z)

In [None]:
#export
class SEResNetBlock(Module):
    "SE block from `ni` to `nh` with `stride`"
    def __init__(self, expansion, ni, nf, groups, reduction, nh1=None, nh2=None, stride=1,
                  sa=False, sym=False, act_cls=inplace_relu):
        if nh2 is None: nh2 = nf
        if nh1 is None: nh1 = nh2
        nf,ni = nf*expansion,ni*expansion
        layers  = [ConvLayer(ni,  nh2, 3, act_cls=act_cls, stride=stride, groups=groups),
                   ConvLayer(nh2, nf,  3, act_cls=None, norm_type=NormType.BatchZero)
        ] if expansion == 1 else [
                   ConvLayer(ni,  nh1, 1, act_cls=act_cls),
                   ConvLayer(nh1, nh2, 3, act_cls=act_cls, stride=stride, groups=groups),
                   ConvLayer(nh2, nf,  1, act_cls=None, norm_type=NormType.BatchZero)
        ]
        self.convs = nn.Sequential(*layers)
        self.sa = SimpleSelfAttention(nf,ks=1,sym=sym) if sa else noop
        self.idconv = noop if ni==nf else ConvLayer(ni, nf, 1, act_cls=None)
        self.pool = noop if stride==1 else nn.AvgPool2d(2, ceil_mode=True)
        self.se = SEModule(nf, reduction=reduction)
        self.act = act_cls()

    def forward(self, x): return self.act(self.sa(self.se(self.convs(x))) + self.idconv(self.pool(x)))

In [None]:
#export
def SEBlock(expansion, ni, nf, groups, reduction, stride=1, **kwargs):
    return SEResNetBlock(expansion, ni, nf, groups, reduction, nh1=nf*2, nh2=nf*expansion, stride=stride, **kwargs)

In [None]:
#export
def SEResNeXtBlock(expansion, ni, nf, groups, reduction, stride=1, base_width=4, **kwargs):
    w = math.floor(nf * (base_width / 64)) * groups
    return SEResNetBlock(expansion, ni, nf, groups, reduction, nh2=w, stride=stride, **kwargs)

In [None]:
#export
class XSENet(nn.Sequential):
    def __init__(self, block, expansion, layers, groups, reduction, p=0.2, c_in=3, c_out=1000,
                 sa=False, sym=False, act_cls=defaults.activation):
        stem = []
        sizes = [c_in,16,32,64] if c_in<3 else [c_in,32,32,64]
        for i in range(3):
            stem.append(ConvLayer(sizes[i], sizes[i+1], stride=2 if i==0 else 1))

        block_szs = [64//expansion,64,128,256,512] +[256]*(len(layers)-4)
        blocks = [self._make_layer(block, expansion, block_szs[i], block_szs[i+1], l, groups, reduction,
                                   stride=1 if i==0 else 2, sa=sa if i==len(layers)-4 else False, sym=sym, act_cls=act_cls)
                  for i,l in enumerate(layers)]
        drop = [] if p is None else [nn.Dropout(p)]
        super().__init__(
            *stem,
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            *blocks,
            nn.AdaptiveAvgPool2d(1), Flatten(), *drop,
            init_default(nn.Linear(block_szs[-1]*expansion, c_out)),
        )

    def _make_layer(self, block, expansion, ni, nf, blocks, groups, reduction, stride, sa, sym, act_cls):
        return nn.Sequential(
            *[block(expansion, ni if i==0 else nf, nf, groups, reduction, stride=stride if i==0 else 1,
                   sa=sa if i==blocks else False, sym=sym, act_cls=act_cls)
              for i in range(blocks)])

In [None]:
#export
se_kwargs1 = dict(groups=1 , reduction=16, p=None)
se_kwargs2 = dict(groups=32, reduction=16, p=None)
g0 = [2,2,2,2]
g1 = [3,4,6,3]
g2 = [3,4,23,3]
g3 = [3,8,36,3]

def xse_resnet18(c_out=1000, pretrained=False, **kwargs):         return XSENet(SEResNetBlock,  1, g0, c_out=c_out, **se_kwargs1, **kwargs)
def xse_resnext18_32x4d(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, g0, c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnet34(c_out=1000, pretrained=False, **kwargs):         return XSENet(SEResNetBlock,  1, g1, c_out=c_out, **se_kwargs1, **kwargs)
def xse_resnext34_32x4d(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, g1, c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnet50(c_out=1000, pretrained=False, **kwargs):         return XSENet(SEResNetBlock,  4, g1, c_out=c_out, **se_kwargs1, **kwargs)
def xse_resnext50_32x4d(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 4, g1, c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnet101(c_out=1000, pretrained=False, **kwargs):        return XSENet(SEResNetBlock,  4, g2, c_out=c_out, **se_kwargs1, **kwargs)
def xse_resnext101_32x4d(c_out=1000, pretrained=False, **kwargs): return XSENet(SEResNeXtBlock, 4, g2, c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnet152(c_out=1000, pretrained=False, **kwargs):        return XSENet(SEResNetBlock,  4, g3, c_out=c_out, **se_kwargs1, **kwargs)
def xsenet154(c_out=1000, pretrained=False, **kwargs):
    return SENet(SEBlock, g3, groups=64, reduction=16, p=0.2, c_out=c_out)
def xse_resnext18_deep  (c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, g0+[1,1], c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnext34_deep  (c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, g1+[1,1], c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnext50_deep  (c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 4, g1+[1,1], c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnext18_deeper(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, [2,2,1,1,1,1,1,1], c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnext34_deeper(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 1, [3,4,4,2,2,1,1,1], c_out=c_out, **se_kwargs2, **kwargs)
def xse_resnext50_deeper(c_out=1000, pretrained=False, **kwargs):  return XSENet(SEResNeXtBlock, 4, [3,4,4,2,2,1,1,1], c_out=c_out, **se_kwargs2, **kwargs)

## Export -

In [None]:
#hide
from nbdev.export import notebook2script
notebook2script()

Converted 18_callback_fp16.ipynb.
Converted 03_torchcore.ipynb.
Converted 32_text_models_awdlstm.ipynb.
Converted 03a_layers.ipynb.
Converted 14a_callback_data.ipynb.
Converted 07_data_block.ipynb.
Converted 12_optimizer.ipynb.
Converted 17_callback_tracker.ipynb.
Converted 90_xse_resnext.ipynb.
Converted 71_callback_tensorboard.ipynb.
Converted 60_medical_imaging.ipynb.
Converted 05_data_core.ipynb.
Converted 06_data_transforms.ipynb.
Converted 08_vision_core.ipynb.
Converted 00_test.ipynb.
Converted 01b_core_dispatch.ipynb.
Converted 96_data_external.ipynb.
Converted 15a_vision_models_unet.ipynb.
Converted 01c_core_transform.ipynb.
Converted 13_learner.ipynb.
Converted 36_text_models_qrnn.ipynb.
Converted 97_utils_test.ipynb.
Converted 10_pets_tutorial.ipynb.
Converted 34_callback_rnn.ipynb.
Converted 42_tabular_rapids.ipynb.
Converted 16_callback_progress.ipynb.
Converted 70_callback_wandb.ipynb.
Converted 09a_vision_data.ipynb.
Converted 38_tutorial_ulmfit.ipynb.
Converted 95_index