In [1]:
import math
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock
from torch.nn import functional as F
from Attention import *
# from ResNets import resnet34

# ResNet-34 feature extractor

In [7]:
class custom_ResNet34(ResNet):
    def __init__(self):
        super(custom_ResNet34,self).__init__(BasicBlock,[3,4,6,3])
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x0 = self.relu(x)
        x = self.maxpool(x0)
        
        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
    
        return x1, x2, x3, x4

# Den$e Attention Network

In [8]:
class GroupNorm_32(nn.GroupNorm):
    def __init__(self,in_channels,num_groups=32):
        super(GroupNorm_32,self).__init__(num_groups,in_channels)

In [9]:
class FPN_Refine(nn.Module):
    def __init__(self,_norm_layer=nn.BatchNorm2d):
        super(FPN_Refine,self).__init__()
        self.down4 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        
        # MLF
        self.fuse1 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), _norm_layer(128), nn.PReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), _norm_layer(128), nn.PReLU()
        )     
        
        # Fused attention
        self.fuse2 = nn.Sequential(
            nn.Conv2d(512,512,kernel_size=1),
            _norm_layer(512),
            nn.PReLU()
        )
        #Dual Attention layers
        self.combine_1_1 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_1= PAM_Module(128, norm_layer=_norm_layer)
        self.cam_attention_1_1= CAM_Module(128)
        
        self.combine_1_2 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )  
        self.pam_attention_1_2 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_2 = CAM_Module(128)
        
        self.combine_1_3 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_3 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_3 = CAM_Module(128)

        self.combine_1_4 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_4 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_4 = CAM_Module(128)
        
        # Refinement layers
        self.refine4 = RefineConv(256,128,_norm_layer)
        self.refine3 = RefineConv(256,128,_norm_layer)
        self.refine2 = RefineConv(256,128,_norm_layer)
        self.refine1 = RefineConv(256,128,_norm_layer)
        
        # DenseASPP module
        self.dense_aspp_layers = nn.Sequential(
            DenseASPP(512,64,(3,6,12,18),_norm_layer),
            nn.Conv2d(768,128,kernel_size=1),
            _norm_layer(128),
            nn.PReLU()
        )
        
        # Prediction layers
        self.predict4 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Conv2d(128, 1, kernel_size=1)

        self.predict4_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1_2 = nn.Conv2d(128, 1, kernel_size=1)
        
        self.predict = nn.Conv2d(128,1,kernel_size=1)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
    def forward(self, x, size):
        # Bottom up network
        layer1, layer2, layer3, layer4 = x
        # Top down network (FPN)
        down4 = self.down4(layer4)
        down3 = torch.add(
            F.interpolate(down4,size=layer3.size()[2:],mode="bilinear",align_corners=True),
            self.down3(layer3)
        )
        down2 = torch.add(
            F.interpolate(down3,size=layer2.size()[2:],mode="bilinear",align_corners=True),
            self.down2(layer2)
        )
        down1 = torch.add(
            F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True),
            self.down1(layer1)
        )
        
        down4 = F.interpolate(down4,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down3 = F.interpolate(down3,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down2 = F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        
        # Deep supervision of top down network
        predict4 = self.predict4(down4)
        predict3 = self.predict3(down3)
        predict2 = self.predict2(down2)
        predict1 = self.predict1(down1)

        fuse1 = self.fuse1(torch.cat((down4, down3, down2, down1), 1))
        
        # Attention layers
        combine_4 = self.combine_1_4(torch.cat((down4,fuse1),1))
        attn_pam4 = self.pam_attention_1_4(combine_4,fuse1)
        attn_cam4 = self.cam_attention_1_4(combine_4,fuse1)
        refine4 = self.refine4(torch.cat((down4,(attn_pam4+attn_cam4)),dim=1))
        
        combine_3 = self.combine_1_3(torch.cat((down3,fuse1),1))
        attn_pam3 = self.pam_attention_1_3(combine_3,fuse1)
        attn_cam3 = self.cam_attention_1_3(combine_3,fuse1)
        refine3 = self.refine3(torch.cat((down3,(attn_pam3+attn_cam3)),dim=1))
        
        combine_2 = self.combine_1_2(torch.cat((down2,fuse1),1))
        attn_pam2 = self.pam_attention_1_2(combine_2,fuse1)
        attn_cam2 = self.cam_attention_1_2(combine_2,fuse1)
        refine2 = self.refine2(torch.cat((down2,(attn_pam2+attn_cam2)),dim=1))

        combine_1 = self.combine_1_1(torch.cat((down1,fuse1),1))
        attn_pam1 = self.pam_attention_1_1(combine_1,fuse1)
        attn_cam1 = self.cam_attention_1_1(combine_1,fuse1)
        refine1 = self.refine1(torch.cat((down1,(attn_pam1+attn_cam1)),dim=1))
        
        # Fuse refined attention layers for Dense ASPP layer
        fuse2 = self.fuse2(torch.cat((refine1,refine2,refine3,refine4),1))
        dense_aspp = self.dense_aspp_layers(fuse2)
        
        # Prediction layers
        
        # Main prediction
        predict = self.predict(dense_aspp)
        # Deep supervision layers
        predict4_2 = self.predict4_2(refine4)
        predict3_2 = self.predict3_2(refine3)
        predict2_2 = self.predict2_2(refine2)
        predict1_2 = self.predict1_2(refine1)
        
        predict1 = F.interpolate(predict1, size=size, mode='bilinear',align_corners=True)
        predict2 = F.interpolate(predict2, size=size, mode='bilinear',align_corners=True)
        predict3 = F.interpolate(predict3, size=size, mode='bilinear',align_corners=True)
        predict4 = F.interpolate(predict4, size=size, mode='bilinear',align_corners=True)

        predict1_2 = F.interpolate(predict1_2, size=size, mode='bilinear',align_corners=True)
        predict2_2 = F.interpolate(predict2_2, size=size, mode='bilinear',align_corners=True)
        predict3_2 = F.interpolate(predict3_2, size=size, mode='bilinear',align_corners=True)
        predict4_2 = F.interpolate(predict4_2, size=size, mode='bilinear',align_corners=True)
    
        predict = F.interpolate(predict, size=size, mode='bilinear',align_corners=True)
        return predict1, predict2, predict3, predict4, predict1_2, predict2_2, predict3_2, predict4_2, predict

In [10]:
class DANet(nn.Module):
    def __init__(self,_norm_layer=nn.BatchNorm2d):
        super(DANet, self).__init__()
        
        # Encoder group
        self.encoder = custom_ResNet34()
        self.encoder.load_state_dict(models.resnet34(pretrained=True).state_dict())
        self.encoder.fc = nn.Identity()
        
        # Decoder group
        self.decoder = FPN_Refine()
    def forward(self, x):
        # Bottom up network
        layer1, layer2, layer3, layer4 = self.encoder(x)
        layers = [layer1, layer2, layer3, layer4]
        predict1, predict2, predict3, predict4, predict1_2, predict2_2, predict3_2, predict4_2, predict = self.decoder(layers, x.size()[2:])
        
        return predict1, predict2, predict3, predict4, predict1_2, predict2_2, predict3_2, predict4_2, predict

In [11]:
if __name__ == "__main__":
    model = DANet()
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(pytorch_total_params)
    
    # Test function
    x = torch.randn(1,3,256,256)
    outputs = model(x)
    for output in outputs:
        print(output.size())
        
    print(model)

26182259
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
DANet(
  (encoder): custom_ResNet34(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNo

)


In [10]:
class DANet_R(nn.Module):
    def __init__(self,_norm_layer=nn.BatchNorm2d,_dilated=True):
        super(DANet_R, self).__init__()
        self.encoder = resnet34(norm_layer=_norm_layer,dilated=_dilated)

        self.down4 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        
        # MLF
        self.fuse1 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), _norm_layer(128), nn.PReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1), _norm_layer(128), nn.PReLU()
        )     
        
        #Dual Attention layers
        self.combine_1_1 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_1= PAM_Module(128, norm_layer=_norm_layer)
        self.cam_attention_1_1= CAM_Module(128)
        
        self.combine_1_2 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )  
        self.pam_attention_1_2 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_2 = CAM_Module(128)
        
        self.combine_1_3 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_3 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_3 = CAM_Module(128)

        self.combine_1_4 = nn.Sequential(
            nn.Conv2d(256,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
            nn.Conv2d(128,128,kernel_size=3,padding=1),_norm_layer(128),nn.PReLU(),
        )
        self.pam_attention_1_4 = PAM_Module(128,norm_layer=_norm_layer)
        self.cam_attention_1_4 = CAM_Module(128)
        
        # Refinement layers
        self.refine4 = RefineConv(256,128,_norm_layer)
        self.refine3 = RefineConv(256,128,_norm_layer)
        self.refine2 = RefineConv(256,128,_norm_layer)
        self.refine1 = RefineConv(256,128,_norm_layer)
        # Prediction layers
        self.predict4 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Conv2d(128, 1, kernel_size=1)

        self.predict4_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2_2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1_2 = nn.Conv2d(128, 1, kernel_size=1)
        
        self.predict = nn.Conv2d(512,1,kernel_size=1)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        # Bottom up network
        layer0, layer1, layer2, layer3, layer4 = self.encoder(x)
        # Top down network (FPN)
        down4 = self.down4(layer4)
        down3 = torch.add(
            F.interpolate(down4,size=layer3.size()[2:],mode="bilinear",align_corners=True),
            self.down3(layer3)
        )
        down2 = torch.add(
            F.interpolate(down3,size=layer2.size()[2:],mode="bilinear",align_corners=True),
            self.down2(layer2)
        )
        down1 = torch.add(
            F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True),
            self.down1(layer1)
        )
        
        down4 = F.interpolate(down4,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down3 = F.interpolate(down3,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down2 = F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        
        # Deep supervision of top down network
        predict4 = self.predict4(down4)
        predict3 = self.predict3(down3)
        predict2 = self.predict2(down2)
        predict1 = self.predict1(down1)

        fuse1 = self.fuse1(torch.cat((down4, down3, down2, down1), 1))
        
        # Attention layers
        combine_4 = self.combine_1_4(torch.cat((down4,fuse1),1))
        attn_pam4 = self.pam_attention_1_4(combine_4,fuse1)
        attn_cam4 = self.cam_attention_1_4(combine_4,fuse1)
        refine4 = self.refine4(torch.cat((down4,(attn_pam4+attn_cam4)),dim=1))
        
        combine_3 = self.combine_1_3(torch.cat((down3,fuse1),1))
        attn_pam3 = self.pam_attention_1_3(combine_3,fuse1)
        attn_cam3 = self.cam_attention_1_3(combine_3,fuse1)
        refine3 = self.refine3(torch.cat((down3,(attn_pam3+attn_cam3)),dim=1))
        
        combine_2 = self.combine_1_2(torch.cat((down2,fuse1),1))
        attn_pam2 = self.pam_attention_1_2(combine_2,fuse1)
        attn_cam2 = self.cam_attention_1_2(combine_2,fuse1)
        refine2 = self.refine2(torch.cat((down2,(attn_pam2+attn_cam2)),dim=1))

        combine_1 = self.combine_1_1(torch.cat((down1,fuse1),1))
        attn_pam1 = self.pam_attention_1_1(combine_1,fuse1)
        attn_cam1 = self.cam_attention_1_1(combine_1,fuse1)
        refine1 = self.refine1(torch.cat((down1,(attn_pam1+attn_cam1)),dim=1))
        
        # Main prediction 
        predict = self.predict(torch.cat((refine1,refine2,refine3,refine4),1))
        
        # Deep supervision layers
        predict4_2 = self.predict4_2(refine4)
        predict3_2 = self.predict3_2(refine3)
        predict2_2 = self.predict2_2(refine2)
        predict1_2 = self.predict1_2(refine1)
        
        predict1 = F.interpolate(predict1, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict2 = F.interpolate(predict2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict3 = F.interpolate(predict3, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict4 = F.interpolate(predict4, size=x.size()[2:], mode='bilinear',align_corners=True)

        predict1_2 = F.interpolate(predict1_2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict2_2 = F.interpolate(predict2_2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict3_2 = F.interpolate(predict3_2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict4_2 = F.interpolate(predict4_2, size=x.size()[2:], mode='bilinear',align_corners=True)
    
        predict = F.interpolate(predict, size=x.size()[2:], mode='bilinear',align_corners=True)
        return predict1, predict2, predict3, predict4, predict1_2, predict2_2, predict3_2, predict4_2, predict

In [9]:
class DANet_Base(nn.Module):
    def __init__(self,_norm_layer=nn.BatchNorm2d,_dilated=True):
        super(DANet_Base, self).__init__()
        self.encoder = resnet34(norm_layer=_norm_layer,dilated=_dilated)
        
        # down layers
        self.down4 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        # Prediction layers
        self.predict4 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Conv2d(128, 1, kernel_size=1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        # Bottom up network
        layer0, layer1, layer2, layer3, layer4 = self.encoder(x)
        # Top down network (FPN)
        down4 = self.down4(layer4)
        down3 = torch.add(
            F.interpolate(down4,size=layer3.size()[2:],mode="bilinear",align_corners=True),
            self.down3(layer3)
        )
        down2 = torch.add(
            F.interpolate(down3,size=layer2.size()[2:],mode="bilinear",align_corners=True),
            self.down2(layer2)
        )
        down1 = torch.add(
            F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True),
            self.down1(layer1)
        )
        
        down4 = F.interpolate(down4,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down3 = F.interpolate(down3,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down2 = F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        
        # Deep supervision of top down network
        predict4 = self.predict4(down4)
        predict3 = self.predict3(down3)
        predict2 = self.predict2(down2)
        predict1 = self.predict1(down1)
        
        predict1 = F.interpolate(predict1, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict2 = F.interpolate(predict2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict3 = F.interpolate(predict3, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict4 = F.interpolate(predict4, size=x.size()[2:], mode='bilinear',align_corners=True)

        return predict1, predict2, predict3, predict4


In [12]:
class DANet_DASPP(nn.Module):
    def __init__(self,_norm_layer=nn.BatchNorm2d,_dilated=False):
        super(DANet_DASPP, self).__init__()
        self.encoder = resnet34(norm_layer=_norm_layer,dilated=_dilated)

        self.down4 = nn.Sequential(
            nn.Conv2d(512, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=1), _norm_layer(128), nn.PReLU()
        )     
        
        # Prediction layers
        self.predict4 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Conv2d(128, 1, kernel_size=1)
        
        self.dense_aspp_layers = nn.Sequential(
            DenseASPP(128,64,(3,6,12,18),_norm_layer),
            nn.Conv2d(384,128,kernel_size=1),
            _norm_layer(128),
            nn.PReLU()
        )
        
        self.predict4 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict3 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict2 = nn.Conv2d(128, 1, kernel_size=1)
        self.predict1 = nn.Conv2d(128, 1, kernel_size=1)
        
        self.predict = nn.Conv2d(128,1,kernel_size=1)
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def forward(self, x):
        # Bottom up network
        layer0, layer1, layer2, layer3, layer4 = self.encoder(x)
        # Top down network (FPN)
        down4 = self.down4(layer4)
        down3 = torch.add(
            F.interpolate(down4,size=layer3.size()[2:],mode="bilinear",align_corners=True),
            self.down3(layer3)
        )
        down2 = torch.add(
            F.interpolate(down3,size=layer2.size()[2:],mode="bilinear",align_corners=True),
            self.down2(layer2)
        )
        down1 = torch.add(
            F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True),
            self.down1(layer1)
        )
        
        down4 = F.interpolate(down4,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down3 = F.interpolate(down3,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        down2 = F.interpolate(down2,size=layer1.size()[2:],mode="bilinear",align_corners=True)
        
        # Deep supervision of top down network
        predict4 = self.predict4(down4)
        predict3 = self.predict3(down3)
        predict2 = self.predict2(down2)
        predict1 = self.predict1(down1)
        
        # Main prediction 
        dense_aspp = self.dense_aspp_layers(down1)
        predict = self.predict(dense_aspp)

        # Deep supervision layers        
        predict1 = F.interpolate(predict1, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict2 = F.interpolate(predict2, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict3 = F.interpolate(predict3, size=x.size()[2:], mode='bilinear',align_corners=True)
        predict4 = F.interpolate(predict4, size=x.size()[2:], mode='bilinear',align_corners=True)
    
        predict = F.interpolate(predict, size=x.size()[2:], mode='bilinear',align_corners=True)
        return predict1, predict2, predict3, predict4, predict

In [13]:
if __name__ == "__main__":
    model = DANet_DASPP(GroupNorm_32)
    pytorch_total_params = sum(p.numel() for p in model.parameters())
    print(pytorch_total_params)
    
    # Test function
    x = torch.randn(1,3,256,256)
    outputs = model(x)
    for output in outputs:
        print(output.size())
    
    
    print(model)

21837135
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
torch.Size([1, 1, 256, 256])
DANet_DASPP(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm_32(32, 64, eps=1e-05, affine=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): GroupNorm_32(32, 64, eps=1e-05, affine=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): GroupNorm_32(32, 64, e