In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RES_Block(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(RES_Block, self).__init__()

        self.split_conv_x1_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(15, 1), padding=(7, 0)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x1_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 15), padding=(0, 7)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x2_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(13, 1), padding=(6, 0)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x2_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 13), padding=(0, 6)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x3_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(11, 1), padding=(5, 0)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x3_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 11),padding=(0, 5)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x4_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(9, 1), padding=(4, 0)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x4_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 9), padding=(0, 4)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.sum_conv_x1 = nn.Sequential(
            nn.Conv2d(5 * out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.sum_conv_x2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.sum_conv_x3 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )

    def forward(self, x):
        init = x

        split_conv_x1 = self.split_conv_x1_1(x)
        split_conv_x1 = self.split_conv_x1_2(split_conv_x1)
        
        split_conv_x2 = self.split_conv_x2_1(x)
        split_conv_x2 = self.split_conv_x2_2(split_conv_x2)
        
        split_conv_x3 = self.split_conv_x3_1(x)
        split_conv_x3 = self.split_conv_x3_2(split_conv_x3)
        
        split_conv_x4 = self.split_conv_x4_1(x)
        split_conv_x4 = self.split_conv_x4_2(split_conv_x4)


        x = torch.cat([init, split_conv_x1, split_conv_x2, split_conv_x3, split_conv_x4],dim=1)
        
        x = self.sum_conv_x1(x)
        x = self.sum_conv_x2(x)
        x = self.sum_conv_x3(x)

        return x


class WC_Block(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(WC_Block, self).__init__()

        self.split_conv_x1_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(15, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x1_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(1, 15)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.split_conv_x2_1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(1, 15)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        self.split_conv_x2_2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=(15, 1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
            )
        
        self.conv_sum = nn.Conv2d(2* out_channels, out_channels, 3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):

        split_conv_x1 = self.split_conv_x1_1(x)
        split_conv_x1 = self.split_conv_x1_2(split_conv_x1)
        
        split_conv_x2 = self.split_conv_x2_1(x)
        split_conv_x2 = self.split_conv_x2_2(split_conv_x2)
        x = torch.cat([split_conv_x1, split_conv_x2],dim=1)
        
        x = self.conv_sum(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        
        return x


def conv(in_channels, out_channels):

    conv_block = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),  
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

    return conv_block


class BU_net(nn.Module):
    def __init__(self, n_classes):
        super(BU_net, self).__init__()

        self.convDown1 = conv(3, 64)
        self.convDown2 = conv(64, 128)
        self.convDown3 = conv(128, 256)
        self.convDown4 = conv(256, 512)
        self.convDown5 = nn.Sequential(
        nn.Conv2d(1024, 1024, 3, padding=1),  
        nn.BatchNorm2d(1024),
        nn.ReLU(inplace=True)
        )
        self.maxpool = nn.MaxPool2d(2, stride=2)
        self.convUp4 = conv(1024+512, 512)
        self.convUp3 = conv(512+256, 256)
        self.convUp2 = conv(256+128, 128)
        self.convUp1 = conv(128+64, 64)
        self.convUp_fin = nn.Conv2d(64, n_classes, kernel_size=1)
        
        self.upsample1 = nn.ConvTranspose2d(1024, 1024, kernel_size=32, stride=1)
        self.upsample2 = nn.ConvTranspose2d(512, 512, kernel_size=31, stride=1)
        self.upsample3 = nn.ConvTranspose2d(256, 256, kernel_size=61, stride=1)
        self.upsample4 = nn.ConvTranspose2d(128, 128, kernel_size=121, stride=1)
        
        self.RES1 = RES_Block(64, 64)
        self.RES2 = RES_Block(128, 128)
        self.RES3 = RES_Block(256, 256)
        self.RES4 = RES_Block(512, 512)
        self.WC = WC_Block(512, 1024)
        
        self.sigmoid_layer = nn.Sigmoid()

    def forward(self, x):
        conv1 = self.convDown1(x)
        x = self.maxpool(conv1)
        conv2 = self.convDown2(x)
        x = self.maxpool(conv2)
        conv3 = self.convDown3(x)
        x = self.maxpool(conv3)
        conv4 = self.convDown4(x)
        x = self.maxpool(conv4)
        WC_5 = self.WC(x)
        conv5 = self.convDown5(WC_5)
        x = self.upsample1(conv5)
        
        RES_4 = self.RES4(conv4)
        x = torch.cat([RES_4,x], dim=1)
        x = self.convUp4(x)
        x = self.upsample2(x)
        
        RES_3 = self.RES3(conv3)
        x = torch.cat([RES_3,x], dim=1)
        x = self.convUp3(x)
        x = self.upsample3(x)
        
        RES_2 = self.RES2(conv2)
        x = torch.cat([RES_2,x], dim=1)
        x = self.convUp2(x)
        x = self.upsample4(x)
        
        RES_1 = self.RES1(conv1)
        x = torch.cat([RES_1,x], dim=1)
        x = self.convUp1(x)
        out = self.convUp_fin(x)
        
        out = self.sigmoid_layer(out)
        
        return out