In [58]:
import torch 
from torch import nn, optim 
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt 
from tqdm import tqdm
from torchsummary import summary

In [59]:
def ConvLayer(in_channels, out_channels): 
    if in_channels == out_channels : 
        # downsampling does not occur if in_channels == out_channels; i.e. stride = 1
        return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
    else : 
        # downsampling occurs if in_channels != out_channels; i.e. stride = 2
        return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)


class ResidualBlock(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.net = nn.Sequential(
            ConvLayer(in_channels, out_channels),  # first conv layer
            nn.BatchNorm2d(out_channels),          # batch norm before activation
            nn.ReLU(),                             # activation
            ConvLayer(out_channels, out_channels), # second conv layer
            nn.BatchNorm2d(out_channels)           # second batch norm
        )

        self.identity_downsample = nn.Sequential()
        if in_channels != out_channels : 
            # if in_channels != out_channels, i.e. downwsampling occurs
            # downsample identity element as well
            self.identity_downsample = ConvLayer(in_channels, out_channels)
        
    def forward(self, x):
        identity = self.identity_downsample(x) # save identity in variable
        output = self.net(x) + identity        # pass input through and add identity
        return nn.ReLU()(output)               # final ReLU


class ResNet(nn.Module):

    def __init__(self, num_classes, input_dimensions):
        super(ResNet, self).__init__()

        self.net_top = nn.Sequential(
            nn.Conv2d(3,64, kernel_size=7, stride=2, padding=3), # size = in_size / 2
            nn.MaxPool2d(kernel_size=2) # size = in_size / 4
        )

        self.net_mid = nn.Sequential(
            ResidualBlock(64, 64), # size = in_size / 4
            ResidualBlock(64, 64), # size = in_size / 4
            ResidualBlock(64, 64), # size = in_size / 4

            ResidualBlock(64, 128), # size = in_size / 8

            ResidualBlock(128, 128), # size = in_size / 8
            ResidualBlock(128, 128), # size = in_size / 8
            ResidualBlock(128, 128), # size = in_size / 8

            ResidualBlock(128, 256), # size = in_size / 16

            ResidualBlock(256, 256), # size = in_size / 16
            ResidualBlock(256, 256), # size = in_size / 16
            ResidualBlock(256, 256), # size = in_size / 16
            ResidualBlock(256, 256), # size = in_size / 16
            ResidualBlock(256, 256), # size = in_size / 16

            ResidualBlock(256, 512), # size = in_size / 32

            ResidualBlock(512, 512), # size = in_size / 32
            ResidualBlock(512, 512), # size = in_size / 32
        )

        self.net_bottom = nn.Sequential(
            nn.AvgPool2d(kernel_size=2) # size = in_size / 64
        )

        output_size = int(input_dimensions / (2**6))

        self.output_layer = nn.Linear(512 * int(output_size ** 2), num_classes)

    def forward(self, x):
        x = self.net_top(x)
        x = self.net_mid(x)
        x = self.net_bottom(x).flatten(1)
        x = self.output_layer(x)
        x = nn.Softmax(dim=1)(x)
        return x

In [60]:
model = ResNet(10, 256)
summary(model, (3, 256, 256)) # params = (model, input_shape)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           9,472
         MaxPool2d-2           [-1, 64, 64, 64]               0
            Conv2d-3           [-1, 64, 64, 64]          36,928
       BatchNorm2d-4           [-1, 64, 64, 64]             128
              ReLU-5           [-1, 64, 64, 64]               0
            Conv2d-6           [-1, 64, 64, 64]          36,928
       BatchNorm2d-7           [-1, 64, 64, 64]             128
     ResidualBlock-8           [-1, 64, 64, 64]               0
            Conv2d-9           [-1, 64, 64, 64]          36,928
      BatchNorm2d-10           [-1, 64, 64, 64]             128
             ReLU-11           [-1, 64, 64, 64]               0
           Conv2d-12           [-1, 64, 64, 64]          36,928
      BatchNorm2d-13           [-1, 64, 64, 64]             128
    ResidualBlock-14           [-1, 64,