In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# CustomDataset
class CustomDataset(Dataset):
    def __init__(self, num_samples, num_channels, height, width):
        self.num_samples = num_samples
        self.data = torch.randn(num_samples, num_channels, height, width)
        self.labels = torch.randint(0, 2, (num_samples,), dtype=torch.float32)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

batch_size = 64
dataset = CustomDataset(num_samples=1000, num_channels=64, height=64, width=64)

# Data Loader
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

class ChannelAttentionModule(nn.Module):  
    def __init__(self, F, r=16):  
        super(ChannelAttentionModule, self).__init__()  
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
           
        self.fc = nn.Sequential(nn.Conv2d(F, F // r, 1, bias=False),  
                               nn.ReLU(),
                               nn.Conv2d(F // r, F, 1, bias=False)) 
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)
    
class SpatialAttentionModule(nn.Module):  
    def __init__(self, spatial_r=7):
        super(SpatialAttentionModule, self).__init__() 
        self.conv1 = nn.Conv2d(2, 1, kernel_size=spatial_r, padding=spatial_r//2, bias=False)  
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAMModule(nn.Module):
    def __init__(self, F, r, spatial_r=7):  
        super(CBAMModule, self).__init__()
        self.channel_attention = ChannelAttentionModule(F, r) 
        self.spatial_attention = SpatialAttentionModule(spatial_r)  
    
    def forward(self, x):
        return self.spatial_attention(self.channel_attention(x))

model = CBAMModule(F=64, r=16)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)


num_epochs = 100
weight_decay = 0.001

for epoch in range(num_epochs):
    model.train()
    for batch_data, batch_labels in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_data)
        loss = criterion(outputs.squeeze(), batch_labels)

        # L2 Regularisation
        l2_reg = torch.tensor(0.0)
        for param in model.parameters():
            l2_reg += torch.norm(param, p=2)

        # Adding L2 regularisation term to the loss
        loss += weight_decay * l2_reg

        loss.backward()
        optimizer.step()

    if epoch % 10 == 0:
        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')


Epoch [1/100], Loss: 0.7570207118988037
Epoch [11/100], Loss: 0.7676351070404053
Epoch [21/100], Loss: 0.7045150399208069
Epoch [31/100], Loss: 0.6798362731933594
Epoch [41/100], Loss: 0.7505645155906677
Epoch [51/100], Loss: 0.7027023434638977
Epoch [61/100], Loss: 0.7361545562744141
Epoch [71/100], Loss: 0.6793601512908936
Epoch [81/100], Loss: 0.6683942079544067
Epoch [91/100], Loss: 0.7636892199516296


In [13]:
import torchinfo
from torchinfo import summary
summary(model, input_size=[1, 64, 64, 64])

Layer (type:depth-idx)                   Output Shape              Param #
CBAMModule                               [1, 1, 1, 1]              --
├─ChannelAttentionModule: 1-1            [1, 64, 1, 1]             --
│    └─AdaptiveAvgPool2d: 2-1            [1, 64, 1, 1]             --
│    └─Sequential: 2-2                   [1, 64, 1, 1]             --
│    │    └─Conv2d: 3-1                  [1, 4, 1, 1]              256
│    │    └─ReLU: 3-2                    [1, 4, 1, 1]              --
│    │    └─Conv2d: 3-3                  [1, 64, 1, 1]             256
│    └─AdaptiveMaxPool2d: 2-3            [1, 64, 1, 1]             --
│    └─Sequential: 2-4                   [1, 64, 1, 1]             (recursive)
│    │    └─Conv2d: 3-4                  [1, 4, 1, 1]              (recursive)
│    │    └─ReLU: 3-5                    [1, 4, 1, 1]              --
│    │    └─Conv2d: 3-6                  [1, 64, 1, 1]             (recursive)
│    └─Sigmoid: 2-5                      [1, 64, 1, 1]  