In [2]:
import numpy as np
import torch 
from torch import nn
from torch.nn import init

In [3]:
class ChannelAttention(nn.Module):
    def __init__(self, channel, reduction=16) -> None:
        super().__init__()
        self.maxpool = nn.AdaptiveMaxPool2d(1)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.se = nn.Sequential(
            nn.Conv2d(channel, channel//reduction, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        max_result = self.maxpool(x)
        avg_result = self.avgpool(x)
        max_out = self.se(max_result)
        avg_out = self.se(avg_result)
        output = self.sigmoid(max_out + avg_out)
        return output

In [4]:
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7) -> None:
        super().__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size//2)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        max_result, _ = torch.max(x, dim=1, keepdim=True)
        avg_result = torch.mean(x, dim=1, keepdim=True)
        result = torch.cat([max_result, avg_result], 1)
        output = self.conv(result)
        output = self.sigmoid(output)
        return output
        

In [12]:
class CBAMBlock(nn.Module):
    def __init__(self, channel=512, reduction=16, kernel_size=49) -> None:
        super().__init__()
        self.ca = ChannelAttention(channel=channel, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=kernel_size)
        
    def forward(self, x):
        b, c, _, _ = x.size()
        residual = x 
        out = x * self.ca(x)
        out = out * self.sa(out)
        return out + residual

In [13]:
input=torch.randn(50,512,7,7)
kernel_size=input.shape[2]
cbam = CBAMBlock(channel=512,reduction=16,kernel_size=kernel_size)
output=cbam(input)
print(output.shape)

torch.Size([50, 512, 7, 7])


In [11]:
x = torch.randn(16, 4, 8, 8)
max_res,_ = torch.max(x, dim=1, keepdim=True)
avg_result = torch.mean(x, dim=1, keepdim=True)
print(max_res.size())
print(avg_result.size())

torch.Size([16, 1, 8, 8])
torch.Size([16, 1, 8, 8])
