# CBAM to CNN

Please embed CBAM to the CNN model or other image classification models. 

CNN + CBAN 在流程中的一個過程
MNIST dataset

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets.mnist import read_image_file, read_label_file

class CustomMNIST(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images = read_image_file(images_path)
        self.labels = read_label_file(labels_path)
        self.transform = transform

        # 计算均值和标准差
        self.mean, self.std = self.calculate_mean_std()

    def calculate_mean_std(self):
        images_flat = self.images.view(self.images.size(0), -1).float() / 255.0
        mean = images_flat.mean()
        std = images_flat.std()
        return mean.item(), std.item()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx].unsqueeze(0).float() / 255.0
        image = (image - self.mean) / self.std 
        label = self.labels[idx]
        return image, label

train_images_path = './data/train-images-idx3-ubyte/train-images-idx3-ubyte'
train_labels_path = './data/train-labels-idx1-ubyte/train-labels-idx1-ubyte'
test_images_path = './data/t10k-images-idx3-ubyte/t10k-images-idx3-ubyte'
test_labels_path = './data/t10k-labels-idx1-ubyte/t10k-labels-idx1-ubyte'

trainset = CustomMNIST(train_images_path, train_labels_path)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

testset = CustomMNIST(test_images_path, test_labels_path)
testloader = DataLoader(testset, batch_size=1000, shuffle=False)


In [2]:
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(channels, reduction)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

In [3]:
class CNNWithCBAM(nn.Module):
    def __init__(self, num_classes=10):
        super(CNNWithCBAM, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.cbam1 = CBAM(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.cbam2 = CBAM(128)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(self.cbam1(self.conv1(x)))
        x = self.pool(self.cbam2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = nn.ReLU()(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
model = CNNWithCBAM(num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
def train(model, trainloader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for inputs, labels in trainloader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Accuracy: {100 * correct / total}%")

In [6]:
train(model, trainloader, criterion, optimizer, epochs=5)
test(model, testloader)

Epoch 1, Loss: 0.1377631417398196
Epoch 2, Loss: 0.04089171191621788
Epoch 3, Loss: 0.02469020376833726
Epoch 4, Loss: 0.01740437371867115
Epoch 5, Loss: 0.014836040282766017
Accuracy: 99.0%
