In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Module
from torch.nn import init


In [23]:
class Att_Module(Module):
    """Attention module is implemented inspired by the following article:
       https://arxiv.org/abs/1807.06521. 
       CBAM attention module incorporates both dorsal (where) and ventral (what) stram of attention mechanism.
       base_n_filter (int): Number of base filters in the U-Net.
       r_ratio (int): reduction ratio that influences the hidden layer neuron number. it reduces the number of neurons 
                      by a factor of r_ration, the unreduced number of Neurons is equal to the number of channels in 
                      the batch
       att_kernel_size (int):  indicated the kernel size fo the convolution filtering in spatial attention (size in paper=7)
       """
    
    def __init__(self,base_n_filter = 64, r_ratio=16, att_kernel_size=7):
        
        super(Att_Module, self).__init__()
        
        self.in_Feat = base_n_filter # input feature tensor
        self.R_ratio = r_ratio # reduction ratio for FCN
        self.Att_K = att_kernel_size    
        self.GAP = nn.AdaptiveAvgPool2d((1,1)) # for global average pooling
        self.GMP =  nn.AdaptiveMaxPool2d((1,1)) # for global max pooling
        self.FCN_1 = nn.Linear(self.in_Feat,self.in_Feat//self.R_ratio)
        self.FCN_2 = nn.Linear(self.in_Feat//self.R_ratio,self.in_Feat)        
        self.CNN = nn.Conv2d(2,1,kernel_size=self.Att_K,padding=(self.Att_K-1)//2)        
        
        
    def forward(self, feat_tensor):
        
        ################## CHANNLE ATTENTION #################
        x = self.GAP(feat_tensor)
        x = self.FCN_1(x.view(-1,x.size()[1])) # shape match for FCN
        x = self.FCN_2(x)
        
        y = self.GMP(feat_tensor)
        y = self.FCN_1(y.view(-1,y.size()[1])) # shape match for FCN
        y = self.FCN_2(y)
        
        ch_out = F.sigmoid(x+y)
        ch_out = ch_out.unsqueeze(dim=2).unsqueeze(dim=3) # expected dimension fixed
        ch_out = ch_out.repeat(1,1,feat_tensor.size()[2],feat_tensor.size()[3]) # repeat a vector to expected batch size   
              
        feat_tensor_ch = ch_out*feat_tensor # elemenwise multiplication of input to channel att
        
        ####################### Spatial attention ###############
        max_spat, indices = torch.max(feat_tensor_ch.view(-1,feat_tensor_ch.size()[1]), dim=1) # INPUT IS 
        max_spat = max_spat.view(-1,feat_tensor_ch.size()[2],feat_tensor_ch.size()[3]) # max_pooling on channels
        
        avg_spat = torch.mean(feat_tensor_ch.view(-1,feat_tensor_ch.size()[1]),
                            dim=1).view(-1,feat_tensor_ch.size()[2],feat_tensor_ch.size()[3])# avg_pool on channles
        
        avg_spat = avg_spat.unsqueeze(dim=1) 
        max_spat = max_spat.unsqueeze(dim=1)
        
        z = torch.cat([max_spat,avg_spat], dim=1)
        
        spat_att = F.sigmoid(self.CNN(z))
        spat_att = spat_att.repeat(1,feat_tensor.size()[1],1,1) # repeat a tensor to expected batch size (num samples) 
        feat_tensor_ch_spat = spat_att*feat_tensor # elemenwise multiplication of input to channel-spatial att
        
        ######################## FINAL ATTENTION MODULE ############
        Final_att = feat_tensor_ch_spat + feat_tensor    
        
        print(f"The input batch size is {feat_tensor.size()}")  

        print(f"The output attention batch size is {Final_att.size()}")  
        
        return Final_att
         

In [24]:
def double_conv(in_c, out_c, kernel_size):
    ## DEFINING DOUBLE CONVOLUTION
    conv = nn.Sequential(nn.Conv2d(in_c,out_c,kernel_size, padding=(kernel_size-1)//2),
              nn.ReLU(inplace=True),
              nn.Conv2d(out_c,out_c,kernel_size, padding=(kernel_size-1)//2),
              nn.ReLU(inplace=True))
    return conv

In [25]:
class ENCODER(Module):
    def __init__(self,in_channel=1, depth=3, n_filters=64, kernel_size=3, attention=True, r_ratio=16, att_kernel_size=7):
        
        super(ENCODER, self).__init__()
        
        self.Depth =depth
        self.base_n_filter =n_filters
        self.R_ratio = r_ratio
        self.Att_K = att_kernel_size 
        maxpool2D = nn.MaxPool2d(2,2)                                     
        self.Down_path = nn.ModuleList() # CREATING A LIST FOR ENCODER LAYERS
        self.Down_path.append(double_conv(in_channel,n_filters,kernel_size)) # FIRST CONV2d LAYER
        self.Down_path.append(maxpool2D) #FIRST MAXPOOL LAYER
        self.attention = attention
        
        for i in range(depth-1):
            
            self.Down_path.append(double_conv(n_filters*(2**i),n_filters*(2**(i+1)),kernel_size))
            self.Down_path.append(maxpool2D)
            
        self.conv_bottom = double_conv(n_filters*(2**(self.Depth-1)),n_filters*(2**(self.Depth)),kernel_size)
        
    def forward(self, image):
                
        feature = []
        #FIRST BLOCK                                                            
        x = self.Down_path[0](image)
        print(x.size())
        
        if self.attention:
            self.attention = Att_Module(base_n_filter =self.base_n_filter, r_ratio=self.R_ratio, att_kernel_size=self.Att_K)
            x = self.attention(x)+x
            
        else:
            pass
            
        feature.append(x)
        x = self.Down_path[1](x)
        
        # Down-sampling path
        for j in range(1,self.Depth):
            x = self.Down_path[2*j](x)
            if self.attention:
                
                self.attention = Att_Module(base_n_filter =self.base_n_filter*(2**j),r_ratio=self.R_ratio*(2**j),
                                            att_kernel_size=self.Att_K)
                x = self.attention(x)+x
              
            else:
                pass
            feature.append(x)
            print(x.size())
            x = self.Down_path[2*j+1](x)
            
        # Bottom level
        x = self.conv_bottom(x)
        feature.append(x) # save the bottom layer in feature list
        print(x.size())
        return feature
      
    

In [26]:
class DECODER(Module):
    def __init__(self,out_channel=1, depth=3, n_filters=64, kernel_size=3):
        super(DECODER,self).__init__()
        
        self.Up_path = nn.ModuleList() # CREATING A LIST FOR DECODER LAYERS
        self.Depth = depth
        self.Out = out_channel
        
        for h in range(0,self.Depth):
            
            self.Up_path.append(nn.ConvTranspose2d((n_filters*(2**self.Depth))//(2**h), 
                                                   (n_filters*(2**(self.Depth-1)))//(2**h),kernel_size,stride=2,
                                                   padding=(kernel_size-1)//2))
                                
            self.Up_path.append(double_conv(n_filters*(2**self.Depth)//(2**h),
                                            n_filters*(2**(self.Depth-1))//(2**h),kernel_size))
            
        self.conv_top = double_conv(n_filters,self.Out,kernel_size)
        
    def forward(self, feature):
        
        x = feature[-1]
        
        for i in range(0,self.Depth):
                        
            x = self.Up_path[2*i](x)
          
            feature_T = torch.tensor(feature[(self.Depth-1)-i])
            x = F.interpolate(x,size=(feature_T.size()[2],feature_T.size()[3]),mode='bilinear')
            x = torch.cat([x,feature[(self.Depth-1)-i]], dim =1)
            x = self.Up_path[2*i+1](x)
        
        out = F.relu(self.conv_top(x)) # different kinds of activation function can be used
        print(f"the ouput size is {out.size()}")
        return out
        

In [27]:
class UNet_M(Module):
    def __init__(self, in_channel=1, out_channel=2, depth=3, n_filters=64, kernel_size=3, attention=True, 
                r_ratio=16, att_kernel_size=7):
        
        """
        in_channel (int): Number of channels in the input image.
        out_channel (int): Number of channels in the output image.
        depth (int): Number of down convolutions minus bottom down convolution.
        n_filters (int): Number of base filters in the U-Net.
        kernel_size (int):kernel size for convolution process in the U-Net (paper kernel_size=3)
        attention (bool):If attention=True, attention Module is added in the decoder to help focus 
        attention on important features. Code related to the attentions gates is inspired from:
        https://github.com/ozan-oktay/Attention-Gated-Networks
        r_ratio (int): reduction ratio that influences the hidden layer neuron number. it reduces the 
        number of neurons by factor of r_ratio 
        att_kernel_size (int):  indicated the kernel size fo the convolution filtering in spatial 
        attention (size in paper=7)

        
        
        """
        
        super(UNet_M,self).__init__()
        self.encoder = ENCODER(in_channel=in_channel, depth=depth, n_filters=n_filters, kernel_size=kernel_size
                               ,attention= attention, r_ratio=r_ratio, att_kernel_size=att_kernel_size)
        self.decoder = DECODER(out_channel=out_channel, depth=depth, n_filters=n_filters, kernel_size=kernel_size)
        
    def forward(self,x):
        feat = self.encoder(x)
        pred = self.decoder(feat)
        
        print(pred.size())
        return pred
        
    
if __name__=="__main__":
    image = torch.rand((2,1,322,322)) # batch, Channle, W, H as input
    model = UNet_M()
    print(model(image))
        

torch.Size([2, 64, 322, 322])
The input batch size is torch.Size([2, 64, 322, 322])
The output attention batch size is torch.Size([2, 64, 322, 322])
The input batch size is torch.Size([2, 128, 161, 161])
The output attention batch size is torch.Size([2, 128, 161, 161])
torch.Size([2, 128, 161, 161])
The input batch size is torch.Size([2, 256, 80, 80])
The output attention batch size is torch.Size([2, 256, 80, 80])
torch.Size([2, 256, 80, 80])
torch.Size([2, 512, 40, 40])




the ouput size is torch.Size([2, 2, 322, 322])
torch.Size([2, 2, 322, 322])
tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.2304, 0.2376, 0.2439,  ..., 0.2433, 0.2395, 0.2341],
          [0.2303, 0.2441, 0.2528,  ..., 0.2479, 0.2440, 0.2394],
          [0.2312, 0.2402, 0.2471,  ..., 0.2515, 0.2471, 0.2350],
          ...,
          [0.2308, 0.2380, 0.2476,  ..., 0.2475, 0.2442, 0.2384],
          [0.2293, 0.2320, 0.2371,  ..., 0.2419, 0.2415, 0.2332],
          [0.2266, 0.2235, 0.2261,  ..., 0.2305, 0.2336, 0.2289]]],


        [[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0