# Implementation of XResNet

This is a dynamic testing for the following paper:

> [Bag of Tricks for Image Classification with Convolutional Neural Networks](https://arxiv.org/abs/1812.01187)


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [3]:
import torch
from torchvision.models import resnet50, resnet101, resnet152
from fastai.vision.all import *

In [24]:
from model.backbone.xresnet import XResNet, RESNET50_LAYERS, RESNET101_LAYERS, RESNET152_LAYERS

In [25]:
data = torch.randn((16, 3, 224, 224))

In [27]:
model = XResNet(RESNET50_LAYERS, out_features=["res2", "res3", "res4", "res5"], num_classes=1000)

In [28]:
outputs = model(data)

In [29]:
assert outputs["res2"].shape == (16, 256, 56, 56)
assert outputs["res3"].shape == (16, 512, 28, 28)
assert outputs["res4"].shape == (16, 1024, 14, 14)
assert outputs["res5"].shape == (16, 2048, 7, 7)

In [30]:
len(model.state_dict().items())

332

In [31]:
model = XResNet(RESNET101_LAYERS, out_features=["res2", "res3", "res4", "res5"], num_classes=1000)

In [32]:
outputs = model(data)

In [33]:
assert outputs["res2"].shape == (16, 256, 56, 56)
assert outputs["res3"].shape == (16, 512, 28, 28)
assert outputs["res4"].shape == (16, 1024, 14, 14)
assert outputs["res5"].shape == (16, 2048, 7, 7)

In [34]:
len(model.state_dict().items())

638

In [35]:
model = XResNet(RESNET152_LAYERS, out_features=["res2", "res3", "res4", "res5"], num_classes=1000)

In [36]:
outputs = model(data)

In [37]:
assert outputs["res2"].shape == (16, 256, 56, 56)
assert outputs["res3"].shape == (16, 512, 28, 28)
assert outputs["res4"].shape == (16, 1024, 14, 14)
assert outputs["res5"].shape == (16, 2048, 7, 7)

In [38]:
len(model.state_dict().items())

944

In [41]:
p_model = resnet50()
len(p_model.state_dict().items())

320

In [42]:
p_model = resnet101()
len(p_model.state_dict().items())

626

In [43]:
p_model = resnet152()
len(p_model.state_dict().items())

932

In [19]:
p_model = xresnet50()
len(p_model.state_dict().items())

332

In [39]:
p_model = xresnet101()
len(p_model.state_dict().items())

638

In [40]:
p_model = xresnet152()
len(p_model.state_dict().items())

944