In [10]:
import torch
import torch.nn as nn
from torchinfo import summary

def residual_function(in_channels, out_channels, stride, bottleneck_ratio):
    
    if bottleneck_ratio == 1:
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(inplace=True),
                              nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
                              nn.BatchNorm2d(out_channels))
    else:
        block = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(inplace=True),
                              nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
                              nn.BatchNorm2d(out_channels),
                              nn.ReLU(inplace=True),
                              nn.Conv2d(out_channels, out_channels * bottleneck_ratio, kernel_size=1, bias=False),
                              nn.BatchNorm2d(out_channels * bottleneck_ratio))  
    return block


def shortcut_function(in_channels, out_channels, stride, bottleneck_ratio):
    
    if stride != 1 or in_channels != out_channels * bottleneck_ratio:
        shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * bottleneck_ratio, kernel_size=1, stride=stride, bias=False),
                                 nn.BatchNorm2d(out_channels * bottleneck_ratio))
    else:
        shortcut = lambda x: x
    
    return shortcut
            

<img src='img/resnet_blocks.png' width=50%/>

In [11]:
class ResidualBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride, bottleneck_ratio):

        super(ResidualBlock, self).__init__()

        self.F = residual_function(in_channels, out_channels, stride, bottleneck_ratio)
        self.shortcut = shortcut_function(in_channels, out_channels, stride, bottleneck_ratio)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        H = self.F(x) + self.shortcut(x)
        return self.relu(H)

<img src='img/resnet_architectures.png' width=90% />

In [12]:
resnet_configs = {
    'resnet18' : {'layer_configs' : [(64, 1, 2), (128, 2, 2), (256, 2, 2), (512, 2, 2)], 'bottleneck_ratio' : 1, 'first_in_channels' : 64},
    'resnet34' : {'layer_configs' : [(64, 1, 3), (128, 2, 4), (256, 2, 6), (512, 2, 3)], 'bottleneck_ratio' : 1, 'first_in_channels' : 64},
    'resnet50' : {'layer_configs' : [(64, 1, 3), (128, 2, 4), (256, 2, 6), (512, 2, 3)], 'bottleneck_ratio' : 4, 'first_in_channels' : 64},
    'resnet101' : {'layer_configs' : [(64, 1, 3), (128, 2, 4), (256, 2, 23), (512, 2, 3)], 'bottleneck_ratio' : 4, 'first_in_channels' : 64},
    'resnet152' : {'layer_configs' : [(64, 1, 3), (128, 2, 8), (256, 2, 36), (512, 2, 3)], 'bottleneck_ratio' : 4, 'first_in_channels' : 64},
}

In [13]:
class ResNet(nn.Module):

    def __init__(self, layer_configs, bottleneck_ratio, num_classes=1000, first_in_channels=64):

        super(ResNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, first_in_channels, kernel_size=7, stride=2, padding=3, bias=False),
                                    nn.BatchNorm2d(first_in_channels),
                                    nn.ReLU(inplace=True),
                                    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

        layers = []
        in_channels = first_in_channels

        for out_channels, stride, n in layer_configs: 
            for i in range (n):
                layers.append(ResidualBlock(in_channels, out_channels, stride if i==0 else 1, bottleneck_ratio))
                in_channels = out_channels * bottleneck_ratio

        self.convs = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(in_channels, num_classes)
        
        self._init_weights()

    def _init_weights(self):

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.convs(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [14]:
n = 1
first_in_channels = 32

model = ResNet(layer_configs=[(16, 1, n), (32, 2, n), (64, 2, n)], bottleneck_ratio=1, first_in_channels=first_in_channels, num_classes=10)
model.conv1 = nn.Sequential(nn.Conv2d(3, first_in_channels, kernel_size=3, stride=1, padding=1, bias=False),
                                    nn.BatchNorm2d(first_in_channels),
                                    nn.ReLU(inplace=True))

In [15]:
summary(model, (1, 3, 32, 32))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   --                        --
├─Sequential: 1-1                        [1, 32, 32, 32]           --
│    └─Conv2d: 2-1                       [1, 32, 32, 32]           864
│    └─BatchNorm2d: 2-2                  [1, 32, 32, 32]           64
│    └─ReLU: 2-3                         [1, 32, 32, 32]           --
├─Sequential: 1-2                        [1, 64, 8, 8]             --
│    └─ResidualBlock: 2-4                [1, 16, 32, 32]           --
│    │    └─Sequential: 3-1              [1, 16, 32, 32]           6,976
│    │    └─Sequential: 3-2              [1, 16, 32, 32]           544
│    │    └─ReLU: 3-3                    [1, 16, 32, 32]           --
│    └─ResidualBlock: 2-5                [1, 32, 16, 16]           --
│    │    └─Sequential: 3-4              [1, 32, 16, 16]           13,952
│    │    └─Sequential: 3-5              [1, 32, 16, 16]           576
│    