## ResNet

We start with a fundamental question: Is learning better networks as easy as stacking more layers?

Stacking more layers adds more expressive power, and hence one may expect that adding more layers should help. However, in practise, deeper neural networks are more difficult to train. 

Practically, when deeper networks are able to start converging, a degradation problem occurs when the depth of the network increases. The accuracy gets saturated and then degrades rapidly.
This is not surprising. There are 2 common reasons why this was thought to happen
1. Vanishing/ Exploding Gradient problem: The deeper the network, higher is the chance for gradients to saturate.
2. Overfitting: With more parameters, the network is more prone to overfitting

However, the authors of the ResNet paper show that the degradation problem is not due to overfitting. Adding more layers lead to a higher training error. If the network was overfitting, then we'd have expected the train error to be lower, but the validation/ test error to be high. Additionally, they also look at the gradients during training to ensure that the gradients are healthy i.e not vanishing/exploding.


How do we solve the degradation problem?

Let us consider a shallower architecture with $n$ layers and its deeper counterpart that adds more layers onto it  ($n+m$ layers).  The deeper architecture should be able to achieve no higher loss than the shallow architecture. Intuitively, a trivial solution is to learn the exact $n$ layers of the shallow architecture, and the identity function for the additional $m$ layers. The fact that this doesn't happen in practice indicates that the neural network layers have a hard time in learning the identity function. Thus the paper proposes "shortcut/skip connections" which enables the layers to potentially learn the identity function easily.  This “identity shortcut connection” is the core idea of ResNet.

ResNet is a key architecture for deep learning models. It has inspired several variants, and is one of the most popular architectures in use. Details of ResNet can be found in https://arxiv.org/pdf/1512.03385.pdf.

We will spend the subsequent sections of this notebook in trying to understand the key ideas behind ResNet.

## Residual Learning

Let $g(x)$ be the function learned by a stack of layers (not necessarily the entire network).
Let’s consider $h(x)= g(x)+x$, i.e the output of the stack of layers with skip connections. Here $+x$ term denotes the skip connection i.e the input $x$ is directly added to the output of the stack of layers. This is called a skip connection because the input bypasses the intermediate layer and is fed into the deeper layer with a direct path.

The output $h(x)$ already contains information about the input $x$. So the neural network layers need to learn the function $g(x)=h(x)-x$, which is the change in the value/delta/residue. Hence the name residual networks.

Now let us revisit the earlier problem of degradation. We posited that normal neural network layers in general have a hard time learning the identity function. In case of the residual learning, if identity functions are optimal, then to learn the identity function, $h(x)=x$, the layers need to learn $g(x)$=0. This can easily be done by driving all the weights of the layers to 0. 

Another way to think about it is: if we initialize a regular neural network’s weights and biases to be 0 at the start, then every layer starts with the “zero” function i.e $g(x)=0$. Thus, the output of every stack of layers with a shortcut connection, $h(x)=g(x)+x$, is already the identity function i.e $h(x)=x$ when $g(x)=0$

In real cases, it is important to note that it is unlikely that identity mappings are optimal i.e the network layers will want to learn actual features. In such cases, this reformulation isn't preventing the network lawyers from doing so. So the layers can still learn other functions like regular stack of layers. We can think of this reformulation as a  preconditioning which makes learning the identity function easier if needed.


Additionally, by adding skip connections, we are allowing a direct path for the gradient to flow from layer to layer i.e the deeper layer has a direct path to $x$. This allows for better learning as information from the lower layers pass directly into the higher layers.

Let us now consider how the shortcut connection works. We will do so by implementing the basic skip connection block.

In [1]:
import torch
import torchvision

from torch import nn
from torchsummary import summary

In [2]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, num_features, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Sequential(
                        nn.Conv2d(in_channels, num_features, kernel_size=3, stride=stride, padding=1, bias=False),
                        nn.BatchNorm2d(num_features, eps=0.001),
                        nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(
                        nn.Conv2d(num_features, num_features, kernel_size=3, stride=1, padding=1, bias=False),
                        nn.BatchNorm2d(num_features, eps=0.001))
        self.downsample = downsample
        self.relu = nn.ReLU(inplace=True)
        
        
    def forward(self, x):
        conv_out = self.conv2(self.conv1(x))
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        assert identity.shape == conv_out.shape, f"Identity {identity.shape} and conv out {conv_out.shape} have different shapes"
        
        # Skip connection
        out = self.relu(conv_out+identity)
        return out

In [3]:
x = torch.rand([1, 64, 28, 28])
residual_block = BasicBlock(in_channels=64, num_features=64, stride=1)

out = residual_block(x)
assert out.shape == torch.Size([1, 64, 28, 28])

Notice how the output of the residual block is a function of both the input and conv layer out i.e $ReLU(conv\_out+x)$ This assumes that $x$ and $conv\_out$ have the same shape. We will study subsequently about what needs to be done when this isn't the case. 

Also, note that adding the skip connections does not increase the number of parameters. The shortcut connections are parameter free. This makes the solution cheap from a computational point of view. This is one of the charms of shortcut connections.

## ResNet Architecture

Now that we have studied the basic building block i.e a stack of conv layers with a skip connection, let us delve deeper into the architecture of ResNet. 

ResNet architectures are constructed by stacking multiple building blocks on top of each other. They follow the same idea as VGG i.e
1. The convolutional layers mostly have 3×3 filters
2. The layers have the same number of filters for a given  output feature map size.
3. If the feature map size is halved, the number of filters is doubled so as to preserve the time complexity per layer.

ResNet however uses Conv layers with a stride=2 to downsample, unlike VGG (which had multiple max pooling layers)

The core architecture consists of the following components:

1. 5 Convolutional Layer blocks:

   The first convolutional block consists of a 7x7 kernel, with stride=2, padding=3, num_features=64 followed by a Max Pooling layer with a 3x3 kernel, stride=2, padding=1. The feature map size is reduced from (224, 224) to (56, 56).
   
   The remaining convolutional blocks (ResidualConvBlock) are built by stacking multiple basic shortcut blocks together. Each basic block uses 3x3 filters as described above.
   
  
2. Classifier: An average pooling block which runs on top of the conv block output, followed by an Fully Connected layer, which is used for classification



Let us now implement ResNet-34 from scratch. In practice, this is seldom done. `torchvision.models` already provides ready-made implementations for all the ResNet architectures. However by building the network from scratch, we will gain a deeper understanding of the architecture.


We have already looked at the BasicBlock which implements the shortcut connection. Now we will implement a residual conv block which consists of a number of basic blocks stacked on top of each other.

We have to handle 2 cases when it comes to basic blocks.
 
#### Case 1: Output feature map size = Input Feature map size &  Number of output features = Number of input features

This is the most common case. Since there is no change in either the num_features or the feature_map size, we can easily add the input and output via shortcut connections.

#### Case 2: Output feature map size =  1/2 *  Input Feature map size &  Number of output features = 2 * Number of input features

Remember that ResNet uses Conv layers with a stride=2 to downsample. Additionally the number of features are also doubled. 

This is done by the first Basic Block of every Conv Block (except 2nd Conv Block). The output feature map size is reduced by using a 3x3 convolution with a stride=2.

In this case the input and output are not of the same size. So how do we add them together as part of the skip connection? 1x1 convs are the answer.
The input feature map size is downsampled and the number of input features are upsampled using a 1x1 conv with stride=2, num_features=2 * Number of input features.

In [4]:
class ResidualConvBlock(nn.Module):
    def __init__(self, in_channels, num_blocks, reduce_fm_size=True):
        super(ResidualConvBlock, self).__init__()
        
        num_features = in_channels * 2 if reduce_fm_size else in_channels
        modules = []
    
        for i in range(num_blocks):
            if i == 0 and reduce_fm_size:
                # Case 2
                stride = 2 
                downsample = nn.Sequential(
                    nn.Conv2d(in_channels, num_features, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(num_features, eps=0.001),
                )
                basic_block = BasicBlock(in_channels=in_channels, num_features=num_features, stride=stride,
                                        downsample=downsample)
            else:
                # Case 1
                basic_block = BasicBlock(in_channels=num_features, num_features=num_features, stride=1)
            modules.append(basic_block)
        
        self.conv_block = nn.Sequential(*modules)
    
    
    def forward(self, x):
        return self.conv_block(x)
            

In [5]:
# Case 1
x = torch.rand([1, 64, 56, 56])
conv_block = ResidualConvBlock(64, 3, reduce_fm_size=False)

y = conv_block(x)
assert y.shape == torch.Size([1, 64, 56, 56])


# Case 2
x = torch.rand([1, 64, 56, 56])
conv_block = ResidualConvBlock(64, 3, reduce_fm_size=True)

y = conv_block(x)
assert y.shape == torch.Size([1, 64*2, 56//2, 56//2])

Now we are ready to implement ResNet-34. 

In [6]:
class ResNet(nn.Module):
    def __init__(self, num_basic_blocks, num_classes):
        super(ResNet, self).__init__()
        conv1 = nn.Sequential(
                              nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
                              nn.BatchNorm2d(64, eps=0.001),
                              nn.ReLU(inplace=True),
                              nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
                
        assert len(num_basic_blocks) == 4
        # num_basic_blocks is a list of size 4, which specifies the number of basic blocks per ResidualConvBlock
        conv2 = ResidualConvBlock(in_channels=64, num_blocks=num_basic_blocks[0], reduce_fm_size=False)
        conv3 = ResidualConvBlock(in_channels=64, num_blocks=num_basic_blocks[1], reduce_fm_size=True)
        conv4 = ResidualConvBlock(in_channels=128, num_blocks=num_basic_blocks[2], reduce_fm_size=True)
        conv5 = ResidualConvBlock(in_channels=256, num_blocks=num_basic_blocks[3], reduce_fm_size=True)
        
        self.conv_backbone = nn.Sequential(*[conv1, conv2, conv3, conv4, conv5])
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(512, num_classes)
        
    
    def forward(self, x):
        conv_out = self.conv_backbone(x)
        conv_out = self.avg_pool(conv_out)
        # We need to flatten the conv features before passing it to the classifier
        logits = self.classifier(conv_out.view(conv_out.shape[0], -1)) 
        return logits

In [7]:
num_classes = 1000
resnet34 = ResNet([3, 4, 6, 3], num_classes)

x = torch.rand([1, 3, 224, 224])
logits = resnet34(x)
assert logits.shape == torch.Size([1, num_classes])

In [8]:
resnet34

ResNet(
  (conv_backbone): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    )
    (1): ResidualConvBlock(
      (conv_block): Sequential(
        (0): BasicBlock(
          (conv1): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU(inplace=True)
          )
          (conv2): Sequential(
            (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
          )
          (relu): ReLU(inplace=True)
        )
 

In [9]:
# We can now take a look at the summary to visualize the output shape, number of parameters and the layers
summary(resnet34, input_size=(3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]          36,864
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
       BasicBlock-11           [-1, 64, 56, 56]               0
           Conv2d-12           [-1, 64, 56, 56]          36,864
      BatchNorm2d-13           [-1, 64, 56, 56]             128
             ReLU-14           [-1, 64,

In [10]:
# As an elementary check, let us compare the number of parameters between our implementation
# and the official torchivision implementation and assert that they are equal

num_resnet_params = sum(p.numel() for p in resnet34.parameters() if p.requires_grad)

torch_resnet34 = torchvision.models.resnet34()
num_torch_resnet_params = sum(p.numel() for p in torch_resnet34.parameters() if p.requires_grad)

assert num_resnet_params == num_torch_resnet_params

And there we have our own simple implementation of Resnet-34. Note that this is a barebones implementation to get a sense for the broad architecture. 

Deeper ResNets (ResNet50, ResNet101, Resnet151) use a different type of basic block called the bottleneck layer. Similarly there are several other variants inspired by ResNet like ResNext, Wide Resnet etc. However the core idea behind them remains the same.