# Attention Mechanisms in Convolutional Neural Networks

Attention mechanisms in convolutional neural networks enable the model to adaptively
focus on the most relevant features of the input signal, either at the channel level or
at the spatial level. These modules learn to recalibrate intermediate activations by
assigning differentiated importance weights, which increases the representational
capacity of the model without drastically increasing the number of parameters or the
computational cost.

In modern architectures, attention is integrated in a modular way into existing
convolutional blocks, such as the residual blocks of ResNet. The following sections
describe and implement two of the most influential attention mechanisms in convolutional
networks: The Squeeze-and-Excitation (SE) block and the Convolutional Block Attention
Module (CBAM).

## Squeeze-and-Excitation (SE) Block

The Squeeze-and-Excitation block, introduced in the work
[Squeeze-and-Excitation Networks](https://arxiv.org/abs/1709.01507), incorporates a
channel-wise attention mechanism. The central idea is to explicitly model the dependency
relationships between feature channels so that the network learns to emphasize those
channels that are most informative for the task, while suppressing less relevant or
redundant channels.

The SE mechanism decomposes into two conceptual stages, commonly referred to as _squeeze_
and _excitation_. In the squeeze phase, the spatial dimension of each feature map is
reduced by means of global average pooling. In this way, each channel is compressed into
a single scalar value that summarizes its global activation across the entire image. In
the excitation phase, these aggregated values are fed into a small fully connected
network that learns a channel-wise attention function. The output of this network is a
vector of weights in the interval $(0, 1)$, which is applied multiplicatively to the
original channels, recalibrating their relative importance.

Let $X \in \mathbb{R}^{B \times C \times H \times W}$ be a feature tensor with batch size
$B$, number of channels $C$, and spatial dimensions $H \times W$. The squeeze operation
computes, for each channel $c$,

$$z_c = \frac{1}{HW} \sum_{i=1}^{H} \sum_{j=1}^{W} X_c(i, j)$$

The compressed vector $z \in \mathbb{R}^{C}$ is processed by a two-layer fully connected
network with an intermediate dimensionality reduction, which produces a vector of weights
$s \in (0, 1)^{C}$ after a sigmoid activation. The recalibration is implemented as

$$\tilde{X}_c(i, j) = s_c \cdot X_c(i, j)$$

The following code shows an implementation of the SE block and its integration into a
basic residual block in PyTorch. The code is designed for direct use in a reproducible
and fully executable workflow.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels: int, reduction_ratio: int = 16) -> None:
        super().__init__()
        reduced_channels = max(in_channels // reduction_ratio, 1)
        # Squeeze: Global average pooling per channel
        self.squeeze = nn.AdaptiveAvgPool2d(1)
        # Excitation: Two fully connected (implemented as Linear) layers
        self.excitation = nn.Sequential(
            nn.Linear(in_channels, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, in_channels, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, channels, _, _ = x.size()
        # Squeeze: Global average pooling per channel
        squeezed = self.squeeze(x).view(batch_size, channels)
        # Excitation: Channel-wise weights in (0, 1)
        excited = self.excitation(squeezed).view(batch_size, channels, 1, 1)
        # Channel-wise recalibration
        return x * excited


class SEResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        stride: int = 1,
        reduction_ratio: int = 16,
    ) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.se = SqueezeExcitation(out_channels, reduction_ratio)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.se(out)
        out += identity
        out = F.relu(out)
        return out

To verify the correct construction and behavior of the SE block, a small functional test
can be defined. This test checks that input and output shapes match and reports the
number of parameters of the SE module.

In [None]:
def test_se_block() -> None:
    x = torch.randn(2, 64, 32, 32)
    se_block = SqueezeExcitation(in_channels=64, reduction_ratio=16)
    output = se_block(x)
    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"SE parameters: {sum(p.numel() for p in se_block.parameters())}")
    assert x.shape == output.shape, "Shape mismatch"
    print("SE Block test passed")


test_se_block()

The SE block introduces a relatively moderate number of additional parameters, controlled
by the hyperparameter `reduction_ratio`. This parameter determines the bottleneck size in
the excitation network: Larger values reduce the capacity of the module but also decrease
its computational cost. In practice, configurations such as `reduction_ratio = 16`
usually provide a good balance between modeling capacity and efficiency.

## Convolutional Block Attention Module (CBAM)

The Convolutional Block Attention Module (CBAM) extends the SE idea by sequentially
incorporating attention both in the channel domain and in the spatial domain. First, it
applies a channel attention module conceptually similar to SE, but combining information
from global average pooling and global max pooling. Subsequently, it applies a spatial
attention module that analyzes the distribution of activations across channels to
determine which regions of the image are most relevant.

The channel attention module in CBAM is built from two parallel paths. One path receives
as input the output of a global average pooling and the other uses the output of a global
max pooling, both computed over the spatial dimensions for each channel. Each of these
summaries is processed by a small $1 \times 1$ convolutional network that acts as a
shared fully connected projection. The two resulting outputs are combined by element-wise
addition and then passed through a sigmoid function to obtain a channel attention map
that modulates the contribution of each channel.

The spatial attention module is applied to the feature maps already recalibrated by
channel. To this end, two single-channel spatial maps are computed by aggregating over
the channel dimension using mean and maximum operations. These two maps are concatenated
along the channel axis and processed by a convolution of size $k \times k$, typically
with $k = 7$, followed by a sigmoid activation. The result is a spatial attention map
that is applied multiplicatively to the signal, modulating the importance of each spatial
position $(i, j)$ in the image.

The following code presents the implementation of CBAM (channel and spatial attention)
and its integration into a residual block.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ChannelAttention(nn.Module):
    def __init__(self, in_channels: int, reduction_ratio: int = 16) -> None:
        super().__init__()
        reduced_channels = max(in_channels // reduction_ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        # Shared MLP implemented with 1x1 convolutions
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels, reduced_channels, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(reduced_channels, in_channels, kernel_size=1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        attention = self.sigmoid(avg_out + max_out)
        return x * attention


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7) -> None:
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            2, 1, kernel_size=kernel_size, padding=padding, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Channel-wise average and max projections
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        combined = torch.cat([avg_out, max_out], dim=1)
        attention = self.sigmoid(self.conv(combined))
        return x * attention


class CBAM(nn.Module):
    def __init__(
        self, in_channels: int, reduction_ratio: int = 16, kernel_size: int = 7
    ) -> None:
        super().__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x


class CBAMResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.cbam = CBAM(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.shortcut(x)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.cbam(out)
        out += identity
        out = F.relu(out)
        return out

The following code fragment performs a basic check of the CBAM module, analogous to the
test applied in the case of the SE block. It validates that the input and output have the
same shape and reports the number of parameters of the module.

In [None]:
import torch


def test_cbam() -> None:
    x = torch.randn(2, 64, 32, 32)  # Batch of 2, 64 channels, 32x32 feature map
    cbam = CBAM(in_channels=64)

    output = cbam(x)

    print(f"Input shape:  {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"CBAM parameters: {sum(p.numel() for p in cbam.parameters())}")

    assert x.shape == output.shape, "Shape mismatch"
    print("CBAM test passed")


test_cbam()

In practice, CBAM often provides consistent improvements over SE, since it combines
channel-level and spatial attention in a complementary way. Spatial attention is
particularly useful in tasks where the localization of objects or discriminative regions
plays a critical role, such as object detection, semantic and instance segmentation, or
recognition in scenarios with multiple instances per image.