# Goal: Implement a (Rudimentary) ResNet Block and Visualize Loss Landscapes

In this demo, we will implement a simple (and incomplete) residual block from the seminal ResNet paper. We will then compare a VGG network with a ResNet network, and finally, we will visualize the loss landscape of ResNets with and without skip connections.

- [ResNet Paper](https://arxiv.org/abs/1512.03385)
- [Visualizing the Loss Landscape of Neural Nets](https://arxiv.org/abs/1712.09913)

## Imports

Let's start by importing our favorite packages.

In [1]:
import torch
import torch.nn as nn

import torchvision
from torchvision import transforms
from torchsummary import summary

from torch.nn.utils import (
  parameters_to_vector as Params2Vec,
  vector_to_parameters as Vec2Params
)

from tqdm import tqdm

# Simple (and Incomplete) Implementation of ResNet Building Block

The image below describes the basic idea of the Residual Block proposed in the paper. Rather than directly processing an input feature $x$ through a series of linear and non-linear transformations, we add a skip connection after these transformations and add back in our input $x$.

![](https://raw.githubusercontent.com/kvgarimella/dl-demos/main/imgs/residual_block.png)

Residual blocks have become ubiqitous in deep networks, even in transformer-based networks and have enabled the training of rather deep networks (for example, thousands of layers deep). We will stick to the case of CNNs where the **weight layer** in the above image is a 2D convolutional layer. Another common rule of thumb is to see `BatchNorm` added directly after a `Conv` layer (as discussed in class). Since `BatchNorm` contains learned parameters, we can think of them as being part of our weight layer. The image below shows residual blocks in action in a 34-layer networks (note that adding a skip connection does not increase the number of trainable parameters).

![](https://raw.githubusercontent.com/kvgarimella/dl-demos/main/imgs/resnet34.png)

We can see that each weight layer consists of a $3 \times 3$ convolution. In this particular image, the number of channels is $64$ although this varies depending upon what stage of the network you are in (see Figure 3 from the paper).

To build a simplified residual block, let's place some constriants on our system:

- we will assume the number of channels is constant from the input and output of the residual block
- we will assume weight layer means a convolutional layer followed by a batchnorm layer
- we won't worry about stride (i.e. reducing the resolution of the image)

## Step 1: Filling in the `__init__` function

Let's build our residual block as a Torch Module. This means we need both an `__init__` and `forward` function. For the `__init__` function, we need to initialize two convolutional layers, two batchnorm layers, and two ReLU layers. We will have an input parameter called `num_channels` which will be the number of channels we expect our input image to have, and we will keep the number of channels constant throughout our residual block.

In [2]:
class BuildingBlock(nn.Module):
    def __init__(self, num_channels):
        super(BuildingBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        pass

Great, now we have our init function (for the most part) complete.

## Step 2: A first attempt at `forward`
Let's now take a stab at the forward pass. We will model it after our image of the residual block above. We will also **print out the `shape` of the image after each weight layer**.

In [3]:
class BuildingBlock(nn.Module):
    def __init__(self, num_channels):
        super(BuildingBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        # forward through the first weight layer 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        print("shape after first weight layer:", out.shape)

        # forward through the second weight layer
        out = self.conv2(out)
        out = self.bn2(out)

        print("shape after second weight layer:", out.shape)

        # our skip connection: adding back in x after both weight layers  
        out += x 

        out = self.relu2(out)
        return out

Okay, let's instantiate a random torch "image" and attempt to perform forward propagation through our residual block!

In [4]:
NUM_CHANNELS = 64
x = torch.randn(1, NUM_CHANNELS, 50, 50)
print(x.shape)

torch.Size([1, 64, 50, 50])


We have an "image" of $B \times C \times H \times W = 1 \times 64 \times 50 \times 50$. Let's pass this through our residual block:

In [5]:
building_block = BuildingBlock(num_channels=NUM_CHANNELS)
y = building_block(x)

shape after first weight layer: torch.Size([1, 64, 48, 48])
shape after second weight layer: torch.Size([1, 64, 46, 46])


RuntimeError: The size of tensor a (46) must match the size of tensor b (50) at non-singleton dimension 3

**We should see an error message telling us that our tensor don't have the same size**. In particular, our image has shrunk in size. We started with an image of height (and width) of 50. But after the first Conv layer, our image is now of size 48. And after the second Conv layer, the image has a height and width of 46. In order to add back in $x$ after going through each weight layer, we need to ensure that the output of each Conv layer has the same height and width. One way to handle this is with padding.

## Step 3: Fixing our Implementation with Padding
Let's add in padding for each Conv layer in our `__init__` function. We will just need to set `padding=1` for both layers.

In [6]:
class BuildingBlock(nn.Module):
    def __init__(self, num_channels):
        super(BuildingBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, padding=1) # padding added
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(in_channels=num_channels, out_channels=num_channels, kernel_size=3, padding=1) # padding added
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu2 = nn.ReLU()
        
    def forward(self, x):
        # forward through the first weight layer 
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        print("shape after first weight layer:", out.shape)

        # forward through the second weight layer
        out = self.conv2(out)
        out = self.bn2(out)

        print("shape after second weight layer:", out.shape)

        # our skip connection: adding back in x after both weight layers  
        out += x 

        out = self.relu2(out)
        return out

Now, let's retry our example.

In [7]:
NUM_CHANNELS = 64
x = torch.randn(1, NUM_CHANNELS, 50, 50)
print(x.shape)
building_block = BuildingBlock(num_channels=NUM_CHANNELS)
y = building_block(x)

torch.Size([1, 64, 50, 50])
shape after first weight layer: torch.Size([1, 64, 50, 50])
shape after second weight layer: torch.Size([1, 64, 50, 50])


This time, our input was succesfully processed through our implementation of the residual block. Adding padding for both convolutions preserved the height and width of the image throughout the block. And there is our rudimentary implementation of a Residual Block! PyTorch will take care of backpropagation through the residiual block for us. 

For our implementation, we didn't take care of the case when the input and output of the residual block have a different number of channels . We also didn't worry about spatial resolution (i.e. stride). Let's compare our implementation to [Torch's implementation below](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L59):

```python
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super().__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor) -> Tensor:
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out
```

Outside of the `downsample` portion (which takes care of the edge case we described above), the forward pass looks quite similar. 

## Comparison with VGG
Let's now import [VGG13](https://pytorch.org/vision/main/models/generated/torchvision.models.vgg13.html#torchvision.models.vgg13) and [ResNet18](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html) from torchvision and use torchsummary to see some of the differences in the two networks. Both of these networks have roughly the same test accuracy on the ImageNet dataset at ~70%.

In [8]:
vgg = torchvision.models.vgg13()
r18 = torchvision.models.resnet18()

Let's use the summary package to print out some statistics of performing a forward pass on a single image.

In [9]:
summary(vgg, (3,224,224), 1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 64, 224, 224]           1,792
              ReLU-2          [1, 64, 224, 224]               0
            Conv2d-3          [1, 64, 224, 224]          36,928
              ReLU-4          [1, 64, 224, 224]               0
         MaxPool2d-5          [1, 64, 112, 112]               0
            Conv2d-6         [1, 128, 112, 112]          73,856
              ReLU-7         [1, 128, 112, 112]               0
            Conv2d-8         [1, 128, 112, 112]         147,584
              ReLU-9         [1, 128, 112, 112]               0
        MaxPool2d-10           [1, 128, 56, 56]               0
           Conv2d-11           [1, 256, 56, 56]         295,168
             ReLU-12           [1, 256, 56, 56]               0
           Conv2d-13           [1, 256, 56, 56]         590,080
             ReLU-14           [1, 256,

In [10]:
summary(r18, (3,224,224), 1)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [1, 64, 112, 112]           9,408
       BatchNorm2d-2          [1, 64, 112, 112]             128
              ReLU-3          [1, 64, 112, 112]               0
         MaxPool2d-4            [1, 64, 56, 56]               0
            Conv2d-5            [1, 64, 56, 56]          36,864
       BatchNorm2d-6            [1, 64, 56, 56]             128
              ReLU-7            [1, 64, 56, 56]               0
            Conv2d-8            [1, 64, 56, 56]          36,864
       BatchNorm2d-9            [1, 64, 56, 56]             128
             ReLU-10            [1, 64, 56, 56]               0
       BasicBlock-11            [1, 64, 56, 56]               0
           Conv2d-12            [1, 64, 56, 56]          36,864
      BatchNorm2d-13            [1, 64, 56, 56]             128
             ReLU-14            [1, 64,

Things to Note: 

- ResNet18 has $\sim 10$% of the number of parameters that VGG13 has (see `Total params`)
- ResNet18 takes up $\sim 15$% of the total size when compared to VGG13 (see `Total Size (MB)`)

# Visualizing the Loss Landscape of ResNets

In addition to the skip connections of ResNets enabling the training of deeper networks, they seem to also smoothen out the loss surfaces of these networks. Below is the first figure in the Visualizing Loss Landscape paper:

![](https://raw.githubusercontent.com/kvgarimella/dl-demos/main/imgs/loss_landscape.png)

At a high level, these loss landscapes are generated by:

1. training a network on a particular dataset (CIFAR10, in this case)
2. slightly perturbing the weight values in two different directions and observing the loss value over the entire dataset.

This is computationally intensive so we will be using a provided visualization tool by the authors. Navigate to the following URL: [http://www.telesens.co/loss-landscape-viz/viewer.html](http://www.telesens.co/loss-landscape-viz/viewer.html). 

Try out the following configurations:
(no short) means that there are **no skip connections** in the network. 

## 1. ResNet 20 (no short) and ResNet 20 (short)

Do you see a difference in the two loss landscapes?

## 2. ResNet 56 (no short) and ResNet 56 (short)

## 3. ResNet 110 (no short) and ResNet 110 (short)

What happens to the loss landscapes as the network grows deeper?


