## CBAM_Channel_Attention_Module

In [1]:
import torch
import torch.nn as nn

In [2]:
class ChannelAttentionModule(nn.Module):
    def __init__(self,in_channels,ratio):
        super(ChannelAttentionModule,self).__init__()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveMaxPool2d(1)
        
        self.sharedMlP = nn.Sequential(
                         nn.Conv2d(in_channels,in_channels//ratio,stride=1,bias=False),
                         nn.ReLU(),
                         nn.Conv2d(in_channels//ratio,in_channels,1,bias=False))
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        maxpool_output = self.sharedMlP(self.maxpool(x))
        avgpool_output = self.sharedMlP(self.avgpool(x))
        output = self.sigmoid(maxpool_output + avgpool_output)
        
        return output

In [3]:
class SpatialAttentionModule(nn.Module):
    def __init__(self,kernel_size=7):
        super(SpatialAttentionModule,self).__init__()
        assert kernel_size in (3,7),"kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        
        self.conv = nn.Conv2d(2,1,kernel_size,padding=padding,bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        max_output,_ = torch.max(x,dim=1)
        avg_output = torch.mean(x,dim=1)
        x = torch.cat([max_output,avg_output],dim=1)
        x = self.conv(x)
        output = self.sigmoid(x)
        
        return output

## SE attention 

In [5]:
class SEAttentionModule(nn.Module):
    def __init__(self,in_channels,reduction=16):
        super(SEAttentionModule,self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                  nn.Linear(in_channels,in_channels//reduction,bias=False),
                  nn.ReLU(inplace=True),
                  nn.Linear(in_channels//reduction,in_channels,bias=False))
        self.sigmoid = nn.Sigmoid()
    def forward(self,x):
        batch_size,channel,_ = x.view()
        y = self.avgpool(x)
        y = x.view(x.size(0),-1)
        y = self.fc(x)
        y = self.sigmoid(x)
        y = y.view(batch_size,channel,1,1)
        output = x * y
        return output

### SK attention

In [None]:
class SKConv(nn.Module):
    def __init__(self, features, WH, M, G, r, stride=1 ,L=32):
        """ Constructor
        Args:
            features: input channel dimensionality.
            WH: input spatial dimensionality, used for GAP kernel size.
            M: the number of branchs.
            G: num of convolution groups.
            r: the radio for compute d, the length of z.
            stride: stride, default 1.
            L: the minimum dim of the vector z in paper, default 32.
        """
        super(SKConv, self).__init__()
        d = max(int(features/r), L)
        self.M = M
        self.features = features
        self.convs = nn.ModuleList([])
        for i in range(M):
            self.convs.append(nn.Sequential(
                nn.Conv2d(features, features, kernel_size=3+i*2, stride=stride, padding=1+i, groups=G),
                nn.BatchNorm2d(features),
                nn.ReLU(inplace=False)
            ))
        # self.gap = nn.AvgPool2d(int(WH/stride))
        self.fc = nn.Linear(features, d)
        self.fcs = nn.ModuleList([])
        for i in range(M):
            self.fcs.append(
                nn.Linear(d, features)
            )
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        for i, conv in enumerate(self.convs):
            fea = conv(x).unsqueeze_(dim=1)
            if i == 0:
                feas = fea
            else:
                feas = torch.cat([feas, fea], dim=1)
        fea_U = torch.sum(feas, dim=1)
        # fea_s = self.gap(fea_U).squeeze_()
        fea_s = fea_U.mean(-1).mean(-1)
        fea_z = self.fc(fea_s)
        for i, fc in enumerate(self.fcs):
            vector = fc(fea_z).unsqueeze_(dim=1)
            if i == 0:
                attention_vectors = vector
            else:
                attention_vectors = torch.cat([attention_vectors, vector], dim=1)
        attention_vectors = self.softmax(attention_vectors)
        attention_vectors = attention_vectors.unsqueeze(-1).unsqueeze(-1)
        fea_v = (feas * attention_vectors).sum(dim=1)
        return fea_v

In [7]:
class SKAttentionModule(nn.Module):
    def __init__(self,in_channels,G,L=32):
        super(SKAttentionModule,self).__init__()
        d = max(int(in_channels/r),L)
        
        self.conv3x3 = nn.Sequential(
             nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=G),
             nn.BatchNorm2d(in_channels),
             nn.ReLU(inplace=True)
            )
        self.conv5x5 = nn.Sequential(
             nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=G),
             nn.BatchNorm2d(in_channels),
             nn.ReLU(inplace=True)
            )
        self.gp = nn.AvgPool2d(WH)
        self.fc = nn.Linear(in_channels, d)
        self.fcs_1 = nn.Linear(d,in_channels)
        self.fcs_2 = nn.Linear(d,in_channels)
        
    def forward(self,x):
        branch1 = self.conv3x3(x)
        branch2 = self.conv5x5(x)
        branch_1_2 = torch.cat([branch1,branch2],dim=1)
        
        branch3 = branch1 + branch2
        f_gp = self.gp(branch3)
        f_fc = self.fc(f_gp)
        f_fcs1 = self.fcs_1(f_fc)
        f_fcs2 = self.fcs_2(f_fc)
        attention_vectors = torch.cat([f_fcs1,f_fcs2],dim=1).softmax(dim=1)
        attention_vectors = attention_vectors.view(attention_vectors.size(0),-1)
        
        output = attention_vectors * branch_1_2
        
        return output  

## BAM Attention module

In [9]:
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

In [10]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

In [11]:
class ChannelGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, num_layers=1):
        super(ChannelGate, self).__init__()
        self.gate_activation = gate_activation
        self.gate_c = nn.Sequential()
        self.gate_c.add_module( 'flatten', Flatten() )
        gate_channels = [gate_channel]
        gate_channels += [gate_channel // reduction_ratio] * num_layers
        gate_channels += [gate_channel]
        for i in range( len(gate_channels) - 2 ):
            self.gate_c.add_module( 'gate_c_fc_%d'%i, nn.Linear(gate_channels[i], gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_bn_%d'%(i+1), nn.BatchNorm1d(gate_channels[i+1]) )
            self.gate_c.add_module( 'gate_c_relu_%d'%(i+1), nn.ReLU() )
        self.gate_c.add_module( 'gate_c_fc_final', nn.Linear(gate_channels[-2], gate_channels[-1]) )
    def forward(self, in_tensor):
        avg_pool = F.avg_pool2d( in_tensor, in_tensor.size(2), stride=in_tensor.size(2) )
        return self.gate_c( avg_pool ).unsqueeze(2).unsqueeze(3).expand_as(in_tensor)

In [13]:
class SpatialGate(nn.Module):
    def __init__(self, gate_channel, reduction_ratio=16, dilation_conv_num=2, dilation_val=4):
        super(SpatialGate, self).__init__()
        self.gate_s = nn.Sequential()
        self.gate_s.add_module( 'gate_s_conv_reduce0', nn.Conv2d(gate_channel, gate_channel//reduction_ratio, kernel_size=1))
        self.gate_s.add_module( 'gate_s_bn_reduce0',	nn.BatchNorm2d(gate_channel//reduction_ratio) )
        self.gate_s.add_module( 'gate_s_relu_reduce0',nn.ReLU() )
        for i in range( dilation_conv_num ):
            self.gate_s.add_module( 'gate_s_conv_di_%d'%i, 
                                   nn.Conv2d(gate_channel//reduction_ratio, gate_channel//reduction_ratio, 
                                             kernel_size=3, padding=dilation_val, dilation=dilation_val) )
            self.gate_s.add_module( 'gate_s_bn_di_%d'%i, nn.BatchNorm2d(gate_channel//reduction_ratio) )
            self.gate_s.add_module( 'gate_s_relu_di_%d'%i, nn.ReLU() )
        self.gate_s.add_module( 'gate_s_conv_final', nn.Conv2d(gate_channel//reduction_ratio, 1, kernel_size=1) )
    def forward(self, in_tensor):
        return self.gate_s(in_tensor).expand_as(in_tensor)

In [14]:
class BAM(nn.Module):
    def __init__(self, gate_channel):
        super(BAM, self).__init__()
        self.channel_att = ChannelGate(gate_channel)
        self.spatial_att = SpatialGate(gate_channel)
    def forward(self,in_tensor):
        att = 1 + F.sigmoid( self.channel_att(in_tensor) * self.spatial_att(in_tensor) )
        return att * in_tensor

In [20]:
class ChannelGate(nn.Module):
    def __init__(in_channels,r):
        super(ChannelGate,self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels,in_channels//r),
            nn.BatchNorm2d(in_channels//r),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels//r,in_channels),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(inplace=True))
    def forward(self,x):
        output = self.avgpool(output)
        output = output.view(output.size(0),-1)
        output = self.fc(output)
        
        return output.expand_as(x)      

In [19]:
class SpatialGate(nn.Module):
    def __init__(self,in_channels,ratio):
        super(SpatialGate,self).__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(in_channels,in_channels//ratio,kernel_size=1,padding=0),
            nn.BatchNorm2d(in_channels//ratio),
            nn.ReLU(inplace=True))
        self.dilation_conv = nn.Sequential(
            nn.Conv2d(in_channels//ratio,in_channels//ratio,
                      kernel_size=3,padding=4,dilation=4),#dilation conv
            nn.BatchNorm2d(in_channels//ratio),
            nn.ReLU(),
            nn.Conv2d(in_channels//ratio,in_channels//ratio,
                      kernel_size=3,padding=4,dilation=4),#dilation conv
            nn.BatchNorm2d(in_channels//ratio),
            nn.ReLU())
        self.conv1x1_final = nn.Sequential(
            nn.Conv2d(in_channels//ratio,1,kernel_size=1,padding=0))
        def forward(self,x):
            output = self.conv1x1(output)
            output = self.dilation_conv(output)
            output = conv1x1_final(output)
            
            return output.expand_as(output)    

In [25]:
class BAM(nn.Module):
    def __init__(self, in_channels):
        super(BAM, self).__init__()
        self.channel_attention = ChannelGate(in_channels)
        self.spatial_attention = SpatialGate(in_channels)
    def forward(self,x):
        bam_attention = 1 + nn.Sigmoid(self.channel_attention(in_tensor) 
                                       * self.spatial_attention(in_tensor) )
        return bam_attention * x