In [1]:
from torch import nn
import torch
from functools import partial
import torch.functional as F

In [17]:
def conv(in_channels, out_channels, kernel_size=3, *args, **kwargs):
    return nn.Conv2d(in_channels, out_channels, 
                     bias=False, padding=kernel_size//2, kernel_size=kernel_size, *args, **kwargs)

def conv_bn(in_channels, out_channels, *args, **kwargs):
    return nn.Sequential(
        conv(in_channels, out_channels, *args, **kwargs),
        nn.BatchNorm2d(out_channels)
    )

In [139]:
class ResnetResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, activation=nn.ReLU(inplace=True),
                 expansion=1, downsampling=1):
        super(ResnetResidualBlock, self).__init__()
        self.in_channels, self.out_channels = in_channels, out_channels
        self.blocks = nn.Identity()
        self.expansion = expansion
        self.activation = nn.ReLU(inplace=True)
        self.downsampling = downsampling
        if self.apply_shortcut:
            self.shortcut = conv_bn(in_channels, self.expanded_channels,
                               stride=self.downsampling, kernel_size=1)
    
    @property
    def expanded_channels(self):
        return self.expansion * self.out_channels
    
    @property
    def apply_shortcut(self):
        return self.in_channels != self.expanded_channels

    def forward(self, x):
        residual = x
        if self.apply_shortcut: 
            residual = self.shortcut(x)
        x = self.blocks(x)
        x += residual
        x = self.activation(x)
        return x
        

In [140]:
class BasicBlock(ResnetResidualBlock):
    expansion = 1
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, *args, **kwargs)
        self.blocks = nn.Sequential(
            # need downsampling here if we are doubling the #filters
            conv_bn(in_channels, out_channels, kernel_size=3, stride=self.downsampling),
            nn.ReLU(),
            conv_bn(out_channels, out_channels, kernel_size=3)
        )

In [141]:
class BottleneckBlock(ResnetResidualBlock):
    expansion = 4
    def __init__(self, in_channels, out_channels, *args, **kwargs):
        super().__init__(in_channels, out_channels, expansion=4, *args, **kwargs)
        self.blocks = nn.Sequential(
            # need downsampling here if we are doubling the #filters
            conv_bn(in_channels, out_channels, kernel_size=1),
            nn.ReLU(),
            conv_bn(out_channels, out_channels, kernel_size=3, stride=self.downsampling),
            nn.ReLU(),
            conv_bn(out_channels, self.expanded_channels, kernel_size=1)
        )

In [145]:
class ResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1000,
                block_sizes=[64, 128, 256, 512], depths=[2, 2, 2, 2],
                block=BasicBlock, activation=nn.ReLU(inplace=True), *args, **kwargs):
        super(ResNet, self).__init__()
        # input gate
        self.gate = nn.Sequential(
            conv_bn(in_channels, block_sizes[0], kernel_size=7, stride=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.block_type = block
        # middle blocks
        self.blocks_list = self._make_layers(block_sizes, activation, depths)
        # end decoder
        self.decoder = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten(),
            nn.Linear(self.blocks_list[-1][-1].expanded_channels, out_channels)
        )
        
    def _make_layers(self, block_sizes, activation, depths):
        in_out_block_sizes = list(zip(block_sizes, block_sizes[1:]))
        return nn.ModuleList([
            self._make_layer(block_sizes[0], block_sizes[0], activation, n=depths[0]),
            *[self._make_layer(in_channels * self.block_type.expansion, 
                               out_channels, activation=activation, n=n) 
                  for (in_channels, out_channels), n in zip(in_out_block_sizes, depths[1:])]
        ])
    
    def _make_layer(self, in_channels, out_channels, activation, n=1):
        downsampling = 2 if in_channels != out_channels else 1
        return nn.Sequential(
            self.block_type(in_channels, out_channels, downsampling=downsampling,
                            activation=activation),
            *[self.block_type(out_channels * self.block_type.expansion, out_channels,
                downsampling=1, activation=activation) for _ in range(n - 1)]
        )
    
    def forward(self, x):
        x = self.gate(x)
        for block in self.blocks_list:
            x = block(x)
        x = self.decoder(x)
        return x

In [151]:
# from torchsummary import summary
model = ResNet(3, 1000, depths=[3, 4, 6, 3], block=BottleneckBlock)
summary(model, (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5          [-1, 256, 56, 56]          16,384
       BatchNorm2d-6          [-1, 256, 56, 56]             512
            Conv2d-7           [-1, 64, 56, 56]           4,096
       BatchNorm2d-8           [-1, 64, 56, 56]             128
              ReLU-9           [-1, 64, 56, 56]               0
           Conv2d-10           [-1, 64, 56, 56]          36,864
      BatchNorm2d-11           [-1, 64, 56, 56]             128
             ReLU-12           [-1, 64, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,