In [1]:
import torch
import torch.nn as nn
from math import sqrt

class Conv_ReLU_Block(nn.Module):
    def __init__(self):
        super(Conv_ReLU_Block, self).__init__()
        self.conv = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        return self.relu(self.conv(x))
        
class VDSR(nn.Module):
    def __init__(self):
        super(VDSR, self).__init__()
        self.residual_layer = self.make_layer(Conv_ReLU_Block, 18)
        self.input = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) 
        self.output = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
    
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                
    def make_layer(self, block, num_of_layer):
        layers = []
        for _ in range(num_of_layer):
            layers.append(block())
        return nn.Sequential(*layers)

    def forward(self, x):
        residual = x
        out = self.relu(self.input(x))
        out = self.residual_layer(out)
        out = self.output(out)
        out = torch.add(out,residual)
        return out

In [5]:
from torch.autograd import Variable

def test_net():
    net = VDSR()
    y = net(Variable(torch.randn(1,1,128,128)))
    print(y.size())

In [6]:
test_net()

torch.Size([1, 1, 128, 128])


In [8]:
from torchsummary import summary
model = VDSR()
summary(model, (1, 128, 128))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]             576
              ReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3         [-1, 64, 128, 128]          36,864
              ReLU-4         [-1, 64, 128, 128]               0
   Conv_ReLU_Block-5         [-1, 64, 128, 128]               0
            Conv2d-6         [-1, 64, 128, 128]          36,864
              ReLU-7         [-1, 64, 128, 128]               0
   Conv_ReLU_Block-8         [-1, 64, 128, 128]               0
            Conv2d-9         [-1, 64, 128, 128]          36,864
             ReLU-10         [-1, 64, 128, 128]               0
  Conv_ReLU_Block-11         [-1, 64, 128, 128]               0
           Conv2d-12         [-1, 64, 128, 128]          36,864
             ReLU-13         [-1, 64, 128, 128]               0
  Conv_ReLU_Block-14         [-1, 64, 1

In [9]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params+=param
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params

model = VDSR()
count_parameters(model)

+-------------------------------+------------+
|            Modules            | Parameters |
+-------------------------------+------------+
|  residual_layer.0.conv.weight |   36864    |
|  residual_layer.1.conv.weight |   36864    |
|  residual_layer.2.conv.weight |   36864    |
|  residual_layer.3.conv.weight |   36864    |
|  residual_layer.4.conv.weight |   36864    |
|  residual_layer.5.conv.weight |   36864    |
|  residual_layer.6.conv.weight |   36864    |
|  residual_layer.7.conv.weight |   36864    |
|  residual_layer.8.conv.weight |   36864    |
|  residual_layer.9.conv.weight |   36864    |
| residual_layer.10.conv.weight |   36864    |
| residual_layer.11.conv.weight |   36864    |
| residual_layer.12.conv.weight |   36864    |
| residual_layer.13.conv.weight |   36864    |
| residual_layer.14.conv.weight |   36864    |
| residual_layer.15.conv.weight |   36864    |
| residual_layer.16.conv.weight |   36864    |
| residual_layer.17.conv.weight |   36864    |
|          in

664704

In [10]:
# Find total parameters and trainable parameters
model = VDSR()
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

664,704 total parameters.
664,704 training parameters.
