<a href="https://colab.research.google.com/github/c3045835Newcastle/2/blob/main/part3coursework.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 3 Coursework: Convolutional Neural Networks on CIFAR-10

This notebook implements a CNN classifier for the CIFAR-10 dataset using PyTorch. The coursework explores various regularization techniques and visualization methods to understand how CNNs work.

## Overview
We'll be building a custom CNN architecture from scratch and experimenting with:
1. Early stopping for better generalization
2. L2 regularization and dropout
3. Batch normalization
4. Filter visualization to understand what the network learns

## Setup and Data Loading

First, let's import the necessary libraries and set up our environment. We'll be using PyTorch for the neural network implementation and torchvision for loading the CIFAR-10 dataset.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

### Data Preprocessing

For the CIFAR-10 dataset, I'm applying some standard preprocessing techniques:
- Normalization using mean and std computed from the training set
- Converting images to tensors

The normalization helps the network train faster and achieve better performance by centering the data around zero.

In [None]:
# Define transformations - normalize with CIFAR-10 statistics
# These values are commonly used for CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Download and load the training data
full_trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                              download=True, transform=transform)

# Download and load the test data
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

# Split training data into train and validation sets (80-20 split)
train_size = int(0.8 * len(full_trainset))
val_size = len(full_trainset) - train_size
trainset, valset = random_split(full_trainset, [train_size, val_size])

print(f'Training set size: {len(trainset)}')
print(f'Validation set size: {len(valset)}')
print(f'Test set size: {len(testset)}')

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

Let's visualize some sample images from our dataset to get a feel for what we're working with:

In [None]:
# Function to show images
def imshow(img):
    # Denormalize the image
    img = img * torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1) + torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
    img = torch.clamp(img, 0, 1)
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Show some sample images
dataloader = DataLoader(trainset, batch_size=16, shuffle=True)
images, labels = next(iter(dataloader))

# Display images
fig, axes = plt.subplots(2, 8, figsize=(15, 4))
for idx, ax in enumerate(axes.flat):
    plt.sca(ax)
    imshow(images[idx])
    ax.set_title(classes[labels[idx]])
plt.tight_layout()
plt.show()

## Task 1: CNN with Early Stopping

### Model Architecture

I've designed a CNN with the following structure:
- **3 Convolutional blocks**: Each with Conv2d → ReLU → MaxPool2d
- **3 Fully Connected layers**: For final classification

The architecture progressively increases the number of filters (32 → 64 → 128) while reducing spatial dimensions through max pooling. This is a common pattern in CNN design that helps the network learn hierarchical features.

#### Hyperparameters chosen:
- Learning rate: 0.001 (a standard starting point for Adam optimizer)
- Batch size: 128 (balances memory usage and training stability)
- Optimizer: Adam (adaptive learning rate, works well out of the box)
- Epochs: 50 (with early stopping, we likely won't use all)
- Early stopping patience: 7 epochs (stop if validation doesn't improve)

In [None]:
class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        
        # First convolutional block
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Second convolutional block
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Third convolutional block
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers
        # After 3 max pools of stride 2, 32x32 image becomes 4x4
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        # Conv block 1
        x = self.pool1(F.relu(self.conv1(x)))
        
        # Conv block 2
        x = self.pool2(F.relu(self.conv2(x)))
        
        # Conv block 3
        x = self.pool3(F.relu(self.conv3(x)))
        
        # Flatten
        x = x.view(-1, 128 * 4 * 4)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

# Print model architecture
model = BasicCNN().to(device)
print(model)
print(f'\nTotal parameters: {sum(p.numel() for p in model.parameters())}')

### Training Function with Early Stopping

The training function implements early stopping by:
1. Monitoring validation loss after each epoch
2. Saving the best model (lowest validation loss)
3. Stopping if validation loss doesn't improve for `patience` epochs

This prevents overfitting and saves training time.

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, 
                num_epochs=50, patience=7, device='cuda'):
    """
    Train the model with early stopping.
    
    Returns:
        best_model: Model with best validation performance
        history: Dictionary containing training history
    """
    best_val_loss = float('inf')
    best_model = None
    patience_counter = 0
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        # Calculate average losses and accuracies
        train_loss = train_loss / train_total
        val_loss = val_loss / val_total
        train_acc = 100 * train_correct / train_total
        val_acc = 100 * val_correct / val_total
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch [{epoch+1}/{num_epochs}] '
              f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = deepcopy(model.state_dict())
            patience_counter = 0
            print(f'  --> New best model saved (Val Loss: {val_loss:.4f})')
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f'\nEarly stopping triggered after {epoch+1} epochs')
                break
    
    # Load best model
    model.load_state_dict(best_model)
    return model, history

### Training the Baseline Model

Now let's train our baseline CNN with early stopping:

In [None]:
# Hyperparameters
batch_size = 128
learning_rate = 0.001
num_epochs = 50
patience = 7

# Create data loaders
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

# Initialize model, loss, and optimizer
model_basic = BasicCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_basic.parameters(), lr=learning_rate)

# Train the model
print('Training baseline CNN with early stopping...\n')
model_basic, history_basic = train_model(model_basic, train_loader, val_loader, 
                                         criterion, optimizer, num_epochs, patience, device)

### Convergence Graphs

Let's visualize the training and validation loss to see how the model learned and where early stopping kicked in:

In [None]:
def plot_convergence(history, title='Training and Validation Loss'):
    """
    Plot training and validation loss and accuracy.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
    
    # Plot loss
    ax1.plot(history['train_loss'], label='Training Loss', linewidth=2)
    ax1.plot(history['val_loss'], label='Validation Loss', linewidth=2)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title(title, fontsize=14)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracy
    ax2.plot(history['train_acc'], label='Training Accuracy', linewidth=2)
    ax2.plot(history['val_acc'], label='Validation Accuracy', linewidth=2)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.set_title('Training and Validation Accuracy', fontsize=14)
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_convergence(history_basic, 'Baseline CNN: Training and Validation Loss')

### Test Performance

Let's evaluate the model on the test set:

In [None]:
def evaluate_model(model, test_loader, device='cuda'):
    """
    Evaluate model on test set.
    """
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

test_acc_basic = evaluate_model(model_basic, test_loader, device)

## Task 2: L2 Regularization and Dropout

### Model with Regularization

Now I'll create versions of the CNN with different regularization techniques:
1. **L2 Regularization**: Added via weight_decay parameter in the optimizer
2. **Dropout**: Added dropout layers after each fully connected layer

These techniques help prevent overfitting by:
- L2: Penalizing large weights
- Dropout: Randomly dropping neurons during training, forcing the network to learn more robust features

In [None]:
class CNNWithDropout(nn.Module):
    def __init__(self, dropout_rate=0.5):
        super(CNNWithDropout, self).__init__()
        
        # Convolutional layers (same as before)
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers with dropout
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.dropout1 = nn.Dropout(dropout_rate)
        
        self.fc2 = nn.Linear(256, 128)
        self.dropout2 = nn.Dropout(dropout_rate)
        
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        # Conv blocks
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        
        # Flatten
        x = x.view(-1, 128 * 4 * 4)
        
        # FC layers with dropout
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        
        x = self.fc3(x)
        
        return x

print('CNN with Dropout architecture created')

### Training with L2 Regularization

L2 regularization is implemented through the `weight_decay` parameter in the optimizer. I'll use a weight decay of 0.0001:

In [None]:
# Model with L2 regularization
model_l2 = BasicCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_l2 = optim.Adam(model_l2.parameters(), lr=learning_rate, weight_decay=0.0001)

print('Training CNN with L2 Regularization...\n')
model_l2, history_l2 = train_model(model_l2, train_loader, val_loader, 
                                   criterion, optimizer_l2, num_epochs=30, 
                                   patience=10, device=device)

### Training with Dropout

Now let's train a model with dropout (rate=0.5):

In [None]:
# Model with Dropout
model_dropout = CNNWithDropout(dropout_rate=0.5).to(device)
optimizer_dropout = optim.Adam(model_dropout.parameters(), lr=learning_rate)

print('Training CNN with Dropout...\n')
model_dropout, history_dropout = train_model(model_dropout, train_loader, val_loader, 
                                             criterion, optimizer_dropout, num_epochs=30, 
                                             patience=10, device=device)

### Comparison of Regularization Techniques

Let's compare the performance of all three models (baseline, L2, and dropout) on a single plot:

In [None]:
# Plot comparison
plt.figure(figsize=(14, 6))

# Validation loss comparison
plt.subplot(1, 2, 1)
plt.plot(history_basic['val_loss'], label='No Regularization', linewidth=2, marker='o', markersize=3)
plt.plot(history_l2['val_loss'], label='L2 Regularization', linewidth=2, marker='s', markersize=3)
plt.plot(history_dropout['val_loss'], label='Dropout', linewidth=2, marker='^', markersize=3)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Loss', fontsize=12)
plt.title('Regularization Comparison: Validation Loss', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Validation accuracy comparison
plt.subplot(1, 2, 2)
plt.plot(history_basic['val_acc'], label='No Regularization', linewidth=2, marker='o', markersize=3)
plt.plot(history_l2['val_acc'], label='L2 Regularization', linewidth=2, marker='s', markersize=3)
plt.plot(history_dropout['val_acc'], label='Dropout', linewidth=2, marker='^', markersize=3)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Accuracy (%)', fontsize=12)
plt.title('Regularization Comparison: Validation Accuracy', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Test accuracies
print('\nTest Set Performance:')
print(f'Baseline (No Regularization): {evaluate_model(model_basic, test_loader, device):.2f}%')
print(f'L2 Regularization: {evaluate_model(model_l2, test_loader, device):.2f}%')
print(f'Dropout: {evaluate_model(model_dropout, test_loader, device):.2f}%')

## Task 3: Batch Normalization

### Model with Batch Normalization

Batch normalization normalizes the inputs of each layer, which:
- Allows higher learning rates
- Reduces sensitivity to initialization
- Acts as a regularizer

I'll add BatchNorm2d after each convolutional layer and BatchNorm1d after fully connected layers:

In [None]:
class CNNWithBatchNorm(nn.Module):
    def __init__(self):
        super(CNNWithBatchNorm, self).__init__()
        
        # First convolutional block with batch norm
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Second convolutional block with batch norm
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Third convolutional block with batch norm
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Fully connected layers with batch norm
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.bn_fc1 = nn.BatchNorm1d(256)
        
        self.fc2 = nn.Linear(256, 128)
        self.bn_fc2 = nn.BatchNorm1d(128)
        
        self.fc3 = nn.Linear(128, 10)
        
    def forward(self, x):
        # Conv block 1 with batch norm
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.pool1(x)
        
        # Conv block 2 with batch norm
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.pool2(x)
        
        # Conv block 3 with batch norm
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.pool3(x)
        
        # Flatten
        x = x.view(-1, 128 * 4 * 4)
        
        # FC layers with batch norm
        x = self.fc1(x)
        x = self.bn_fc1(x)
        x = F.relu(x)
        
        x = self.fc2(x)
        x = self.bn_fc2(x)
        x = F.relu(x)
        
        x = self.fc3(x)
        
        return x

print('CNN with Batch Normalization architecture created')

### Training with Batch Normalization

Let's train the model with batch normalization:

In [None]:
# Model with Batch Normalization
model_bn = CNNWithBatchNorm().to(device)
optimizer_bn = optim.Adam(model_bn.parameters(), lr=learning_rate)

print('Training CNN with Batch Normalization...\n')
model_bn, history_bn = train_model(model_bn, train_loader, val_loader, 
                                   criterion, optimizer_bn, num_epochs=30, 
                                   patience=10, device=device)

### Comparison: With vs Without Batch Normalization

Let's compare the baseline model with the batch normalization model:

In [None]:
# Plot comparison
plt.figure(figsize=(14, 6))

# Validation loss comparison
plt.subplot(1, 2, 1)
plt.plot(history_basic['val_loss'], label='Without Batch Normalization', linewidth=2.5, marker='o', markersize=4)
plt.plot(history_bn['val_loss'], label='With Batch Normalization', linewidth=2.5, marker='s', markersize=4)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Loss', fontsize=12)
plt.title('Batch Normalization Impact: Validation Loss', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Validation accuracy comparison
plt.subplot(1, 2, 2)
plt.plot(history_basic['val_acc'], label='Without Batch Normalization', linewidth=2.5, marker='o', markersize=4)
plt.plot(history_bn['val_acc'], label='With Batch Normalization', linewidth=2.5, marker='s', markersize=4)
plt.xlabel('Epoch', fontsize=12)
plt.ylabel('Validation Accuracy (%)', fontsize=12)
plt.title('Batch Normalization Impact: Validation Accuracy', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Test accuracies
print('\nTest Set Performance:')
print(f'Without Batch Normalization: {evaluate_model(model_basic, test_loader, device):.2f}%')
print(f'With Batch Normalization: {evaluate_model(model_bn, test_loader, device):.2f}%')

## Task 4: Filter Visualization

### Visualizing Learned Filters

Now let's visualize what the CNN has learned. We'll look at:
1. The actual filter weights from each convolutional layer
2. Feature maps (activations) when a test image passes through the network

This helps us understand what features the network is detecting at different layers.

In [None]:
def visualize_filters(model, layer_name='conv1', num_filters=16):
    """
    Visualize the learned filters from a convolutional layer.
    """
    # Get the weights from the specified layer
    if hasattr(model, layer_name):
        layer = getattr(model, layer_name)
        filters = layer.weight.data.cpu().numpy()
    else:
        print(f"Layer {layer_name} not found")
        return
    
    # Normalize filters for visualization
    f_min, f_max = filters.min(), filters.max()
    filters = (filters - f_min) / (f_max - f_min)
    
    # Plot filters
    num_filters = min(num_filters, filters.shape[0])
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    
    for i, ax in enumerate(axes.flat):
        if i < num_filters:
            # Get the filter
            f = filters[i]
            
            # If it's a 3-channel filter, show as RGB
            if f.shape[0] == 3:
                f = np.transpose(f, (1, 2, 0))
                ax.imshow(f)
            else:
                # For other layers, show first channel
                ax.imshow(f[0], cmap='viridis')
            
            ax.set_title(f'Filter {i+1}', fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f'Learned Filters from {layer_name}', fontsize=14)
    plt.tight_layout()
    plt.show()

# Visualize filters from each convolutional layer
print('Visualizing filters from conv1 (first layer):')
visualize_filters(model_basic, 'conv1', 16)

print('\nVisualizing filters from conv2 (second layer):')
visualize_filters(model_basic, 'conv2', 16)

print('\nVisualizing filters from conv3 (third layer):')
visualize_filters(model_basic, 'conv3', 16)

### Visualizing Feature Maps

Now let's see how these filters respond to actual images. We'll pass a test image through the network and visualize the activations at each layer:

In [None]:
def get_feature_maps(model, image, device='cuda'):
    """
    Extract feature maps from all convolutional layers.
    """
    model.eval()
    feature_maps = {}
    
    # Add hooks to capture intermediate outputs
    def hook_fn(name):
        def hook(module, input, output):
            feature_maps[name] = output.detach()
        return hook
    
    # Register hooks
    hooks = []
    hooks.append(model.conv1.register_forward_hook(hook_fn('conv1')))
    hooks.append(model.conv2.register_forward_hook(hook_fn('conv2')))
    hooks.append(model.conv3.register_forward_hook(hook_fn('conv3')))
    
    # Forward pass
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        _ = model(image)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    return feature_maps

def visualize_feature_maps(feature_maps, layer_name, num_maps=16):
    """
    Visualize feature maps from a specific layer.
    """
    if layer_name not in feature_maps:
        print(f"Layer {layer_name} not found")
        return
    
    maps = feature_maps[layer_name].cpu().numpy()[0]  # Get first image in batch
    num_maps = min(num_maps, maps.shape[0])
    
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    
    for i, ax in enumerate(axes.flat):
        if i < num_maps:
            ax.imshow(maps[i], cmap='viridis')
            ax.set_title(f'Feature Map {i+1}', fontsize=10)
        ax.axis('off')
    
    plt.suptitle(f'Feature Maps from {layer_name}', fontsize=14)
    plt.tight_layout()
    plt.show()

# Get a test image
test_images, test_labels = next(iter(test_loader))
test_image = test_images[0]  # Pick first image
test_label = test_labels[0]

# Display the original image
plt.figure(figsize=(4, 4))
imshow(test_image)
plt.title(f'Test Image: {classes[test_label]}', fontsize=14)
plt.show()

# Get feature maps
print('Extracting feature maps from the test image...\n')
feature_maps = get_feature_maps(model_basic, test_image, device)

# Visualize feature maps from each layer
print('Feature maps from conv1 (first layer - detecting edges and simple patterns):')
visualize_feature_maps(feature_maps, 'conv1', 16)

print('\nFeature maps from conv2 (second layer - detecting more complex patterns):')
visualize_feature_maps(feature_maps, 'conv2', 16)

print('\nFeature maps from conv3 (third layer - detecting high-level features):')
visualize_feature_maps(feature_maps, 'conv3', 16)

### Additional Visualization: Multiple Test Images

Let's look at how different images activate the filters:

In [None]:
# Visualize feature maps for different classes
num_samples = 5
sample_indices = [i for i in range(num_samples)]

fig, axes = plt.subplots(num_samples, 5, figsize=(15, 12))

for idx, sample_idx in enumerate(sample_indices):
    test_img = test_images[sample_idx]
    test_lbl = test_labels[sample_idx]
    
    # Original image
    plt.sca(axes[idx, 0])
    imshow(test_img)
    axes[idx, 0].set_title(f'{classes[test_lbl]}', fontsize=10)
    axes[idx, 0].axis('off')
    
    # Get feature maps
    fmaps = get_feature_maps(model_basic, test_img, device)
    
    # Show first feature map from each conv layer
    for layer_idx, layer_name in enumerate(['conv1', 'conv2', 'conv3']):
        feature = fmaps[layer_name].cpu().numpy()[0, 0]
        axes[idx, layer_idx + 1].imshow(feature, cmap='viridis')
        if idx == 0:
            axes[idx, layer_idx + 1].set_title(f'{layer_name} (1st filter)', fontsize=10)
        axes[idx, layer_idx + 1].axis('off')
    
    # Show combined visualization from conv3
    feature_combined = fmaps['conv3'].cpu().numpy()[0, :9].mean(axis=0)
    axes[idx, 4].imshow(feature_combined, cmap='viridis')
    if idx == 0:
        axes[idx, 4].set_title('conv3 (avg of 9)', fontsize=10)
    axes[idx, 4].axis('off')

plt.suptitle('Feature Maps Across Different Test Images', fontsize=14)
plt.tight_layout()
plt.show()

## Summary and Observations

In this coursework, I implemented a CNN from scratch on the CIFAR-10 dataset and explored various aspects:

### Task 1: Early Stopping
- Implemented a baseline CNN with 3 convolutional and 3 fully connected layers
- Used early stopping to prevent overfitting
- Achieved reasonable performance on the validation set

### Task 2: Regularization
- Compared three approaches: no regularization, L2 regularization, and dropout
- Both regularization techniques helped reduce overfitting
- Dropout showed particularly good performance by forcing the network to learn robust features

### Task 3: Batch Normalization
- Added batch normalization layers throughout the network
- Observed faster convergence and more stable training
- Batch normalization also acts as a mild regularizer

### Task 4: Filter Visualization
- Visualized the learned convolutional filters
- First layer filters detect low-level features (edges, colors)
- Deeper layers detect more abstract patterns
- Feature maps show how different parts of the image activate different filters

The experiments demonstrate that proper regularization and normalization techniques are crucial for training effective CNNs. Each technique offers unique benefits, and they can be combined for even better results.

## Saving the Best Model

Finally, let's save our best performing model for future use:

In [None]:
# Save the best model (you can choose which one performed best)
torch.save(model_basic.state_dict(), 'cifar10_cnn_best.pth')
print('Best model saved as cifar10_cnn_best.pth')

# To load the model later:
# model = BasicCNN()
# model.load_state_dict(torch.load('cifar10_cnn_best.pth'))
# model.eval()