In [1]:
import math
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F

#  CBAM

In [2]:
class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


In [3]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [4]:
class ChannelGate(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(ChannelGate, self).__init__()
        self.gate_channels = gate_channels
        self.mlp = nn.Sequential(
            Flatten(),
            nn.Linear(gate_channels, gate_channels // reduction_ratio),
            nn.ReLU(),
            nn.Linear(gate_channels // reduction_ratio, gate_channels)
            )
        self.pool_types = pool_types
    def forward(self, x):
        channel_att_sum = None
        for pool_type in self.pool_types:
            if pool_type=='avg':
                avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( avg_pool )
            elif pool_type=='max':
                max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
                channel_att_raw = self.mlp( max_pool )

            if channel_att_sum is None:
                channel_att_sum = channel_att_raw
            else:
                channel_att_sum = channel_att_sum + channel_att_raw
                
        scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)

        return scale # Channel attention map

In [5]:
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )

In [6]:
class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        conv_spatial = []
        self.spatial = BasicConv(2,1,kernel_size,stride=1,padding=(kernel_size-1) // 2,relu=False, bn=False)
            
    def forward(self, x):
        """
        x1: Input vector
        """
        x_compress = self.compress(x) # input
        x_out = self.spatial(x_compress)
            
        scale = torch.sigmoid(x_out) # broadcasting
        return scale


In [7]:
class CBAM(nn.Module):
    def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
        super(CBAM, self).__init__()
        self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
        self.SpatialGate = SpatialGate()
    def forward(self,x1,x2):
        """
        x1: Query vector
        x2: Value vector
        """
        c_map = self.ChannelGate(x1) # Channel attention map
        s_map = self.SpatialGate(x1) # Spatial attention map
        x_out = x2*c_map
        x_out = x_out*s_map
        return x_out

# Self Attention module

In [8]:
class PAM_Module(nn.Module):
    #Ref from SAGAN
    def __init__(self, in_dim,norm_layer=nn.BatchNorm2d):
        """
        Position Attention Module
        Arguments
        in_dim: input channel dimension
        ratio: ratio for level channels, used for reducing number of channels and parameters
        rates: dilation rates
        
        Returns:
        output: Output with attention module with parameter gamma for tuning
        """
        super(PAM_Module, self).__init__()
        self.channel_in = in_dim
        self.query_conv = nn.Conv2d(self.channel_in,self.channel_in//8,kernel_size=1)
        self.key_conv = nn.Conv2d(self.channel_in, self.channel_in//8, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x1, x2):
        """
        Parameters:
        ----------
            inputs :
                x1 : input feature maps( B X C X H X W) ([SLF,MLF])
                x2 : multi feature maps(B X C X H X W) (MLF)
            returns :
                out : scaled attention map
        """
        m_batchsize, C, height, width = x1.size()
        proj_query = self.query_conv(x1).view(m_batchsize, -1, width*height).permute(0, 2, 1)
        proj_key = self.key_conv(x1).view(m_batchsize, -1, width*height)
        energy = torch.bmm(proj_query, proj_key)
        attention = self.softmax(energy)
        

        out = torch.bmm(x2.view(m_batchsize,-1,width*height), attention.permute(0, 2, 1))
        out = out.view(m_batchsize, C, height, width)

        out = self.gamma * out + x2
        return out


In [9]:
class CAM_Module(nn.Module):
    """ Channel attention module"""
    def __init__(self, in_dim, norm_layer=nn.BatchNorm2d):
        super(CAM_Module, self).__init__()
        self.channel_in = in_dim
        
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax  = nn.Softmax(dim=-1)
    def forward(self,x1, x2):
        """
        Parameters:
        ----------
            inputs :
                x1 : input feature maps( B X C X H X W) ([SLF,MLF])
                x2 : multi feature maps(B X C X H X W) (MLF)
            returns :
                out : attention value + input feature
                attention: B X C X C
        """
        m_batchsize, C, height, width = x1.size()
           
        # Flatten 
        proj_query = x1.view(m_batchsize,C,-1)
        proj_key = x1.view(m_batchsize,C,-1).permute(0,2,1)
        
        energy = torch.bmm(proj_query, proj_key)
        energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
        attention = self.softmax(energy_new)
        
        proj_value = x2.view(m_batchsize, C, -1)
        out = torch.bmm(attention, proj_value)
        out = out.view(m_batchsize, C, height, width)

#         out = self.gamma * out + x2
        return out


In [10]:
class RefineConv(nn.Module):
    """
    Helper function for Multiple Convolutions for refining.
    
    Parameters:
    ----------
    inputs:
        in_ch : input channels
        out_ch : output channels
        attn : Boolean value whether to use Softmax or PReLU
    outputs:
        returns the refined convolution tensor
    """
    def __init__(self, in_ch, out_ch,norm_layer=nn.BatchNorm2d):
        super(RefineConv, self).__init__()
        
        self.refine = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=1),
            norm_layer(out_ch), 
            nn.PReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 
            norm_layer(out_ch), 
            nn.PReLU(),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), 
            norm_layer(out_ch), 
            nn.PReLU()
        )
    
    def forward(self, x):
        return self.refine(x)

In [11]:
class DenseASPP(nn.Module):
    def __init__(self,in_channels, level_channels, rates,norm_layer=nn.BatchNorm2d):
        """
        in_channels: input channels of input
        level_channels: level channels
        rates: dilation rates
        """
        super(DenseASPP,self).__init__()
        self.rates = rates
        self.in_channels = in_channels
        self.level_channels = level_channels
        convs = []
        bns = []
        prelus = []
        c0 = in_channels // 2
        
        # Downsize of input
        self.down0 = nn.Sequential(
            nn.Conv2d(in_channels,c0,1),
            norm_layer(c0),
            nn.PReLU()
        )
        for i in range(0,len(self.rates)):
            temp_in_channels = c0 + i*level_channels
            convs.append(nn.Conv2d(temp_in_channels,level_channels,3,dilation=self.rates[i],padding=self.rates[i]))
            bns.append(norm_layer(level_channels))
            prelus.append(nn.PReLU())
        
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.prelus = nn.ModuleList(prelus)
        
    def forward(self,x):
        # Transform to lower dimension
        c0 = self.down0(x)
        final_output = x
        c_i = c0
        for i in range(len(self.rates)):
            temp_out = self.convs[i](c_i)
            temp_out = self.bns[i](temp_out)
            temp_out = self.prelus[i](temp_out)
            
            c_i = torch.cat((c_i,temp_out),dim=1)
            final_output = torch.cat((final_output,temp_out),dim=1)
            
        return final_output
    

In [12]:
if __name__ == "__main__":   
    x = torch.randn((1,64,256,256))
    y = torch.randn((1,64,256,256))
    test = out(x,y)
    total_params = sum(p.numel() for p in out.parameters())
    print(total_params)

PAM_Module(
  (query_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
  (key_conv): Conv2d(64, 8, kernel_size=(1, 1), stride=(1, 1))
  (softmax): Softmax(dim=-1)
)


KeyboardInterrupt: 