In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RES_Block(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(RES_Block, self).__init__()

        self.split_conv_x1_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(15, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x1_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 15)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x2_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(13, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x2_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 13)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x3_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(11, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x3_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 11)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x4_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(9, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x4_2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 9)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.sum_conv_x1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.sum_conv_x2 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.sum_conv_x3 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )

    def forward(self, x):
        init = x

        split_conv_x1 = self.split_conv_x1_1(x)
        split_conv_x1 = self.split_conv_x1_2(split_conv_x1)
        
        split_conv_x2 = self.split_conv_x2_1(x)
        split_conv_x2 = self.split_conv_x2_2(split_conv_x2)
        
        split_conv_x3 = self.split_conv_x3_1(x)
        split_conv_x3 = self.split_conv_x3_2(split_conv_x3)
        
        split_conv_x4 = self.split_conv_x4_1(x)
        split_conv_x4 = self.split_conv_x4_2(split_conv_x4)
        
        x = torch.cat([init, split_conv_x1, split_conv_x2, split_conv_x3, split_conv_x4], dim=1)
        
        x = self.sum_conv_x1(x)
        x = self.sum_conv_x2(x)
        x = self.sum_conv_x1(x)

        return x


WC_Block(
  (split_conv_x1_1): Sequential(
    (0): Conv2d(1024, 1024, kernel_size=(15, 1), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (split_conv_x1_2): Sequential(
    (0): Conv2d(1024, 1024, kernel_size=(15, 1), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (split_conv_x2_1): Sequential(
    (0): Conv2d(1024, 1024, kernel_size=(1, 15), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (split_conv_x2_2): Sequential(
    (0): Conv2d(1024, 1024, kernel_size=(1, 15), stride=(1, 1))
    (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
  )
  (batch_norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
