# Vanilla ResNet-18 Implementation

## Preprocess

## Define Neural Network Model

In [3]:
import torch
from torch import nn
import torch.nn.functional as F

In [None]:
class block(nn.Module):
    """
    A fundamental building block for ResNet Structures, which inclues:
    - two convolutional layers each followed by batch normalization and ReLU activation
    - a shortcut connection that adds the input to the output, with subsampling if needed
    Initial codes are from https://github.com/a-martyn/resnet/blob/master/resnet.py
    """
    def __init__(self, filters, subsample=False):
        # s = stride
        # z becomes half the size of x when subsample = True
        s = 0.5 if subsample else 1.0
        
        self.conv1 = nn.Conv2d(int(filters*s), filters, kernel_size=3,
                               stride=int(1/s), padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(filters, track_running_stats=True)
        self.relu1 = nn.ReLU()
        self.conv2d = nn.Conv2d(filters, filters, kernel_size=3, stride=1,
                                padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(filters, track_running_stats=True)
        self.relu2 = nn.ReLU()

        # shortcut subsample
        # this is a mechanism used when the input to the residual block x has different dimensions (either spatial size or number of channels)
        # than the output of the convolutional layers z.
        # instead of pooling we use stride which makes learned subsampling possible.
        self.subsample = nn.AvgPool2d(kernel_size=1, stride=2)

        # weight initialization based on Kaiming He et al., "Delving Deep into Rectifiers: 
        # Surpassing Human-Level Performance on ImageNet Classification"
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # calculate variance based on output channels
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def shortcut(self, z, x):
        """
        shortcut connection using identity or subsampling
        """

        # x.shape: (batch_size, C, H, W)
        # z.shape: (batch_size, 2 * C, 1/2 * H, 1/2 * W)
        # Thus we need to subsample x to match 1/2 * H, 1/2 * W 
        # and pad channels with zeros to match 2 * C
        if x.shape != z.shape:
            # reduces the spatial dimensions of x by half using stride 2 pooling
            d = self.subsample(x)
            # pads the channel dimension with zeros to match z's channels
            p = torch.mul(d, 0)
            # shape matches: (batch_size, 2 * C, 1/2 * H, 1/2 * W) so we can add
            return z + torch.cat((d, p), dim=1)
        else:
            return z + x
        
    
    def forward(self, x, shortcuts=False):
        z = self.conv1(x)
        z = self.bn1(z)
        z = self.relu1(z)
        z = self.conv2d(z)
        z = self.bn2(z)

        if shortcuts:
            z = self.shortcut(z, x)

        z = self.relu2(z)
        return z

In [5]:
class ResNet(nn.Module):
    """
    A general ResNet model.
    Initial codes are from https://github.com/a-martyn/resnet/blob/master/resnet.py
    
    Args:
        n: number of blocks per stack
        shortcuts: whether to use shortcut connections

    Returns:
        output: log-probabilities for each class
    """
    def __init__(self, n, shortcuts=True):
        super().__init__()
        self.shortcuts = shortcuts

        # spatial size is kept the same because of padding=1 and kernel_size=3
        # (B, 3, 32, 32) --> (B, 16, 32, 32)
        self.convIn = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bnIn = nn.BatchNorm2d(16, track_running_stats=True)
        self.relu = nn.ReLU()

        # 16 channels with NO subsampling; repeat n times
        # (B, 16, 32, 32) --> (B, 16, 32, 32)
        self.stack1 = nn.ModuleList([block(16, subsample=False) for _ in range(n)])
        
        # double the channels with subsampling
        # (B, 16, 32, 32) --> (B, 32, 16, 16)
        self.stack2a = block(32, subsample=True)

        # 32 channels with NO subsampling(keep the H, W same); repeat n-1 times
        # (B, 32, 16, 16) --> (B, 32, 16, 16)
        self.stack2b = nn.ModuleList([block(32, subsample=False) for _ in range(n-1)])
        
        # double the channels with subsampling
        # (B, 32, 16, 16) --> (B, 64, 8, 8)
        self.stack3a = block(64, subsample=True)
        # 64 channels with NO subsampling(keep the H, W same); repeat n-1 times
        # (B, 64, 8, 8) --> (B, 64, 8, 8)
        self.stack3b = nn.ModuleList([block(64, subsample=False) for _ in range(n-1)])

        # global average pooling and output layer
        # (B, 64, 8, 8) --> (B, 64, 1, 1)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # final fully connected layer
        self.fcOut = nn.Linear(64, 10, bias=True)
        # log-softmax to predict log-probabilities for each class
        self.softmax = nn.LogSoftmax(dim=-1)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal(m.weight)
                m.bias.data.zero_()

    def forward(self, x):
        z = self.convIn(x)
        z = self.bnIn(z)
        z = self.relu(z)

        for l in self.stack1:
            z = l(z, shortcuts=self.shortcuts)
        z = self.stack2a(z, shortcuts=self.shortcuts)
        for l in self.stack2b:
            z = l(z, shortcuts=self.shortcuts)
        z = self.stack3a(z, shortcuts=self.shortcuts)
        for l in self.stack3b:
            z = l(z, shortcuts=self.shortcuts)
        
        z = self.avgpool(z)
        # flatten
        z = z.view(z.size(0), -1)
        z = self.fcOut(z)
        return self.softmax(z)
