### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

### Basic block

In [3]:
# Defining the basic block to be used in the ResNet
# nn.Module is the base class for all neural network modules in pytorch
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        # calls the constructor of the parent class nn.Module (inheriting attributes and methods from parent class)
        super(BasicBlock, self).__init__()
        # in_channels = RGB
        # out_channels = number of different kernels applied to input
        # Kernel size = height and width of sliding window
        # stride = how many pixels does the kernel move each time a convolution is performed
        # padding = number of 0 pixels added to the edge of the input
        # bias = whether or not a bias will be included in convolutions
        # keeping padding at 1 and stride at 1 will preserve spatial dimensions in this concolutional layer
        # including a bias allows the model to learn more complex patterns, but comes at a computational cost
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        # normalises the outputs of the previous layer before use in the next
        # Stabilises training
        # Learns params gamma and beta
        # The normalised activations (outputs) are scaled by gamma and shifted by beta
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Manually setting stride and padding to 1 to preserve spatial dimensions of previous conv layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        # nn.Sequential() creates an empty nn module to create a path for shortcut connections within the res block
        self.shortcut = nn.Sequential()
        # If stride != 1, the spatial dimensions will be reduced
        # If in_channels != out_channels, means that either the stride>1 and the block will reduce the spatial dimensions
        # Or, the number of channels between the input and the output of the residual block needs to change
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    # The standard sequence of operations for a ResNet Block
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x) # output is added with input, representing the residual connection
        out = F.relu(out) # final activation
        return out


### ResNet architecture

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes = 3):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels
        return nn.Sequential(*layers)