In [1]:
import torch
import torch.nn as nn
from math import sqrt


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, stride=1, dilation=1, groups=1):
        super(ConvBlock, self).__init__()

        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding=int((kernel_size - 1) / 2) * dilation,
                              bias=True, dilation=dilation, groups=groups)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):        
        return self.relu(self.conv(x))
    

class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
    
        self.input_layer = ConvBlock(1, 64, kernel_size=3)
        
        layers = [ConvBlock(64, 64, kernel_size=3) for _ in range(18)]
        self.medium_layer = nn.Sequential(*layers)
        
        self.output_layer = ConvBlock(64, 1, kernel_size=3)
        
        self.relu = nn.ReLU(inplace=True)
        
        # He initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))


    def forward(self, x):
        
        residual = x

        x = self.relu(self.input_layer(x))
        x = self.medium_layer(x)
        x = self.output_layer(x)

        return x + residual

In [2]:
# Find total parameters and trainable parameters
model = VDSR()
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

665,921 total parameters.
665,921 training parameters.


In [3]:
from torchsummary import summary

model = VDSR()
summary(model, (1,32,32)) 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]             640
              ReLU-2           [-1, 64, 32, 32]               0
         ConvBlock-3           [-1, 64, 32, 32]               0
              ReLU-4           [-1, 64, 32, 32]               0
            Conv2d-5           [-1, 64, 32, 32]          36,928
              ReLU-6           [-1, 64, 32, 32]               0
         ConvBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,928
              ReLU-9           [-1, 64, 32, 32]               0
        ConvBlock-10           [-1, 64, 32, 32]               0
           Conv2d-11           [-1, 64, 32, 32]          36,928
             ReLU-12           [-1, 64, 32, 32]               0
        ConvBlock-13           [-1, 64, 32, 32]               0
           Conv2d-14           [-1, 64,