# Residual Block

A residual block is defined as y=σ(F(x)+G(x))

where x and y represent the input and output tensors of the block, σ is the ReLU activation function, F is the residual function to be learned and G is a projection shortcut used to match dimensions between F(x) and x.

Your code needs to define a ResidualBlock class (inherited from nn.Module) which implements a residual block. In your code, F will be implemented with two convolutional layers with a ReLU non-linearity between them, i.e. F=conv2(σ(conv1(x))). Batch normalization will also be adopted right after each convolution operation.

The constructor of the ResidualBlock class needs to take the following arguments as input:

* *inplanes*, the number of channels of x;

* *planes*, the number of output channels of conv1 and conv2;

* *stride*, the stride of conv1;

If the shapes of F(x) and x do not match (either because inplanes != planes, or because stride > 1) ResidualBlock also needs to apply a projection shortcut G, composed of a convolutional layer with kernel size 1×1, no bias, no padding and stride stride, followed by a batch normalization. Otherwise G is simply the identity function.

The forward method of ResidualBlock will take as input the input tensor x and return the output tensor y, after performing all the operations of a Residual block.

Additional details
Unless otherwise specified, convolutional layers must have 3×3 kernels, stride 1, padding 1 and no bias.

In [3]:
import torch
from torch import nn

class ResidualBlock(nn.Module):
    def __init__(self, inplanes:int, planes:int, stride:int):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels=inplanes,
            out_channels=planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        self.conv1_bn = nn.BatchNorm2d(planes)

        self.conv2 = nn.Conv2d(
            in_channels=planes,
            out_channels=planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False
        )
        self.conv2_bn = nn.BatchNorm2d(planes)

        self.conv_g = nn.Conv2d(
            in_channels=inplanes,
            out_channels=planes,
            kernel_size=1,
            stride=1,
            padding=1,
            bias=False
        )
        self.conv_g_bn = nn.BatchNorm2d(planes)
    
    def forward(self, X:torch.Tensor):
        fx = self.conv1(X)
        fx = self.conv1_bn(fx)
        fx = nn.functional.relu(fx)
        fx = self.conv2(fx)
        fx = self.conv2_bn(fx)

        if X.shape == fx.shape:
            gx = X
        else:
            gx = self.conv_g(X)
            gx = self.conv_g_bn(gx)

        out = nn.functional.relu(fx + gx)
        return out

