In [3]:
print("hello world")

hello world


In [4]:
import torch
import torch.nn as nn

# My code starts here

In [24]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(ResidualBlock,self).__init__()
        #convolution layer
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels*2, kernel_size, padding=1)
        self.conv4 = nn.Conv2d(out_channels*2, out_channels*2, kernel_size, padding=1)
        self.conv5 = nn.Conv2d(out_channels*2, out_channels*4, kernel_size, padding=1)
        self.conv6 = nn.Conv2d(out_channels*4, out_channels*4, kernel_size, padding=1)
        self.conv7 = nn.Conv2d(out_channels*4, out_channels*8, kernel_size, padding=1)
        self.conv8 = nn.Conv2d(out_channels*8, out_channels*8, kernel_size, padding=1)
        #batch norm layer
        self.bn1 = nn.BatchNorm2d(num_features=out_channels)
        self.bn2 = nn.BatchNorm2d(num_features=out_channels*2)
        self.bn3 = nn.BatchNorm2d(num_features=out_channels*4)
        self.bn4 = nn.BatchNorm2d(num_features=out_channels*8)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2)
        
        #transition the input channel for the residual path to add
        self.trans1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.trans2 = nn.Conv2d(out_channels, out_channels*2, kernel_size=1)
        self.trans3 = nn.Conv2d(out_channels*2, out_channels*4, kernel_size=1)
        self.trans4 = nn.Conv2d(out_channels*4, out_channels*8, kernel_size=1)
        #if in_channels == out_channels: no need to do transition
        self.isChanged = not in_channels == out_channels
        
    def forward(self, x):
        #block 1
        f_x = self.conv1(x)
        f_x = self.bn1(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv2(f_x)
        f_x = self.bn1(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        if(self.isChanged):
            x = self.trans1(x)
        x = f_x + x
        x = self.relu(x)
        
        #block 1
        f_x = self.conv2(x)
        f_x = self.bn1(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv2(f_x)
        f_x = self.bn1(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = f_x + x
        x = self.relu(x)
        
        #max pooling
        x = self.pool(x)
        
        #block 2
        f_x = self.conv3(x)
        f_x = self.bn2(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv4(f_x)
        f_x = self.bn2(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = self.trans2(x)
        x = f_x + x
        x = self.relu(x)
        
        #block 2
        f_x = self.conv4(x)
        f_x = self.bn2(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv4(f_x)
        f_x = self.bn2(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = f_x + x
        x = self.relu(x)
        
        #max pooling
        x = self.pool(x)
        
        #block 3
        f_x = self.conv5(x)
        f_x = self.bn3(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv6(f_x)
        f_x = self.bn3(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = self.trans3(x)
        x = f_x + x
        x = self.relu(x)
        
        #block 3
        f_x = self.conv6(x)
        f_x = self.bn3(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv6(f_x)
        f_x = self.bn3(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = f_x + x
        x = self.relu(x)
        
        #max pooling
        x = self.pool(x)
        
        #block 4
        f_x = self.conv7(x)
        f_x = self.bn4(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv8(f_x)
        f_x = self.bn4(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = self.trans4(x)
        x = f_x + x
        x = self.relu(x)
        
        #block 4
        f_x = self.conv8(x)
        f_x = self.bn4(f_x)
        f_x = self.relu(f_x)
        f_x = self.conv8(f_x)
        f_x = self.bn4(f_x)
        #residual path
        #in order for the residual path to work, we need to match the channels
        x = f_x + x
        x = self.relu(x)
        
        #max pooling
        x = self.pool(x)
        
        return x

In [6]:
from torchsummary import summary

In [25]:
residualBlock = ResidualBlock(3,64)

In [26]:
summary(residualBlock, (3, 112, 112))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           1,792
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
            Conv2d-4         [-1, 64, 112, 112]          36,928
       BatchNorm2d-5         [-1, 64, 112, 112]             128
            Conv2d-6         [-1, 64, 112, 112]             256
              ReLU-7         [-1, 64, 112, 112]               0
            Conv2d-8         [-1, 64, 112, 112]          36,928
       BatchNorm2d-9         [-1, 64, 112, 112]             128
             ReLU-10         [-1, 64, 112, 112]               0
           Conv2d-11         [-1, 64, 112, 112]          36,928
      BatchNorm2d-12         [-1, 64, 112, 112]             128
             ReLU-13         [-1, 64, 112, 112]               0
        MaxPool2d-14           [-1, 64,