<a href="https://colab.research.google.com/github/hamidreza2015/CMGFNet-Building_Extraction/blob/hamidreza2015-patch-1/Model_Encoder_Compare.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision

import torch
import torch.nn as nn
import torch.nn.functional as F

!pip install ptflops

if torch.cuda.is_available():
  device = torch.device("cuda")


from ptflops import get_model_complexity_info



##Based on Resnet34

In [10]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part
            
        self.resnet_features = torchvision.models.resnet34(pretrained=pretrained)
        
        self.enc_rgb1 = nn.Sequential(self.resnet_features.conv1,
                                    self.resnet_features.bn1,
                                    self.resnet_features.relu,)
        self.enc_rgb2 = nn.Sequential(self.resnet_features.maxpool,
                                    self.resnet_features.layer1)
        
        self.enc_rgb3 = self.resnet_features.layer2
        self.enc_rgb4 = self.resnet_features.layer3
        self.enc_rgb5 = self.resnet_features.layer4

               
        
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.resnet34(pretrained=pretrained)

        avg = torch.mean(self.encoder_depth.conv1.weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1 = conv1d
        
        self.enc_dsm1 = nn.Sequential(self.encoder_depth.conv1,
                                    self.encoder_depth.bn1,
                                    self.encoder_depth.relu,)
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.maxpool,
                                    self.encoder_depth.layer1)
        
        self.enc_dsm3 = self.encoder_depth.layer2
        self.enc_dsm4 = self.encoder_depth.layer3
        self.enc_dsm5 = self.encoder_depth.layer4

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(128, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(64, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(64, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(128, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(64, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(64, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
        
  
        self.final_dsm1 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm2 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm3 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm4 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm5 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm6 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         # bs * 64 * W/2 * H/2
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 64 * W/2 * H/2
        x1_side = self.side1_rgb(x1)

        ##########################################################

        y2 = self.enc_dsm2(y1)         # bs * 64 * W/4 * H/4
        y2_side = self.side2_cross(y2)
        
        

        x2 = self.enc_rgb2(x1)         # bs * 64 * W/4 * H/4
        x2_side = self.side2_rgb(x2)
        
        

        ##########################################################

        y3 = self.enc_dsm3(y2)         # bs * 128 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        


        x3 = self.enc_rgb3(x2)         # bs * 128 * W/8 * H/8
        x3_side = self.side3_rgb(x3)
        
        

        ##########################################################

        y4 = self.enc_dsm4(y3)         # bs * 256 * W/16 * H/16
        y4_side = self.side4_cross(y4)
        


        x4 = self.enc_rgb4(x3)         # bs * 256 * W/16 * H/16
        x4_side = self.side4_rgb(x4)
        

        ##########################################################

        y5 = self.enc_dsm5(y4)         # bs * 512 * W/16 * H/16
        y5_side = self.side5_cross(y5)
        
        

        x5 = self.enc_rgb5(x4)         # bs * 512 * W/16 * H/16
        x5_side = self.side5_rgb(x5)
        

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)
        out_dsm1_out = self.final_dsm1(out_dsm1)    

        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   


        ##########################################################


        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
        out_dsm2_out = self.final_dsm2(out_dsm2)     


        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
        out_dsm3_out = self.final_dsm3(out_dsm3)      

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   
        out_dsm4_out = self.final_dsm4(out_dsm4)  

        ##########################################################


        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
        out_dsm5_out = self.final_dsm5(out_dsm5)  

        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
        out_dsm6_out = self.final_dsm6(out_dsm6)

        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        f1 = F.interpolate(out_dsm1_out, scale_factor=32,mode='bilinear', align_corners=True )
        f2 = F.interpolate(out_dsm2_out, scale_factor=16,mode='bilinear', align_corners=True )
        f3 = F.interpolate(out_dsm3_out, scale_factor=8 ,mode='bilinear', align_corners=True )
        f4 = F.interpolate(out_dsm4_out, scale_factor=4 ,mode='bilinear', align_corners=True )
        f5 = F.interpolate(out_dsm5_out, scale_factor=2 ,mode='bilinear', align_corners=True )

        final_help = (f1 + f2 + f3 + f4 + f5 + out_dsm6_out)/6


        a1 = [out_dsm1_out,out_dsm2_out,out_dsm3_out,out_dsm4_out,out_dsm5_out,out_dsm6_out]
        

        return final, final_help, a1


##Based on DenseNet

In [2]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet_Dens(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part

        self.densenet_features = torchvision.models.densenet.densenet121(pretrained=pretrained)
        
        self.enc_rgb1 = nn.Sequential(self.densenet_features.features.conv0,
                         self.densenet_features.features.norm0,
                         self.densenet_features.features.relu0,)
        
        self.enc_rgb2 = nn.Sequential(self.densenet_features.features.pool0,
                        self.densenet_features.features.denseblock1)
        
        self.transition2 = self.densenet_features.features.transition1
        self.enc_rgb3 = self.densenet_features.features.denseblock2

        self.transition3 = self.densenet_features.features.transition2
        self.enc_rgb4 = self.densenet_features.features.denseblock3

        self.transition4 = self.densenet_features.features.transition3
        self.enc_rgb5 = self.densenet_features.features.denseblock4
            
      
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.densenet.densenet121(pretrained=pretrained)

        avg = torch.mean(self.densenet_features.features.conv0.weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1 = conv1d
        
        self.enc_dsm1 = nn.Sequential(self.encoder_depth.conv1,
                                      self.encoder_depth.features.norm0,
                                      self.encoder_depth.features.relu0,)
        
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.features.pool0,
                                      self.encoder_depth.features.denseblock1)
        
        self.transition2_dsm = self.encoder_depth.features.transition1
        self.enc_dsm3 = self.encoder_depth.features.denseblock2

        self.transition3_dsm = self.encoder_depth.features.transition2
        self.enc_dsm4 = self.encoder_depth.features.denseblock3

        self.transition4_dsm = self.encoder_depth.features.transition3
        self.enc_dsm5 = self.encoder_depth.features.denseblock4

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(64, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(64, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
        
  
        self.final_dsm1 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm2 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm3 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm4 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm5 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm6 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         # bs * 64 * W/2 * H/2
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 64 * W/2 * H/2
        x1_side = self.side1_rgb(x1)

        ##########################################################

        y2 = self.enc_dsm2(y1)         # bs * 256 * W/4 * H/4
        y2_side = self.side2_cross(y2)
        

        x2 = self.enc_rgb2(x1)         # bs * 256 * W/4 * H/4
        x2_side = self.side2_rgb(x2)

        ##########################################################

        y2 = self.transition2_dsm(y2)
        y3 = self.enc_dsm3(y2)         # bs * 512 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        

        x2 = self.transition2(x2)
        x3 = self.enc_rgb3(x2)         # bs * 512 * W/8 * H/8
        x3_side = self.side3_rgb(x3)


        ##########################################################
        y3 = self.transition3_dsm(y3)
        y4 = self.enc_dsm4(y3)         # bs * 1024 * W/16 * H/16
        y4_side = self.side4_cross(y4)
        

        x3 = self.transition3(x3)
        x4 = self.enc_rgb4(x3)         # bs * 1024 * W/16 * H/16
        x4_side = self.side4_rgb(x4)

        
        ##########################################################
        y4 = self.transition4_dsm(y4)
        y5 = self.enc_dsm5(y4)         # bs * 512 * W/16 * H/16
        y5_side = self.side5_cross(y5)
        
        
        x4 = self.transition4(x4)
        x5 = self.enc_rgb5(x4)         # bs * 512 * W/16 * H/16
        x5_side = self.side5_rgb(x5)

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)
        out_dsm1_out = self.final_dsm1(out_dsm1)    

        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   


        ##########################################################

        

        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
        out_dsm2_out = self.final_dsm2(out_dsm2)     


        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
        out_dsm3_out = self.final_dsm3(out_dsm3)      

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   
        out_dsm4_out = self.final_dsm4(out_dsm4)  

        ##########################################################


        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
        out_dsm5_out = self.final_dsm5(out_dsm5)  

        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
        out_dsm6_out = self.final_dsm6(out_dsm6)

        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        f1 = F.interpolate(out_dsm1_out, scale_factor=32,mode='bilinear', align_corners=True )
        f2 = F.interpolate(out_dsm2_out, scale_factor=16,mode='bilinear', align_corners=True )
        f3 = F.interpolate(out_dsm3_out, scale_factor=8 ,mode='bilinear', align_corners=True )
        f4 = F.interpolate(out_dsm4_out, scale_factor=4 ,mode='bilinear', align_corners=True )
        f5 = F.interpolate(out_dsm5_out, scale_factor=2 ,mode='bilinear', align_corners=True )

        final_help = (f1 + f2 + f3 + f4 + f5 + out_dsm6_out)/6


        a1 = [out_dsm1_out,out_dsm2_out,out_dsm3_out,out_dsm4_out,out_dsm5_out,out_dsm6_out]
        

        return final, final_help, a1



In [3]:
x = torch.rand(1,3,640,640).to(device)
y = torch.rand(1,1,640,640).to(device)

model = FuseNet_Dens(2,True).to(device)

out,_,_ = model(x,y)

out.shape

torch.Size([1, 2, 640, 640])

In [4]:
def term_construct(input_tensor):

  return {'x_rgb':input_tensor[0] , 'x_dsm': input_tensor[1]}


macs, params = get_model_complexity_info(model,
                                         input_res=(x,y),
                                         as_strings=True,
                                         input_constructor = term_construct,
                                         print_per_layer_stat=False,
                                         verbose=True)

print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       51.76 GMac
Number of parameters:           16.13 M 


##DensNet_Paper

In [2]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet_Dens(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part

        self.densenet_features = torchvision.models.densenet.densenet161(pretrained=pretrained)
        
        self.enc_rgb1 = nn.Sequential(self.densenet_features.features.conv0,
                         self.densenet_features.features.norm0,
                         self.densenet_features.features.relu0,)
        
        self.enc_rgb2 = nn.Sequential(self.densenet_features.features.pool0,
                        self.densenet_features.features.denseblock1)
        
        self.transition2 = self.densenet_features.features.transition1
        self.enc_rgb3 = self.densenet_features.features.denseblock2

        self.transition3 = self.densenet_features.features.transition2
        self.enc_rgb4 = self.densenet_features.features.denseblock3

        self.transition4 = self.densenet_features.features.transition3
        self.enc_rgb5 = self.densenet_features.features.denseblock4
            
      
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.densenet.densenet161(pretrained=pretrained)

        avg = torch.mean(self.densenet_features.features.conv0.weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1 = conv1d
        
        self.enc_dsm1 = nn.Sequential(self.encoder_depth.conv1,
                                      self.encoder_depth.features.norm0,
                                      self.encoder_depth.features.relu0,)
        
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.features.pool0,
                                      self.encoder_depth.features.denseblock1)
        
        self.transition2_dsm = self.encoder_depth.features.transition1
        self.enc_dsm3 = self.encoder_depth.features.denseblock2

        self.transition3_dsm = self.encoder_depth.features.transition2
        self.enc_dsm4 = self.encoder_depth.features.denseblock3

        self.transition4_dsm = self.encoder_depth.features.transition3
        self.enc_dsm5 = self.encoder_depth.features.denseblock4

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(2208, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(2208, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(2112, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(768, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(384, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(96, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(2208, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(2208, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(2112, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(768, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(384, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(96, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
                

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         # bs * 64 * W/2 * H/2
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 64 * W/2 * H/2
        x1_side = self.side1_rgb(x1)

        ##########################################################

        y2 = self.enc_dsm2(y1)         # bs * 256 * W/4 * H/4
        y2_side = self.side2_cross(y2)
        

        x2 = self.enc_rgb2(x1)         # bs * 256 * W/4 * H/4
        x2_side = self.side2_rgb(x2)

        ##########################################################

        y2 = self.transition2_dsm(y2)
        y3 = self.enc_dsm3(y2)         # bs * 512 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        

        x2 = self.transition2(x2)
        x3 = self.enc_rgb3(x2)         # bs * 512 * W/8 * H/8
        x3_side = self.side3_rgb(x3)


        ##########################################################
        y3 = self.transition3_dsm(y3)
        y4 = self.enc_dsm4(y3)         # bs * 1024 * W/16 * H/16
        y4_side = self.side4_cross(y4)
        

        x3 = self.transition3(x3)
        x4 = self.enc_rgb4(x3)         # bs * 1024 * W/16 * H/16
        x4_side = self.side4_rgb(x4)

        
        ##########################################################
        y4 = self.transition4_dsm(y4)
        y5 = self.enc_dsm5(y4)         # bs * 512 * W/16 * H/16
        y5_side = self.side5_cross(y5)
        
        
        x4 = self.transition4(x4)
        x5 = self.enc_rgb5(x4)         # bs * 512 * W/16 * H/16
        x5_side = self.side5_rgb(x5)

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)


        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   


        ##########################################################

        

        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
 


        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
    

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   

        ##########################################################


        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
  

        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 

        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)


        return final



In [3]:
x = torch.rand(1,3,640,640).to(device)
y = torch.rand(1,1,640,640).to(device)

model = FuseNet_Dens(2,True).to(device)

out = model(x,y)

out.shape

torch.Size([1, 2, 640, 640])

In [4]:
def term_construct(input_tensor):

  return {'x_rgb':input_tensor[0] , 'x_dsm': input_tensor[1]}


macs, params = get_model_complexity_info(model,
                                         input_res=(x,y),
                                         as_strings=True,
                                         input_constructor = term_construct,
                                         print_per_layer_stat=False,
                                         verbose=True)

print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       132.16 GMac
Number of parameters:           57.66 M 


##ResNet34_Paper

In [11]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet_ResNet(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part
            
        self.resnet_features = torchvision.models.resnet34(pretrained=pretrained)
        
        self.enc_rgb1 = nn.Sequential(self.resnet_features.conv1,
                                    self.resnet_features.bn1,
                                    self.resnet_features.relu,)
        self.enc_rgb2 = nn.Sequential(self.resnet_features.maxpool,
                                    self.resnet_features.layer1)
        
        self.enc_rgb3 = self.resnet_features.layer2
        self.enc_rgb4 = self.resnet_features.layer3
        self.enc_rgb5 = self.resnet_features.layer4

               
        
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.resnet34(pretrained=pretrained)

        avg = torch.mean(self.encoder_depth.conv1.weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1 = conv1d
        
        self.enc_dsm1 = nn.Sequential(self.encoder_depth.conv1,
                                    self.encoder_depth.bn1,
                                    self.encoder_depth.relu,)
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.maxpool,
                                    self.encoder_depth.layer1)
        
        self.enc_dsm3 = self.encoder_depth.layer2
        self.enc_dsm4 = self.encoder_depth.layer3
        self.enc_dsm5 = self.encoder_depth.layer4

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(128, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(64, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(64, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(512, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(256, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(128, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(64, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(64, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
        

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         # bs * 64 * W/2 * H/2
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 64 * W/2 * H/2
        x1_side = self.side1_rgb(x1)

        ##########################################################

        y2 = self.enc_dsm2(y1)         # bs * 64 * W/4 * H/4
        y2_side = self.side2_cross(y2)
        
        

        x2 = self.enc_rgb2(x1)         # bs * 64 * W/4 * H/4
        x2_side = self.side2_rgb(x2)
        
        

        ##########################################################

        y3 = self.enc_dsm3(y2)         # bs * 128 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        


        x3 = self.enc_rgb3(x2)         # bs * 128 * W/8 * H/8
        x3_side = self.side3_rgb(x3)
        
        

        ##########################################################

        y4 = self.enc_dsm4(y3)         # bs * 256 * W/16 * H/16
        y4_side = self.side4_cross(y4)
        


        x4 = self.enc_rgb4(x3)         # bs * 256 * W/16 * H/16
        x4_side = self.side4_rgb(x4)
        

        ##########################################################

        y5 = self.enc_dsm5(y4)         # bs * 512 * W/16 * H/16
        y5_side = self.side5_cross(y5)
        
        

        x5 = self.enc_rgb5(x4)         # bs * 512 * W/16 * H/16
        x5_side = self.side5_rgb(x5)
        

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)
         

        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   


        ##########################################################


        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
   


        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
    

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   


        ##########################################################


        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    


        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
        

        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        

        return final


In [12]:
x = torch.rand(1,3,640,640).to(device)
y = torch.rand(1,1,640,640).to(device)

model = FuseNet_ResNet(2,True).to(device)

out = model(x,y)

out.shape

Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




torch.Size([1, 2, 640, 640])

In [13]:
def term_construct(input_tensor):

  return {'x_rgb':input_tensor[0] , 'x_dsm': input_tensor[1]}


macs, params = get_model_complexity_info(model,
                                         input_res=(x,y),
                                         as_strings=True,
                                         input_constructor = term_construct,
                                         print_per_layer_stat=False,
                                         verbose=True)

print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       64.41 GMac
Number of parameters:           43.68 M 


##Vgg16_bn_Paper

In [2]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(2, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part


        self.poold = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, 
                                 ceil_mode=False, return_indices=False)
            
        self.resnet_features = list(torchvision.models.vgg16_bn(pretrained=self.pretrained).features.children())
        
        self.enc_rgb1 = nn.Sequential(*self.resnet_features[:6])         #64
        self.enc_rgb2 = nn.Sequential(*self.resnet_features[7:13])       #128
        self.enc_rgb3 = nn.Sequential(*self.resnet_features[14:23])      #256
        self.enc_rgb4 = nn.Sequential(*self.resnet_features[24:33])      #512
        self.enc_rgb5 = nn.Sequential(*self.resnet_features[34:43])      #512

               
        
        # DSM Encoder Part
        self.encoder_depth = list(torchvision.models.vgg16_bn(pretrained=self.pretrained).features.children())

        enc1 = nn.Sequential(*self.encoder_depth[:6])

        avg = torch.mean(enc1[0].weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        conv1d.weight.data = avg
        enc1[0] = conv1d
        
        self.enc_dsm1 = enc1

        self.enc_dsm2 = nn.Sequential(*self.encoder_depth[7:13]) 
        self.enc_dsm3 = nn.Sequential(*self.encoder_depth[14:23]) 
        self.enc_dsm4 = nn.Sequential(*self.encoder_depth[24:33]) 
        self.enc_dsm5 = nn.Sequential(*self.encoder_depth[34:43])

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side5_rgb  = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side4_rgb  = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side3_rgb  = nn.Sequential(
            nn.Conv2d(256, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side2_rgb  = nn.Sequential(
            nn.Conv2d(128, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side1_rgb  = nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 


        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side5_cross = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side4_cross = nn.Sequential(
            nn.Conv2d(512, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side3_cross = nn.Sequential(
            nn.Conv2d(256, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side2_cross = nn.Sequential(
            nn.Conv2d(128, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 

        self.side1_cross = nn.Sequential(
            nn.Conv2d(64, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            ) 


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
                

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         # bs * 64 * W/2 * H/2
        poold_dsm  = self.poold(y1)
        y1_side = self.side1_cross(poold_dsm)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 64 * W/2 * H/2
        poold  = self.poold(x1)
        x1_side = self.side1_rgb(poold)
        

        ##########################################################

        y2 = self.enc_dsm2(poold_dsm)         # bs * 64 * W/4 * H/4
        poold_dsm  = self.poold(y2)
        y2_side = self.side2_cross(poold_dsm)
        y2_e1 = F.interpolate(y2_side, scale_factor=2, mode='bilinear',align_corners=True)
        

        x2 = self.enc_rgb2(poold)         # bs * 64 * W/4 * H/4
        poold  = self.poold(x2)
        x2_side = self.side2_rgb(poold)
        x2_e1 = F.interpolate(x2_side, scale_factor=2, mode='bilinear',align_corners=True)
        
        

        ##########################################################

        y3 = self.enc_dsm3(poold_dsm)         # bs * 128 * W/8 * H/8
        poold_dsm  = self.poold(y3)
        y3_side = self.side3_cross(poold_dsm)
        y3_e2 = F.interpolate(y3_side, scale_factor=2, mode='bilinear',align_corners=True)
        y3_e1 = F.interpolate(y3_side, scale_factor=4, mode='bilinear',align_corners=True)


        x3 = self.enc_rgb3(poold)         # bs * 128 * W/8 * H/8
        poold  = self.poold(x3)
        x3_side = self.side3_rgb(poold)
        x3_e2 = F.interpolate(x3_side, scale_factor=2, mode='bilinear',align_corners=True)
        x3_e1 = F.interpolate(x3_side, scale_factor=4, mode='bilinear',align_corners=True)
        
        

        ##########################################################

        y4 = self.enc_dsm4(poold_dsm)         # bs * 256 * W/16 * H/16
        poold_dsm  = self.poold(y4)
        y4_side = self.side4_cross(poold_dsm)
        y4_e3 = F.interpolate(y4_side, scale_factor=2, mode='bilinear',align_corners=True)
        y4_e2 = F.interpolate(y4_side, scale_factor=4, mode='bilinear',align_corners=True)


        x4 = self.enc_rgb4(poold)         # bs * 256 * W/16 * H/16
        poold   = self.poold(x4)
        x4_side = self.side4_rgb(poold)
        x4_e3 = F.interpolate(x4_side, scale_factor=2, mode='bilinear',align_corners=True)
        x4_e2 = F.interpolate(x4_side, scale_factor=4, mode='bilinear',align_corners=True)
        

        ##########################################################

        y5 = self.enc_dsm5(poold_dsm)         # bs * 512 * W/16 * H/16
        poold_dsm  = self.poold(y5)
        y5_side = self.side5_cross(poold_dsm)
        y5_e4 = F.interpolate(y5_side, scale_factor=2, mode='bilinear',align_corners=True)
        y5_e3 = F.interpolate(y5_side, scale_factor=4, mode='bilinear',align_corners=True)
        

        x5 = self.enc_rgb5(poold)         # bs * 512 * W/16 * H/16
        poold  = self.poold(x5)
        x5_side = self.side5_rgb(poold)
        x5_e4 = F.interpolate(x5_side, scale_factor=2, mode='bilinear',align_corners=True)
        x5_e3 = F.interpolate(x5_side, scale_factor=4, mode='bilinear',align_corners=True)
        
        ##########################################################

        y6 =  self.pool(poold_dsm)
        y6_side = self.side6_cross(y6)
        y6_e5 = F.interpolate(y6_side, scale_factor=2, mode='bilinear',align_corners=True)
        y6_e4 = F.interpolate(y6_side, scale_factor=4, mode='bilinear',align_corners=True)

        out_dsm1 = self.dconv6_cross(y6_side)
         

        x6 =  self.pool(poold)
        x6_side = self.side6_rgb(x6)
        x6_e5 = F.interpolate(x6_side, scale_factor=2, mode='bilinear',align_corners=True)
        x6_e4 = F.interpolate(x6_side, scale_factor=4, mode='bilinear',align_corners=True)

        out_rgb1 = self.dconv6_rgb(x6_side)   


        ##########################################################


        FG = torch.cat(((x5_side + x6_e5)/2, out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5((x5_side + x6_e5)/2 , (y5_side + y6_e5)/2)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
         


        ##########################################################


        FG = torch.cat(((x4_side + x6_e4 + x5_e4)/3 ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4((x4_side + x6_e4 + x5_e4)/3, (y4_side + y6_e4 + y5_e4)/3)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
            

        ##########################################################


        FG = torch.cat(((x3_side + x5_e3 + x4_e3)/3 ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3((x3_side + x5_e3 + x4_e3)/3, (y3_side + y5_e3 + y4_e3)/3)
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   
        

        ##########################################################


        FG = torch.cat(((x2_side + x4_e2 + x3_e2)/3 ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2((x2_side + x4_e2 + x3_e2)/3 , (y2_side + y4_e2 + y3_e2)/3)
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
        

        ##########################################################

        FG = torch.cat(((x1_side + x3_e1 + x2_e1)/3 ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        
        FG_cross = self.gate1((x1_side + x3_e1 + x2_e1)/3, (y1_side + y3_e1 + y2_e1)/3)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
        

        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        

        return final


In [3]:
x = torch.rand(1,3,640,640).to(device)
y = torch.rand(1,1,640,640).to(device)

model = FuseNet(2,True).to(device)

out = model(x,y)

out.shape

torch.Size([1, 2, 640, 640])

In [4]:
def term_construct(input_tensor):

  return {'x_rgb':input_tensor[0] , 'x_dsm': input_tensor[1]}


macs, params = get_model_complexity_info(model,
                                         input_res=(x,y),
                                         as_strings=True,
                                         input_constructor = term_construct,
                                         print_per_layer_stat=False,
                                         verbose=True)

print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       259.41 GMac
Number of parameters:           30.06 M 


##Shufflenet

In [None]:
        self.densenet_features = torchvision.models.densenet.densenet121(pretrained=pretrained)
        
        self.enc_rgb1 = nn.Sequential(self.densenet_features.features.conv0,
                         self.densenet_features.features.norm0,
                         self.densenet_features.features.relu0,)
        
        self.enc_rgb2 = nn.Sequential(self.densenet_features.features.pool0,
                        self.densenet_features.features.denseblock1)
        
        self.transition2 = self.densenet_features.features.transition1
        self.enc_rgb3 = self.densenet_features.features.denseblock2

        self.transition3 = self.densenet_features.features.transition2
        self.enc_rgb4 = self.densenet_features.features.denseblock3

        self.transition4 = self.densenet_features.features.transition3
        self.enc_rgb5 = self.densenet_features.features.denseblock4

In [23]:
shufflenet_features = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=False)
shufflenet_features.conv1[0]

Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)

In [22]:
shufflenet_features = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=False)

avg = torch.mean(shufflenet_features.conv1[0].weight.data,dim=1)
avg = avg.unsqueeze(1)
conv1d = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
conv1d.weight.data = avg
shufflenet_features.conv1[0] = conv1d

In [24]:
shufflenet_features.conv5

Sequential(
  (0): Conv2d(464, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
)

In [None]:
self.encoder_depth = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=False)

avg = torch.mean(self.encoder_depth.conv1[0].weight.data,dim=1)
avg = avg.unsqueeze(1)
conv1d = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
conv1d.weight.data = avg
self.encoder_depth.conv1[0] = conv1d

self.enc_rgb1 = self.encoder_depth.conv1
self.enc_rgb2 = nn.Sequential(self.encoder_depth.maxpool,
                              self.encoder_depth.stage2)
self.enc_rgb3 = self.encoder_depth.stage3
self.enc_rgb4 = self.encoder_depth.stage4
self.enc_rgb5 = self.encoder_depth.conv5

In [None]:
self.shufflenet_features = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=False)
self.enc_rgb1 = self.shufflenet_features.conv1
self.enc_rgb2 = nn.Sequential(self.shufflenet_features.maxpool,
                              self.shufflenet_features.stage2)
self.enc_rgb3 = self.shufflenet_features.stage3
self.enc_rgb4 = self.shufflenet_features.stage4
self.enc_rgb5 = self.shufflenet_features.conv5

In [12]:
# DSM Encoder Part
        self.encoder_depth = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=False)

        avg = torch.mean(self.densenet_features.features.conv0.weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1 = conv1d
        
        self.enc_dsm1 = nn.Sequential(self.encoder_depth.conv1,
                                      self.encoder_depth.features.norm0,
                                      self.encoder_depth.features.relu0,)
        
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.features.pool0,
                                      self.encoder_depth.features.denseblock1)
        
        self.transition2_dsm = self.encoder_depth.features.transition1
        self.enc_dsm3 = self.encoder_depth.features.denseblock2

        self.transition3_dsm = self.encoder_depth.features.transition2
        self.enc_dsm4 = self.encoder_depth.features.denseblock3

        self.transition4_dsm = self.encoder_depth.features.transition3
        self.enc_dsm5 = self.encoder_depth.features.denseblock4

Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /root/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth


HBox(children=(FloatProgress(value=0.0, max=9218294.0), HTML(value='')))




In [18]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels,
                 scale=2):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(scale_factor=scale, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(scale_factor=scale, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet_shuffle(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part

        self.shufflenet_features = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=pretrained)
        self.enc_rgb1 = self.shufflenet_features.conv1
        self.enc_rgb2 = nn.Sequential(self.shufflenet_features.maxpool,
                              self.shufflenet_features.stage2)
        self.enc_rgb3 = self.shufflenet_features.stage3
        self.enc_rgb4 = self.shufflenet_features.stage4
        self.enc_rgb5 = self.shufflenet_features.conv5
      
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=pretrained)

        avg = torch.mean(self.encoder_depth.conv1[0].weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1[0] = conv1d

        self.enc_dsm1 = self.encoder_depth.conv1
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.maxpool,
                                      self.encoder_depth.stage2)
        self.enc_dsm3 = self.encoder_depth.stage3
        self.enc_dsm4 = self.encoder_depth.stage4
        self.enc_dsm5 = self.encoder_depth.conv5

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16, scale=1) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16,scale=2) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16,scale=4) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(464, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(232, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(116, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(24, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16,scale =1) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16,scale=2) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16,scale=4) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(464, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(232, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(116, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(24, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
        
  
        self.final_dsm1 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm2 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm3 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm4 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm5 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        self.final_dsm6 = nn.Sequential(
            nn.Conv2d(16, 2, kernel_size=1, padding=0),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True),
            )
        

    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 24 * W/2 * H/2
        x1_side = self.side1_rgb(x1)      # bs * 16 * 320 * 320

        ##########################################################

        y2 = self.enc_dsm2(y1)         
        y2_side = self.side2_cross(y2)
        

        x2 = self.enc_rgb2(x1)         # bs * 116 * W/4 * H/4
        x2_side = self.side2_rgb(x2)   # bs * 16 * 80 * 80

        ##########################################################

        y3 = self.enc_dsm3(y2)         # bs * 512 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        

        x3 = self.enc_rgb3(x2)         # bs * 232 * W/8 * H/8
        x3_side = self.side3_rgb(x3)   # bs * 16 * 40 * 40


        ##########################################################
        y4 = self.enc_dsm4(y3)         
        y4_side = self.side4_cross(y4)
        

        x4 = self.enc_rgb4(x3)         # bs * 464 * W/16 * H/16
        x4_side = self.side4_rgb(x4)   # bs * 16 * 20 * 20

        
        ##########################################################
        y5 = self.enc_dsm5(y4)         
        y5_side = self.side5_cross(y5)
        
        
        x5 = self.enc_rgb5(x4)         # bs * 1024 * W/16 * H/16
        x5_side = self.side5_rgb(x5)   # bs * 16 * 20 * 20

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)
        out_dsm1_out = self.final_dsm1(out_dsm1)    

        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   

        ##########################################################

        

        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
        out_dsm2_out = self.final_dsm2(out_dsm2)    


        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
        out_dsm3_out = self.final_dsm3(out_dsm3)     

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   
        out_dsm4_out = self.final_dsm4(out_dsm4)  
        
        ##########################################################

   
        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
        out_dsm5_out = self.final_dsm5(out_dsm5)  

       
        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
        out_dsm6_out = self.final_dsm6(out_dsm6)
    
        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        f1 = F.interpolate(out_dsm1_out, scale_factor=32,mode='bilinear', align_corners=True )
        f2 = F.interpolate(out_dsm2_out, scale_factor=32,mode='bilinear', align_corners=True )
        f3 = F.interpolate(out_dsm3_out, scale_factor=16 ,mode='bilinear', align_corners=True )
        f4 = F.interpolate(out_dsm4_out, scale_factor=8 ,mode='bilinear', align_corners=True )
        f5 = F.interpolate(out_dsm5_out, scale_factor=2 ,mode='bilinear', align_corners=True )

        final_help = (f1 + f2 + f3 + f4 + f5 + out_dsm6_out)/6


        a1 = [out_dsm1_out,out_dsm2_out,out_dsm3_out,out_dsm4_out,out_dsm5_out,out_dsm6_out]
        

        return final, final_help, a1



In [19]:
x = torch.rand(1,3,640,640)
y = torch.rand(1,1,640,640)

model = FuseNet_shuffle(2,True)

out,_,a1= model(x,y)

out.shape

torch.Size([1, 2, 640, 640])

##Shufflenet_Paper

In [2]:
#version 02
class Gated_Fusion(nn.Module):
    
    def __init__(self, in_channels):
        super().__init__()
        
        self.gate = nn.Sequential(            
            nn.Conv2d(2 * in_channels, in_channels,kernel_size=1, padding=0),
            nn.Sigmoid(),
            )
        
    def forward(self, x,y):
      out = torch.cat([x,y], dim=1)
      G = self.gate(out)
      
      PG = x * G
      FG = y * (1-G)

      
      return torch.cat([FG , PG], dim=1)

class Upsample(nn.Module):
    
    def __init__(self, scale_factor, mode="nearest"):
        super(Upsample, self).__init__()
        self.scale_factor = scale_factor
        self.mode = mode

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode,align_corners=True)
        return x

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,kernel_size=3, padding=1):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin, kernel_size = kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)

    def forward(self, x):
        out = self.depthwise(x)
        out = self.pointwise(out)
        return out    

class decoder_block(nn.Module):
    def __init__(self, 
                 input_channels, 
                 output_channels,
                 scale=2):
        
        super(decoder_block, self).__init__()
        
        self.identity = nn.Sequential(
            Upsample(scale_factor=scale, mode="bilinear"),
            nn.Conv2d(input_channels, output_channels, kernel_size=1, padding=0)
            )

        self.decode = nn.Sequential(
            Upsample(scale_factor=scale, mode="bilinear"),
            nn.BatchNorm2d(input_channels),
            depthwise_separable_conv(input_channels,input_channels),
            nn.BatchNorm2d(input_channels),
            nn.ReLU(inplace=True),
            depthwise_separable_conv(input_channels,output_channels),
            nn.BatchNorm2d(output_channels),
            )
        
   
    def forward(self,x):
      
      residual = self.identity(x)
      
      out = self.decode(x)

      out += residual

      return out

class FuseNet_shuffle(nn.Module):
    
    def __init__(self, num_classes, pretrained=False, is_deconve=False):
        
        super().__init__()
        
        self.num_classes = num_classes
        self.pretrained = pretrained
        
        # RGB Encoder Part

        self.shufflenet_features = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=pretrained)
        self.enc_rgb1 = self.shufflenet_features.conv1
        self.enc_rgb2 = nn.Sequential(self.shufflenet_features.maxpool,
                              self.shufflenet_features.stage2)
        self.enc_rgb3 = self.shufflenet_features.stage3
        self.enc_rgb4 = self.shufflenet_features.stage4
        self.enc_rgb5 = self.shufflenet_features.conv5
      
        # DSM Encoder Part
        self.encoder_depth = torchvision.models.shufflenetv2.shufflenet_v2_x1_0(pretrained=pretrained)

        avg = torch.mean(self.encoder_depth.conv1[0].weight.data,dim=1)
        avg = avg.unsqueeze(1)
        conv1d = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        conv1d.weight.data = avg
        self.encoder_depth.conv1[0] = conv1d

        self.enc_dsm1 = self.encoder_depth.conv1
        self.enc_dsm2 = nn.Sequential(self.encoder_depth.maxpool,
                                      self.encoder_depth.stage2)
        self.enc_dsm3 = self.encoder_depth.stage3
        self.enc_dsm4 = self.encoder_depth.stage4
        self.enc_dsm5 = self.encoder_depth.conv5

        self.pool = nn.MaxPool2d(2)

        self.gate5 = Gated_Fusion(16)
        self.gate4 = Gated_Fusion(16)
        self.gate3 = Gated_Fusion(16)
        self.gate2 = Gated_Fusion(16)
        self.gate1 = Gated_Fusion(16)

        self.gate_final = Gated_Fusion(16)
       

        self.dconv6_rgb = decoder_block(16 , 16)
        self.dconv5_rgb = decoder_block(16 + 16 , 16, scale=1) 
        self.dconv4_rgb = decoder_block(16 + 16 , 16) 
        self.dconv3_rgb = decoder_block(16 + 16 , 16,scale=2) 
        self.dconv2_rgb = decoder_block(16 + 16 , 16,scale=4) 
        self.dconv1_rgb = decoder_block(16 + 16 , 16) 

        self.side6_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_rgb  = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_rgb  = nn.Conv2d(464, 16, kernel_size=1, padding=0)
        self.side3_rgb  = nn.Conv2d(232, 16, kernel_size=1, padding=0)
        self.side2_rgb  = nn.Conv2d(116, 16, kernel_size=1, padding=0)
        self.side1_rgb  = nn.Conv2d(24, 16, kernel_size=1, padding=0)

        
        self.dconv6_cross = decoder_block(16 , 16)
        self.dconv5_cross = decoder_block(16 + 16 + 16 , 16,scale =1) 
        self.dconv4_cross = decoder_block(16 + 16 + 16 , 16) 
        self.dconv3_cross = decoder_block(16 + 16 + 16 , 16,scale=2) 
        self.dconv2_cross = decoder_block(16 + 16 + 16 , 16,scale=4) 
        self.dconv1_cross = decoder_block(16 + 16 + 16 , 16) 

        self.side6_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side5_cross = nn.Conv2d(1024, 16, kernel_size=1, padding=0)
        self.side4_cross = nn.Conv2d(464, 16, kernel_size=1, padding=0)
        self.side3_cross = nn.Conv2d(232, 16, kernel_size=1, padding=0)
        self.side2_cross = nn.Conv2d(116, 16, kernel_size=1, padding=0)
        self.side1_cross = nn.Conv2d(24, 16, kernel_size=1, padding=0)


        self.final = nn.Sequential(
            nn.Conv2d(32, self.num_classes, kernel_size=1, padding=0),
            nn.BatchNorm2d(self.num_classes),
            nn.ReLU(inplace=True),
            )
        


    def forward(self, x_rgb , x_dsm):
        
        # dsm_encoder
        
        y1 = self.enc_dsm1(x_dsm)         
        y1_side = self.side1_cross(y1)
       

        x1 = self.enc_rgb1(x_rgb)         # bs * 24 * W/2 * H/2
        x1_side = self.side1_rgb(x1)      # bs * 16 * 320 * 320

        ##########################################################

        y2 = self.enc_dsm2(y1)         
        y2_side = self.side2_cross(y2)
        

        x2 = self.enc_rgb2(x1)         # bs * 116 * W/4 * H/4
        x2_side = self.side2_rgb(x2)   # bs * 16 * 80 * 80

        ##########################################################

        y3 = self.enc_dsm3(y2)         # bs * 512 * W/8 * H/8
        y3_side = self.side3_cross(y3)
        

        x3 = self.enc_rgb3(x2)         # bs * 232 * W/8 * H/8
        x3_side = self.side3_rgb(x3)   # bs * 16 * 40 * 40


        ##########################################################
        y4 = self.enc_dsm4(y3)         
        y4_side = self.side4_cross(y4)
        

        x4 = self.enc_rgb4(x3)         # bs * 464 * W/16 * H/16
        x4_side = self.side4_rgb(x4)   # bs * 16 * 20 * 20

        
        ##########################################################
        y5 = self.enc_dsm5(y4)         
        y5_side = self.side5_cross(y5)
        
        
        x5 = self.enc_rgb5(x4)         # bs * 1024 * W/16 * H/16
        x5_side = self.side5_rgb(x5)   # bs * 16 * 20 * 20

        ##########################################################

        y6 =  self.pool(y5)
        y6_side = self.side6_cross(y6)
        

        out_dsm1 = self.dconv6_cross(y6_side)
        

        x6 =  self.pool(x5)
        x6_side = self.side6_rgb(x6)
        

        out_rgb1 = self.dconv6_rgb(x6_side)   

        ##########################################################

        

        FG = torch.cat((x5_side , out_rgb1),dim=1)      
        out_rgb2 = self.dconv5_rgb(FG) 


        FG_cross = self.gate5(x5_side , y5_side)
        FG_dsm = torch.cat((FG_cross, out_dsm1),dim=1)
        out_dsm2 = self.dconv5_cross(FG_dsm) 
     

        ##########################################################


        FG = torch.cat((x4_side  ,out_rgb2),dim=1)      
        out_rgb3 = self.dconv4_rgb(FG) 


        FG_cross = self.gate4(x4_side , y4_side)
        FG_dsm = torch.cat((FG_cross, out_dsm2),dim=1)  
        out_dsm3 = self.dconv4_cross(FG_dsm) 
         

        ##########################################################


        FG = torch.cat((x3_side ,out_rgb3),dim=1)      
        out_rgb4 = self.dconv3_rgb(FG)  


        FG_cross = self.gate3(x3_side , y3_side )
        FG_dsm = torch.cat((FG_cross, out_dsm3),dim=1) 
        out_dsm4 = self.dconv3_cross(FG_dsm)   
        
        
        ##########################################################

   
        FG = torch.cat((x2_side  ,out_rgb4),dim=1)      
        out_rgb5 = self.dconv2_rgb(FG)   

        
        FG_cross = self.gate2(x2_side , y2_side )
        FG_dsm = torch.cat((FG_cross, out_dsm4),dim=1)
        out_dsm5 = self.dconv2_cross(FG_dsm)    
        

       
        ##########################################################

        FG = torch.cat((x1_side ,out_rgb5),dim=1)      
        out_rgb6 = self.dconv1_rgb(FG)   


        FG_cross = self.gate1(x1_side , y1_side)
        FG_dsm = torch.cat((FG_cross, out_dsm5),dim=1)  
        out_dsm6 = self.dconv1_cross(FG_dsm) 
       
    
        ##########################################################

        final = self.gate_final(out_rgb6, out_dsm6)            
        final = self.final(final)

        

        return final



In [3]:
x = torch.rand(1,3,640,640)
y = torch.rand(1,1,640,640)

model = FuseNet_shuffle(2,True)

out = model(x,y)

out.shape

torch.Size([1, 2, 640, 640])

In [4]:
def term_construct(input_tensor):

  return {'x_rgb':input_tensor[0] , 'x_dsm': input_tensor[1]}


macs, params = get_model_complexity_info(model,
                                         input_res=(x,y),
                                         as_strings=True,
                                         input_constructor = term_construct,
                                         print_per_layer_stat=False,
                                         verbose=True)

print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Computational complexity:       7.06 GMac
Number of parameters:           4.7 M   
