# Model Implementation


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

## Helper Functions


In [7]:
def test_forward_pass(model, input_shape=(1, 3, 64, 64), device='cuda'):
    """
    Test a model with dummy data to ensure the forward pass works.
    
    Args:
        model: The PyTorch model to test
        input_shape: The shape of the input tensor (batch_size, channels, height, width)
        device: The device to run the test on
        
    Returns:
        output: The model output
        output_shape: The shape of the output
    """
    # Create a random tensor with the correct shape
    x = torch.randn(input_shape).to(device)
    
    # Move model to the device
    model = model.to(device)
    model.eval()
    
    # Perform forward pass
    with torch.no_grad():
        try:
            output = model(x)
            print(f"Forward pass successful!")
            print(f"Input shape: {x.shape}")
            print(f"Output shape: {output.shape}")
            return output, output.shape
        except Exception as e:
            print(f"Forward pass failed with error: {e}")
            return None, None

In [8]:
def count_parameters(model):
    """
    Count the number of trainable parameters in a model.
    
    Args:
        model: PyTorch model
        
    Returns:
        int: Number of trainable parameters
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Baseline Model


In [12]:
#=============================================
# 1. Basic Residual Block
#=============================================
class ResidualBlock(nn.Module):
    """Basic residual block with optional attention mechanism"""
    def __init__(self, in_channels, out_channels, stride=1, downsample=None, 
                 attention_type=None, reduction=16):
        super(ResidualBlock, self).__init__()
        
        # Conv layers
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, 
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Downsample if needed (for dimension matching in skip connection)
        self.downsample = downsample
        
        # Attention mechanism
        self.attention_type = attention_type
        if attention_type == 'se':
            self.attention = SEBlock(out_channels, reduction)
        elif attention_type == 'cbam':
            self.attention = CBAMBlock(out_channels, reduction)
        else:
            self.attention = None
            
        # For saving attention maps
        self.attention_map = None
    
    def forward(self, x):
        identity = x
        
        # First conv block
        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)
        
        # Second conv block
        out = self.conv2(out)
        out = self.bn2(out)
        
        # Apply attention if specified
        if self.attention_type == 'se':
            out = self.attention(out)
        elif self.attention_type == 'cbam':
            out, self.attention_map = self.attention(out)
        
        # Apply downsample if needed
        if self.downsample is not None:
            identity = self.downsample(x)
        
        # Add skip connection
        out += identity
        out = F.relu(out)
        
        return out

#=============================================
# 2. Squeeze and Excitation Block
#=============================================
class SEBlock(nn.Module):
    """Squeeze-and-Excitation Block"""
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze operation
        y = self.avg_pool(x).view(b, c)
        # Excitation operation
        y = self.fc(y).view(b, c, 1, 1)
        # Scale the input
        return x * y.expand_as(x)

#=============================================
# 3. CBAM Components
#=============================================
class ChannelAttention(nn.Module):
    """Channel attention module for CBAM"""
    def __init__(self, channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Shared MLP
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False)
        )
        
    def forward(self, x):
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        out = torch.sigmoid(avg_out + max_out)
        return out

class SpatialAttention(nn.Module):
    """Spatial attention module for CBAM"""
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "Kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = torch.sigmoid(self.conv(out))
        return out

class CBAMBlock(nn.Module):
    """Convolutional Block Attention Module"""
    def __init__(self, channels, reduction=16, kernel_size=7):
        super(CBAMBlock, self).__init__()
        self.channel_att = ChannelAttention(channels, reduction)
        self.spatial_att = SpatialAttention(kernel_size)
        
    def forward(self, x):
        # Channel attention first
        channel_att = self.channel_att(x)
        x = x * channel_att
        
        # Then spatial attention
        spatial_att = self.spatial_att(x)
        x = x * spatial_att
        
        # Return result and spatial attention map for visualization
        return x, spatial_att

#=============================================
# 4. Deep CNN Base Network
#=============================================
class DeepCNN(nn.Module):
    """Base class for deep CNN architectures"""
    def __init__(self, block, num_blocks, attention_type=None, num_classes=10):
        super(DeepCNN, self).__init__()
        self.attention_type = attention_type
        self.in_channels = 64
        
        # Initial layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # Residual layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Classification head
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        
        # Initialize weights
        self._initialize_weights()
        
        # For storing attention maps
        self.attention_maps = []
    
    def _make_layer(self, block, out_channels, num_blocks, stride):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        
        layers = []
        # First block might need downsampling
        layers.append(block(self.in_channels, out_channels, stride, downsample, 
                           attention_type=self.attention_type))
        
        # Update input channels for next layers
        self.in_channels = out_channels
        
        # Add remaining blocks
        for _ in range(1, num_blocks):
            layers.append(block(self.in_channels, out_channels, 1, None, 
                               attention_type=self.attention_type))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Clear previous attention maps
        self.attention_maps = []
        
        # Initial layer
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        
        # Residual layers
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        # Collect attention maps if using CBAM
        if self.attention_type == 'cbam':
            for name, module in self.named_modules():
                if isinstance(module, ResidualBlock) and hasattr(module, 'attention_map') and module.attention_map is not None:
                    self.attention_maps.append(module.attention_map)
        
        # Classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def get_attention_maps(self):
        """Method to access collected attention maps"""
        return self.attention_maps

#=============================================
# 5. Model Instantiation Functions
#=============================================
def DeepResNet18(num_classes=10):
    """Deep ResNet-18 without attention"""
    return DeepCNN(ResidualBlock, [2, 2, 2, 2], attention_type=None, num_classes=num_classes)

def DeepSENet18(num_classes=10):
    """Deep ResNet-18 with SE blocks"""
    return DeepCNN(ResidualBlock, [2, 2, 2, 2], attention_type='se', num_classes=num_classes)

def DeepCBAMNet18(num_classes=10):
    """Deep ResNet-18 with CBAM blocks"""
    return DeepCNN(ResidualBlock, [2, 2, 2, 2], attention_type='cbam', num_classes=num_classes)
#=============================================
# 6. Helper Functions
#=============================================
def count_parameters(model):
    """Count the number of trainable parameters in a model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def test_model_architectures():
    """Test the model architectures and compare parameter counts"""
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")
    
    # Create ResNet-18 variants
    base_model = DeepResNet18()
    se_model = DeepSENet18()
    cbam_model = DeepCBAMNet18()
    
    # Test forward pass
    x = torch.randn(1, 3, 32, 32).to(device)
    
    # Move models to device
    base_model = base_model.to(device)
    se_model = se_model.to(device)
    cbam_model = cbam_model.to(device)
    
    # Evaluation mode
    base_model.eval()
    se_model.eval()
    cbam_model.eval()
    
    # Forward passes
    with torch.no_grad():
        # ResNet-18 variants
        print("\nTesting ResNet-18 variants:")
        print("Basic ResNet-18:")
        out_base = base_model(x)
        print(f"Input shape: {x.shape}, Output shape: {out_base.shape}")
        
        print("\nSE-ResNet-18:")
        out_se = se_model(x)
        print(f"Input shape: {x.shape}, Output shape: {out_se.shape}")
        
        print("\nCBAM-ResNet-18:")
        out_cbam = cbam_model(x)
        print(f"Input shape: {x.shape}, Output shape: {out_cbam.shape}")
        
        # Parameter counts
        print("\nParameter counts:")
        base_params = count_parameters(base_model)
        se_params = count_parameters(se_model)
        cbam_params = count_parameters(cbam_model)
        
        print(f"ResNet-18: {base_params:,} parameters")
        print(f"SE-ResNet-18: {se_params:,} parameters")
        print(f"CBAM-ResNet-18: {cbam_params:,} parameters")
        
        print(f"SE-ResNet-18 has {(se_params - base_params) / base_params * 100:.2f}% more parameters than base")
        print(f"CBAM-ResNet-18 has {(cbam_params - base_params) / base_params * 100:.2f}% more parameters than base")
    
test_model_architectures()

Using device: cuda

Testing ResNet-18 variants:
Basic ResNet-18:
Input shape: torch.Size([1, 3, 32, 32]), Output shape: torch.Size([1, 10])

SE-ResNet-18:
Input shape: torch.Size([1, 3, 32, 32]), Output shape: torch.Size([1, 10])

CBAM-ResNet-18:
Input shape: torch.Size([1, 3, 32, 32]), Output shape: torch.Size([1, 10])

Parameter counts:
ResNet-18: 11,173,962 parameters
SE-ResNet-18: 11,261,002 parameters
CBAM-ResNet-18: 11,261,786 parameters
SE-ResNet-18 has 0.78% more parameters than base
CBAM-ResNet-18 has 0.79% more parameters than base


In [13]:
import torch
from torchviz import make_dot

# Create models
base_model = DeepResNet18()
se_model = DeepSENet18()
cbam_model = DeepCBAMNet18()

# Create sample input
x = torch.randn(1, 3, 32, 32)

# Visualize base model
y_base = base_model(x)
make_dot(y_base, params=dict(base_model.named_parameters())).render("base_model", format="png")

# Visualize SE model
y_se = se_model(x)
make_dot(y_se, params=dict(se_model.named_parameters())).render("se_model", format="png")

# Visualize CBAM model
y_cbam = cbam_model(x)
make_dot(y_cbam, params=dict(cbam_model.named_parameters())).render("cbam_model", format="png")

'cbam_model.png'