In [1]:
import torch
import torch.nn as nn

class Unit(nn.Module):
    def __init__(self):
        super(Unit, self).__init__()
        
        self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv3 = nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.conv_down = nn.Conv2d(256, 16, kernel_size=1, stride=1, padding=0, bias=False)
        self.conv_up = nn.Conv2d(16, 256, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.trans = nn.Conv2d(256, 64, kernel_size=1, stride=1, padding=0, bias=False)
        
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        
        residual = x
        
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        
        x_se = self.avgpool(x)
        x_se = self.conv_down(x_se)
        x_se = self.sigmoid(self.conv_up(x_se))
        
        out = x*x_se
        out = self.relu(self.trans(out))
        
        return out+residual
    
    
class SESR(nn.Module):
    def __init__(self):
        super(SESR, self).__init__()
        
        # x2와 x4 각각의 레이어를 정해줘야 함. 공통의 레이어를 쓰면 파라미터가 공유되어버리기 때문
        self.conv_input_x2 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)   
        self.unit_x2 = Unit()
        self.upsample_x2 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False)
        self.upsample_outside_of_unit_x2 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2, padding=1, bias=False)
        self.reconv_x2 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.conv_input_x4 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False) 
        self.unit_x4 = Unit()
        self.upsample_x4 = nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1, bias=False)
        self.upsample_outside_of_unit_x4 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2, padding=1, bias=False)
        self.reconv_x4 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=False)
        
        self.relu = nn.LeakyReLU(0.2, inplace=True)
        
    def forward(self, x):
        
        # LR -> HR_x2
        
        input1 = x
        
        pre_upsample_x2 = input1
        pre_upsample_x2 = self.upsample_outside_of_unit_x2(pre_upsample_x2)
        
        input1 = self.relu(self.conv_input_x2(input1))
        
        for _ in range(4):
            input1 = self.unit_x2(input1)
        input1 = self.upsample_x2(input1)
        
        reconv_x2 = self.reconv_x2(input1)
        
        HR_x2 = reconv_x2 + pre_upsample_x2
        
        # HR_x2 -> HR_x4
        
        input2 = self.relu(self.conv_input_x4(reconv_x2))
        
        pre_upsample_x4 = self.upsample_outside_of_unit_x4(HR_x2)
        
        for _ in range(4):
            input2 = self.unit_x4(input2)
        input2 = self.upsample_x4(input2)
        
        reconv_x4 = self.reconv_x4(input2)
        
        HR_x4 = reconv_x4 + pre_upsample_x4
        
        return HR_x2, HR_x4

In [2]:
# Find total parameters and trainable parameters
model = SESR()
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')

624,928 total parameters.
624,928 training parameters.


In [3]:
from torchsummary import summary

model = SESR()
summary(model, (1,32,32)) 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1            [-1, 1, 64, 64]              16
            Conv2d-2           [-1, 64, 32, 32]             576
         LeakyReLU-3           [-1, 64, 32, 32]               0
            Conv2d-4           [-1, 64, 32, 32]          36,864
         LeakyReLU-5           [-1, 64, 32, 32]               0
            Conv2d-6           [-1, 64, 32, 32]          36,864
         LeakyReLU-7           [-1, 64, 32, 32]               0
            Conv2d-8          [-1, 256, 32, 32]         147,456
 AdaptiveAvgPool2d-9            [-1, 256, 1, 1]               0
           Conv2d-10             [-1, 16, 1, 1]           4,096
           Conv2d-11            [-1, 256, 1, 1]           4,096
          Sigmoid-12            [-1, 256, 1, 1]               0
           Conv2d-13           [-1, 64, 32, 32]          16,384
        LeakyReLU-14           [-1, 64,