In [0]:
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

class conv_block(nn.Module):
    """
    Convolution Block 
    """
    def __init__(self, in_ch, out_ch):
        super(conv_block, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True))

    def forward(self, x):

        x = self.conv(x)
        return x
 
class up_conv(nn.Module):
    """
    Up Convolution Block
    """
    def __init__(self, in_ch, out_ch):
        super(up_conv, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.up(x)
        return x

class conv_block_nested(nn.Module):
    
    def __init__(self, in_ch, mid_ch, out_ch):
        super(conv_block_nested, self).__init__()
        self.activation = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.3)
        self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
        self.bn1 = nn.BatchNorm2d(mid_ch)
        self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
        self.bn2 = nn.BatchNorm2d(out_ch)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.activation(x)
        output = self.dropout(x)

        return output

class DUpsampling(nn.Module):
    def __init__(self, inplanes, scale, num_class=64, pad=0):
        super(DUpsampling, self).__init__()
        ## W matrix
        self.conv_w = nn.Conv2d(inplanes, num_class * scale * scale, kernel_size=1, padding = pad,bias=False)
        ## P matrix
        self.conv_p = nn.Conv2d(num_class * scale * scale, inplanes, kernel_size=1, padding = pad,bias=False)

        self.scale = scale
    
    def forward(self, x):
        x = self.conv_w(x)
        N, C, H, W = x.size()

        # N, W, H, C
        x_permuted = x.permute(0, 3, 2, 1) 

        # N, W, H*scale, C/scale
        x_permuted = x_permuted.contiguous().view((N, W, H * self.scale, int(C / (self.scale))))

        # N, H*scale, W, C/scale
        x_permuted = x_permuted.permute(0, 2, 1, 3)
        # N, H*scale, W*scale, C/(scale**2)
        x_permuted = x_permuted.contiguous().view((N, W * self.scale, H * self.scale, int(C / (self.scale * self.scale))))

        # N, C/(scale**2), H*scale, W*scale
        x = x_permuted.permute(0, 3, 1, 2)
        
        return x

### Modified NestedUnet_V2
class Modified_NestedUNet_V2(nn.Module):
    """
    Implementation of this paper:
    https://arxiv.org/pdf/1807.10165.pdf
    """
    def __init__(self, in_ch=3, out_ch=1):
        super(Modified_NestedUNet_V2, self).__init__()
        
        ### This is a function from CELL 3 that we take the pre-trained weights.
        self.encoder = get_encoder('resnet50', encoder_weights='imagenet')
        
        n1 = 64
        filters = [n1, n1 * 4, n1 * 8, n1 * 16, n1 * 32]

        # Encoder
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.Up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.Up1 = nn.Upsample(size=224, mode='bilinear', align_corners=True)
        self.Up2 = nn.Upsample(size=56, mode='bilinear', align_corners=True)
        self.Up3 = nn.Upsample(size=28, mode='bilinear', align_corners=True)
        
        self.conv0_1 = conv_block_nested(filters[0] + filters[1], filters[0], filters[0]) #(64+256, 64, 64)
        self.conv0_2 = conv_block_nested(filters[0]*2 + filters[1] + filters[2], filters[0], filters[0]) #(64*2+256*2, 64, 64)
        self.conv0_3 = conv_block_nested(filters[0]*3 + filters[1] + filters[2] + filters[3], filters[0], filters[0]) #(64*3+256, 64, 64)
        
        self.conv1_1 = conv_block_nested(filters[1] + filters[2], filters[1], filters[1])
        self.conv1_2 = conv_block_nested(filters[1]*2 + filters[2] + filters[3], filters[1], filters[1])
        self.conv2_1 = conv_block_nested(filters[2] + filters[3], filters[2], filters[2])
        
        # Decoder
        self.conv1 = nn.Conv2d(filters[4], filters[3], kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(filters[3])
        self.relu = nn.ReLU()
        
        self.conv2 = nn.Conv2d(filters[3], filters[2], kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(filters[2])
        self.dropout2 = nn.Dropout(0.5)
        
        self.conv3 = nn.Conv2d(filters[2], filters[1], kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(filters[1])
        self.dropout3 = nn.Dropout(0.5)
        
        self.conv4 = nn.Conv2d(filters[1], filters[0], kernel_size=3, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(filters[0])
        self.dropout4 = nn.Dropout(0.5)
        
        self.dupsample4 = DUpsampling(filters[0], 32, num_class=64)
        
        self.conv6 = nn.Conv2d(filters[0]*2, filters[0], kernel_size=3, padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(filters[0])
        self.dropout6 = nn.Dropout(0.5)
        
        self.final = nn.Conv2d(filters[0], out_ch, kernel_size=1)
        
    def forward(self, x):
        
        # this part self.encoder(x) takes as input our input x and passes to the encoder of Res-Net which we implemented above in CELL 3.
        # This gives as output 5 stacks of feature maps with different resoultions as we see below.
        e = self.encoder(x)
        #print(e[0].shape) #  1, 2048, 7, 7
        #print(e[1].shape) #  1, 1024, 14, 14
        #print(e[2].shape) #  1, 512, 28, 28
        #print(e[3].shape) #  1, 256, 56, 56
        #print(e[4].shape) #  1, 64, 112, 112 #after upsampling it is 64x224x224
        
        x0_0 = self.Up(e[4])
        x1_0 = e[3]
        x2_0 = e[2]
        x3_0 = e[1]
        x4_0 = e[0]
        
        x0_1 = self.conv0_1(torch.cat([x0_0, self.Up1(x1_0)], 1))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.Up(x2_0)], 1))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.Up(x3_0)], 1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.Up2(x2_1), self.Up2(x3_0)], 1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.Up1(x1_1), self.Up1(x2_0)], 1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.Up1(x1_2), self.Up1(x3_0), self.Up1(x2_1)], 1)) 
        # x0_3 = 64 x 224 x 224
        
        # Decoder
        x3_1 = self.conv1(x4_0)
        x3_1 = self.bn1(x3_1)
        x3_1 = self.relu(x3_1)
        
        x3_1 = self.conv2(x3_1)
        x3_1 = self.bn2(x3_1)
        x3_1 = self.relu(x3_1)
        x3_1 = self.dropout2(x3_1)
        
        x3_1 = self.conv3(x3_1)
        x3_1 = self.bn3(x3_1)
        x3_1 = self.relu(x3_1)
        x3_1 = self.dropout3(x3_1)
        
        x3_1 = self.conv4(x3_1)
        x3_1 = self.bn4(x3_1)
        x3_1 = self.relu(x3_1)
        x3_1 = self.dropout4(x3_1)
        
        x3_1_up = self.dupsample4(x3_1)

        x0_4_cat = torch.cat((x3_1_up, x0_3), dim=1)
        
        x3_1_final = self.conv6(x0_4_cat)
        x3_1_final = self.bn6(x3_1_final)
        x3_1_final = self.relu(x3_1_final)
        x3_1_final = self.dropout6(x3_1_final)
        
        output = self.final(x3_1_final)
       
        return output
