In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch import nn
from torchsummary import summary

In [2]:

class ChannelAttentionModule(nn.Module):
    def __init__(self, F, r):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AvgPool2d(1)  # 1 for Global Average Pooling. Can Use AdaptiveAvgPool2d for variable input sizes
        self.max_pool = nn.MaxPool2d(1)  # 1 for Gloabal Max Pooling. Can Use AdaptiveMaxPool2d for variable input sizes
        self.fc = nn.Sequential(
            nn.Conv2d(F, F // r, kernel_size=1),  # 1x1 convolution to reduce the number of channels
            nn.ReLU(),  # ReLU as Activation Function
            nn.Conv2d(F // r, F, kernel_size=1)  # 1x1 convolution to restore the original number of channels
        )
        self.sigmoid = nn.Sigmoid()  # Sigmoid Function

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        sum_out = avg_out + max_out #Summation
        return self.sigmoid(sum_out)


In [3]:
# Printing ChannelAttentionModule architecture
model_CAM = ChannelAttentionModule(F=3, r=3)
summary(model_CAM, (3, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         AvgPool2d-1              [-1, 3, 1, 1]               0
            Conv2d-2              [-1, 1, 1, 1]               4
              ReLU-3              [-1, 1, 1, 1]               0
            Conv2d-4              [-1, 3, 1, 1]               6
         MaxPool2d-5              [-1, 3, 1, 1]               0
            Conv2d-6              [-1, 1, 1, 1]               4
              ReLU-7              [-1, 1, 1, 1]               0
            Conv2d-8              [-1, 3, 1, 1]               6
           Sigmoid-9              [-1, 3, 1, 1]               0
Total params: 20
Trainable params: 20
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
-----------------------------------------------------

In [4]:
class SpatialAttentionModule(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttentionModule, self).__init__()
        self.fc = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=kernel_size//2)  #Symmetric zero-padding
        )
        self.sigmoid = nn.Sigmoid() #Sigmoid Function

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True) #Average across Channel Axis. keepdim retains channel dimensions
        max_out, _ = torch.max(x, dim=1, keepdim=True) #Maximum across Channel Axis. keepdim retains channel dimensions
        combined_features = torch.cat([avg_out, max_out], dim=1)  #Concatenation
        f_out = self.fc(combined_features)
        return self.sigmoid(f_out)

In [5]:
# Printing SpatialAttentionModule architecture
model_SAM = SpatialAttentionModule(kernel_size=7)
summary(model_SAM, (3, 1, 1)) 

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1              [-1, 1, 1, 1]              99
           Sigmoid-2              [-1, 1, 1, 1]               0
Total params: 99
Trainable params: 99
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.00
Estimated Total Size (MB): 0.00
----------------------------------------------------------------


In [6]:
class CBAMModule(nn.Module):
    def __init__(self, F, r, spatial_r=7):
        super(CBAMModule, self).__init__()
        self.channel_attention = ChannelAttentionModule(F, r) #Channel Attention
        self.spatial_attention = SpatialAttentionModule(spatial_r) #Spatial Attention
    
    def forward(self, x):
        return self.spatial_attention(self.channel_attention(x)) #Sequential Arrangement

In [7]:
model_CBAM = CBAMModule(F=3, r=1, spatial_r=7)
summary(model_CBAM, (3, 1, 1))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         AvgPool2d-1              [-1, 3, 1, 1]               0
            Conv2d-2              [-1, 3, 1, 1]              12
              ReLU-3              [-1, 3, 1, 1]               0
            Conv2d-4              [-1, 3, 1, 1]              12
         MaxPool2d-5              [-1, 3, 1, 1]               0
            Conv2d-6              [-1, 3, 1, 1]              12
              ReLU-7              [-1, 3, 1, 1]               0
            Conv2d-8              [-1, 3, 1, 1]              12
           Sigmoid-9              [-1, 3, 1, 1]               0
ChannelAttentionModule-10              [-1, 3, 1, 1]               0
           Conv2d-11              [-1, 1, 1, 1]              99
          Sigmoid-12              [-1, 1, 1, 1]               0
SpatialAttentionModule-13              [-1, 1, 1, 1]               0
Total params: 147
Trainable p