# Model Implementation


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Helper Functions


In [19]:
def test_forward_pass(model, input_shape=(1, 3, 64, 64), device='cuda'):
    """
    Test a model with dummy data to ensure the forward pass works.
    
    Args:
        model: The PyTorch model to test
        input_shape: The shape of the input tensor (batch_size, channels, height, width)
        device: The device to run the test on
        
    Returns:
        output: The model output
        output_shape: The shape of the output
    """
    # Create a random tensor with the correct shape
    x = torch.randn(input_shape).to(device)
    
    # Move model to the device
    model = model.to(device)
    model.eval()
    
    # Perform forward pass
    with torch.no_grad():
        try:
            output = model(x)
            print(f"Forward pass successful!")
            print(f"Input shape: {x.shape}")
            print(f"Output shape: {output.shape}")
            return output, output.shape
        except Exception as e:
            print(f"Forward pass failed with error: {e}")
            return None, None

In [20]:
def count_parameters(model):
    """
    Count the number of trainable parameters in a model.
    
    Args:
        model: PyTorch model
        
    Returns:
        int: Number of trainable parameters
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Baseline Model


In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F

#=============================================
# 1. Baseline CNN for CIFAR-10
#=============================================
class BaselineCNN(nn.Module):
    """
    Streamlined Baseline CNN for CIFAR-10 classification
    """
    def __init__(self, num_classes=10):
        super(BaselineCNN, self).__init__()
        
        # Feature extraction layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        
        # Classifier
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(128, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)
        
        # Initialize weights
        self._initialize_weights()
    
    def forward(self, x):
        # Block 1
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        
        # Block 2
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        
        # Block 3
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        
        # Block 4
        x = F.relu(self.bn4(self.conv4(x)))
        
        # Classifier
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


#=============================================
# 2. Squeeze and Excitation CNN for CIFAR-10
#=============================================
class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze operation
        y = self.avg_pool(x).view(b, c)
        # Excitation operation
        y = self.fc(y).view(b, c, 1, 1)
        # Scale the input
        return x * y.expand_as(x)


class SECNN(nn.Module):
    """
    CNN with Squeeze-and-Excitation blocks for CIFAR-10
    """
    def __init__(self, num_classes=10):
        super(SECNN, self).__init__()
        
        # Feature extraction layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.se1 = SEBlock(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.se2 = SEBlock(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.se3 = SEBlock(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.se4 = SEBlock(128)
        
        # Classifier
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(128, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)
        
        # Initialize weights
        self._initialize_weights()
        
        # For saving attention maps
        self.attention_maps = []
    
    def forward(self, x):
        # Clear previous attention maps
        self.attention_maps = []
        
        # Block 1
        x = self.bn1(self.conv1(x))
        x = self.se1(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        # Block 2
        x = self.bn2(self.conv2(x))
        x = self.se2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        # Block 3
        x = self.bn3(self.conv3(x))
        x = self.se3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        
        # Block 4
        x = self.bn4(self.conv4(x))
        x = self.se4(x)
        x = F.relu(x)
        
        # Classifier
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def get_attention_maps(self):
        """Hook function to extract SE attention weights"""
        attention_maps = []
        def hook_fn(module, input, output):
            if isinstance(module, SEBlock):
                b, c = input[0].size(0), input[0].size(1)
                # Get the sigmoid attention weights
                weights = module.fc(module.avg_pool(input[0]).view(b, c))
                attention_maps.append(weights.detach().cpu())
        
        # Register hooks for all SE blocks
        hooks = []
        for name, module in self.named_modules():
            if isinstance(module, SEBlock):
                hooks.append(module.register_forward_hook(hook_fn))
        
        return attention_maps, hooks


#=============================================
# 3. Convolutional Block Attention Module (CBAM) CNN
#=============================================
class ChannelAttention(nn.Module):
    """Channel attention module for CBAM"""
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Shared MLP
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False)
        )
        
    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = torch.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    """Spatial attention module for CBAM"""
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "Kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = torch.sigmoid(self.conv(out))
        return out


class CBAMBlock(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=16, kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention(kernel_size)
        
    def forward(self, x):
        # Channel attention first
        channel_att = self.channel_att(x)
        x = x * channel_att
        
        # Then spatial attention
        spatial_att = self.spatial_att(x)
        x = x * spatial_att
        
        # Save attention map for visualization
        return x, spatial_att


class CBAMCNN(nn.Module):
    """
    CNN with CBAM for CIFAR-10
    """
    def __init__(self, num_classes=10):
        super(CBAMCNN, self).__init__()
        
        # Feature extraction layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.cbam1 = CBAMBlock(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.cbam2 = CBAMBlock(64)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.cbam3 = CBAMBlock(128)
        
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.cbam4 = CBAMBlock(128)
        
        # Classifier
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc1 = nn.Linear(128, 128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, num_classes)
        
        # Initialize weights
        self._initialize_weights()
        
        # For saving attention maps
        self.attention_maps = []
    
    def forward(self, x):
        # Clear previous attention maps
        self.attention_maps = []
        
        # Block 1
        x = self.bn1(self.conv1(x))
        x_att, spatial_map1 = self.cbam1(x)
        self.attention_maps.append(spatial_map1)
        x = F.relu(x_att)
        x = F.max_pool2d(x, 2)
        
        # Block 2
        x = self.bn2(self.conv2(x))
        x_att, spatial_map2 = self.cbam2(x)
        self.attention_maps.append(spatial_map2)
        x = F.relu(x_att)
        x = F.max_pool2d(x, 2)
        
        # Block 3
        x = self.bn3(self.conv3(x))
        x_att, spatial_map3 = self.cbam3(x)
        self.attention_maps.append(spatial_map3)
        x = F.relu(x_att)
        x = F.max_pool2d(x, 2)
        
        # Block 4
        x = self.bn4(self.conv4(x))
        x_att, spatial_map4 = self.cbam4(x)
        self.attention_maps.append(spatial_map4)
        x = F.relu(x_att)
        
        # Classifier
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def get_attention_maps(self):
        """Method to access collected attention maps"""
        return self.attention_maps


#=============================================
# Helper Functions
#=============================================
def count_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def test_forward_pass(model, input_shape=(1, 3, 32, 32), device='cpu'):
    """
    Test a model with dummy data to ensure the forward pass works.
    
    Args:
        model: The PyTorch model to test
        input_shape: The shape of the input tensor (batch_size, channels, height, width)
        device: The device to run the test on
        
    Returns:
        output: The model output
        output_shape: The shape of the output
    """
    # Create a random tensor with the correct shape
    x = torch.randn(input_shape).to(device)
    
    # Move model to the device
    model = model.to(device)
    model.eval()
    
    # Perform forward pass
    with torch.no_grad():
        try:
            output = model(x)
            print(f"Forward pass successful!")
            print(f"Input shape: {x.shape}")
            print(f"Output shape: {output.shape}")
            return output, output.shape
        except Exception as e:
            print(f"Forward pass failed with error: {e}")
            return None, None

In [30]:
# Create models
baseline_model = BaselineCNN(num_classes=10)
se_model = SECNN(num_classes=10)
cbam_model = CBAMCNN(num_classes=10)

# Test forward pass
print("Testing Baseline CNN:")
test_forward_pass(baseline_model)

print("\nTesting SE-CNN:")
test_forward_pass(se_model)

print("\nTesting CBAM-CNN:")
test_forward_pass(cbam_model)

# Count parameters
print("\nParameter Counts:")
print(f"Baseline CNN: {count_parameters(baseline_model):,} parameters")
print(f"SE-CNN: {count_parameters(se_model):,} parameters")
print(f"CBAM-CNN: {count_parameters(cbam_model):,} parameters")

# Calculate parameter increase percentage
baseline_params = count_parameters(baseline_model)
se_params = count_parameters(se_model)
cbam_params = count_parameters(cbam_model)

print(f"\nSE-CNN has {(se_params - baseline_params) / baseline_params * 100:.2f}% more parameters than baseline")
print(f"CBAM-CNN has {(cbam_params - baseline_params) / baseline_params * 100:.2f}% more parameters than baseline")

Testing Baseline CNN:
Forward pass successful!
Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])

Testing SE-CNN:
Forward pass successful!
Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])

Testing CBAM-CNN:
Forward pass successful!
Input shape: torch.Size([1, 3, 32, 32])
Output shape: torch.Size([1, 10])

Parameter Counts:
Baseline CNN: 259,338 parameters
SE-CNN: 264,074 parameters
CBAM-CNN: 264,466 parameters

SE-CNN has 1.83% more parameters than baseline
CBAM-CNN has 1.98% more parameters than baseline
