# CNN for Multi-Class Image Classification

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/maheshghanta/Codes/blob/master/PyTorch_Tutorials/4.MultiClass_Classification.ipynb)

This tutorial demonstrates building a **Convolutional Neural Network (CNN)** for CIFAR-10 image classification with:
- CNN architecture with convolutional and pooling layers
- Complete training pipeline
- TensorBoard logging
- Comprehensive metrics and visualization

## Overview

We'll build a **Convolutional Neural Network (CNN)** - designed specifically for image data - to classify CIFAR-10 images.

**Why CNN over MLP?**
- **Spatial awareness**: Preserves 2D structure of images
- **Translation invariance**: Recognizes patterns anywhere in image
- **Parameter efficiency**: Shared weights reduce parameters
- **Feature hierarchy**: Builds from edges → shapes → objects

**Architecture Components:**
- Convolutional layers: Extract spatial features
- Pooling layers: Reduce spatial dimensions
- Batch normalization: Stabilize training
- Fully connected layers: Final classification

## Setup and Imports

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0
CUDA available: False


## 1. Load and Prepare Data with Augmentation

CNNs benefit greatly from data augmentation!

In [2]:
# Data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # Flip horizontally
    transforms.RandomCrop(32, padding=4),     # Random crop with padding
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Color variations
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# No augmentation for validation/test
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load CIFAR-10 dataset
train_dataset_full = datasets.CIFAR10(
    root='./data',
    train=True,
    download=True,
    transform=train_transform
)

test_dataset = datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=test_transform
)

# Split training data into train and validation
train_size = int(0.8 * len(train_dataset_full))
val_size = len(train_dataset_full) - train_size

train_dataset, val_dataset = random_split(
    train_dataset_full,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Classes: {test_dataset.classes}")

Training samples: 40000
Validation samples: 10000
Test samples: 10000
Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [3]:
# Create DataLoaders
batch_size = 128

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

# Check data shape
sample_image, sample_label = next(iter(train_loader))
print(f"Batch shape: {sample_image.shape}")  # (batch_size, channels, height, width)
print(f"Label shape: {sample_label.shape}")
print(f"Image size: {sample_image[0].shape} = (C, H, W)")



Batch shape: torch.Size([128, 3, 32, 32])
Label shape: torch.Size([128])
Image size: torch.Size([3, 32, 32]) = (C, H, W)


## 2. Define CNN Model

We'll create a CNN with multiple convolutional blocks followed by fully connected layers.

In [4]:
class CIFAR10_CNN(nn.Module):
    """
    Convolutional Neural Network for CIFAR-10 classification
    
    Architecture:
    - Conv Block 1: 3 -> 64 channels
    - Conv Block 2: 64 -> 128 channels  
    - Conv Block 3: 128 -> 256 channels
    - Fully Connected: 256 -> 10 classes
    """
    
    def __init__(self, num_classes=10):
        super(CIFAR10_CNN, self).__init__()
        
        # Convolutional Block 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 32x32 -> 16x16
        )
        
        # Convolutional Block 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 16x16 -> 8x8
        )
        
        # Convolutional Block 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)  # 8x8 -> 4x4
        )
        
        # Global Average Pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        # Fully Connected Layers
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        # Convolutional feature extraction
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        
        # Global pooling
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)  # Flatten
        
        # Classification
        x = self.classifier(x)
        
        return x

# Create model
model = CIFAR10_CNN()
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

CIFAR10_CNN(
  (conv1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=Fals

## 3. Setup Training Components

In [6]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move model to device
model = model.to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer (Adam with weight decay)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

# Learning rate scheduler (ReduceLROnPlateau - reduces LR when validation loss plateaus)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

print(f"Loss function: {criterion}")
print(f"Optimizer: {optimizer.__class__.__name__}")
print(f"Initial learning rate: {optimizer.param_groups[0]['lr']}")
print(f"Scheduler: ReduceLROnPlateau")

Using device: cpu
Loss function: CrossEntropyLoss()
Optimizer: Adam
Initial learning rate: 0.001
Scheduler: ReduceLROnPlateau


## 4. Setup TensorBoard Logging

In [7]:
# Create TensorBoard writer with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f'runs/cnn_cifar10_{timestamp}'
writer = SummaryWriter(log_dir)

print(f"TensorBoard logs saved to: {log_dir}")
print(f"To view: tensorboard --logdir=runs")

# Log model architecture
sample_input = torch.randn(1, 3, 32, 32).to(device)
writer.add_graph(model, sample_input)
print("Model graph added to TensorBoard")

TensorBoard logs saved to: runs/cnn_cifar10_20251029_194730
To view: tensorboard --logdir=runs
Model graph added to TensorBoard


## 5. Training Functions

In [8]:
def train_one_epoch(model, loader, criterion, optimizer, device, epoch, writer):
    """
    Train the model for one epoch
    
    Returns:
        average_loss: Average training loss
        accuracy: Training accuracy (%)
    """
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(loader):
        # Move data to device
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Log batch metrics to TensorBoard
        if batch_idx % 50 == 0:
            writer.add_scalar('Train/BatchLoss', loss.item(), epoch * len(loader) + batch_idx)
            batch_acc = 100. * correct / total
            writer.add_scalar('Train/BatchAccuracy', batch_acc, epoch * len(loader) + batch_idx)
            
            if batch_idx % 100 == 0:
                print(f'  Batch [{batch_idx}/{len(loader)}] | '
                      f'Loss: {loss.item():.4f} | Acc: {batch_acc:.2f}%')
    
    avg_loss = running_loss / len(loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy


@torch.no_grad()
def validate(model, loader, criterion, device, epoch, writer, phase='Validation'):
    """
    Validate the model
    
    Returns:
        average_loss: Average validation loss
        accuracy: Validation accuracy (%)
        all_preds: All predictions
        all_labels: All true labels
    """
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = running_loss / len(loader)
    accuracy = 100. * correct / total
    
    # Log epoch metrics to TensorBoard
    writer.add_scalar(f'{phase}/Loss', avg_loss, epoch)
    writer.add_scalar(f'{phase}/Accuracy', accuracy, epoch)
    
    return avg_loss, accuracy, all_preds, all_labels

print("Training functions defined successfully!")

Training functions defined successfully!


## 6. Training Loop

In [None]:
# Training configuration
num_epochs = 30
best_val_acc = 0.0

# Store metrics for plotting
train_losses = []
train_accs = []
val_losses = []
val_accs = []
learning_rates = []

print(f"Starting training for {num_epochs} epochs...")
print("=" * 80)

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print("-" * 80)
    
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer, device, epoch, writer
    )
    
    # Validate
    val_loss, val_acc, val_preds, val_labels = validate(
        model, val_loader, criterion, device, epoch, writer, 'Validation'
    )
    
    # Update learning rate based on validation loss
    scheduler.step(val_loss)
    current_lr = optimizer.param_groups[0]['lr']
    writer.add_scalar('Train/LearningRate', current_lr, epoch)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    learning_rates.append(current_lr)
    
    # Print epoch summary
    print(f"\nTrain Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")
    print(f"Learning Rate: {current_lr:.6f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss,
        }, 'best_cnn_model.pth')
        print(f"✓ Best model saved! (Val Acc: {val_acc:.2f}%)")
    
    # Early stopping check
    if current_lr < 1e-6:
        print("\nLearning rate too small. Early stopping...")
        break

print("\n" + "=" * 80)
print(f"Training completed! Best validation accuracy: {best_val_acc:.2f}%")
writer.close()

Starting training for 30 epochs...

Epoch 1/30
--------------------------------------------------------------------------------
  Batch [0/313] | Loss: 2.3665 | Acc: 7.03%
  Batch [100/313] | Loss: 1.6686 | Acc: 29.04%
  Batch [200/313] | Loss: 1.4363 | Acc: 35.20%
  Batch [300/313] | Loss: 1.4569 | Acc: 39.78%

Train Loss: 1.5749 | Train Acc: 40.34%
Val Loss: 1.4776 | Val Acc: 47.62%
Learning Rate: 0.001000
✓ Best model saved! (Val Acc: 47.62%)

Epoch 2/30
--------------------------------------------------------------------------------
  Batch [0/313] | Loss: 1.3115 | Acc: 53.12%
  Batch [100/313] | Loss: 1.1527 | Acc: 55.31%
  Batch [200/313] | Loss: 1.2211 | Acc: 57.32%
  Batch [300/313] | Loss: 0.9770 | Acc: 58.28%

Train Loss: 1.1580 | Train Acc: 58.38%
Val Loss: 1.1390 | Val Acc: 60.70%
Learning Rate: 0.001000
✓ Best model saved! (Val Acc: 60.70%)

Epoch 3/30
--------------------------------------------------------------------------------
  Batch [0/313] | Loss: 1.0177 | Acc: 66.

## 7. Plot Training Metrics

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs_range = range(1, len(train_losses)+1)

# Loss plot
axes[0].plot(epochs_range, train_losses, 'b-', label='Train Loss', marker='o', markersize=4)
axes[0].plot(epochs_range, val_losses, 'r-', label='Val Loss', marker='s', markersize=4)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(epochs_range, train_accs, 'b-', label='Train Acc', marker='o', markersize=4)
axes[1].plot(epochs_range, val_accs, 'r-', label='Val Acc', marker='s', markersize=4)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

# Learning rate plot
axes[2].plot(epochs_range, learning_rates, 'g-', marker='d', markersize=4)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('Learning Rate', fontsize=12)
axes[2].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
axes[2].set_yscale('log')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('cnn_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Training curves saved to 'cnn_training_curves.png'")

## 8. Test Set Evaluation

In [None]:
# Load best model
checkpoint = torch.load('best_cnn_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.2f}%")

# Evaluate on test set
test_loss, test_acc, test_preds, test_labels = validate(
    model, test_loader, criterion, device, 0, writer, 'Test'
)

print("\n" + "=" * 80)
print("TEST SET RESULTS")
print("=" * 80)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")
print("=" * 80)

# Compare with previous MLP
print("\n📊 Performance Comparison:")
print(f"  CNN (this model): ~{test_acc:.1f}%")
print(f"  MLP (previous):   ~50-55%")
print(f"  Improvement:      ~{test_acc - 52.5:.1f}%")

## 9. Per-Class Accuracy

In [None]:
# Calculate per-class accuracy
from collections import defaultdict

class_correct = defaultdict(int)
class_total = defaultdict(int)

for pred, label in zip(test_preds, test_labels):
    if pred == label:
        class_correct[label] += 1
    class_total[label] += 1

classes = test_dataset.classes

print("\nPer-Class Accuracy:")
print("-" * 50)
class_accs = []
for i, class_name in enumerate(classes):
    acc = 100.0 * class_correct[i] / class_total[i]
    class_accs.append(acc)
    print(f"{class_name:12s}: {acc:6.2f}% ({class_correct[i]}/{class_total[i]})")

# Plot per-class accuracy
plt.figure(figsize=(14, 6))
bars = plt.bar(range(len(classes)), class_accs, color='steelblue', edgecolor='navy', linewidth=1.5)
plt.xlabel('Class', fontsize=12)
plt.ylabel('Accuracy (%)', fontsize=12)
plt.title('Per-Class Accuracy on Test Set (CNN)', fontsize=14, fontweight='bold')
plt.xticks(range(len(classes)), classes, rotation=45, ha='right')
plt.ylim([0, 100])
plt.axhline(y=test_acc, color='r', linestyle='--', label=f'Overall Acc: {test_acc:.2f}%')
plt.grid(axis='y', alpha=0.3)
plt.legend(fontsize=10)

# Add value labels on bars
for bar, acc in zip(bars, class_accs):
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 1,
             f'{acc:.1f}%', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.savefig('cnn_per_class_accuracy.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nPer-class accuracy plot saved to 'cnn_per_class_accuracy.png'")

## 10. Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

# Compute confusion matrix
cm = confusion_matrix(test_labels, test_preds)

# Plot confusion matrix
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='YlOrRd', 
            xticklabels=classes, yticklabels=classes,
            cbar_kws={'label': 'Count'}, linewidths=0.5)
plt.xlabel('Predicted Label', fontsize=12)
plt.ylabel('True Label', fontsize=12)
plt.title('Confusion Matrix - CNN on CIFAR-10 Test Set', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('cnn_confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print("Confusion matrix saved to 'cnn_confusion_matrix.png'")

# Print classification report
print("\nClassification Report:")
print("=" * 80)
print(classification_report(test_labels, test_preds, target_names=classes, digits=4))

## 11. Sample Predictions with Confidence

In [None]:
# Get a batch of test images
model.eval()
test_images, test_labels_batch = next(iter(test_loader))
test_images_device = test_images.to(device)

with torch.no_grad():
    outputs = model(test_images_device)
    probabilities = torch.softmax(outputs, dim=1)
    confidences, predictions = probabilities.max(1)

# Move to CPU for visualization
test_images = test_images.cpu()
predictions = predictions.cpu()
confidences = confidences.cpu()

# Visualize first 16 predictions
fig, axes = plt.subplots(4, 4, figsize=(14, 14))

for i, ax in enumerate(axes.flat):
    # Denormalize image
    img = test_images[i].permute(1, 2, 0).numpy()
    img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    img = np.clip(img, 0, 1)
    
    ax.imshow(img)
    
    # Color: green if correct, red if wrong
    true_label = classes[test_labels_batch[i]]
    pred_label = classes[predictions[i]]
    confidence = confidences[i].item() * 100
    
    is_correct = test_labels_batch[i] == predictions[i]
    color = 'green' if is_correct else 'red'
    
    ax.set_title(f'True: {true_label}\nPred: {pred_label}\nConf: {confidence:.1f}%', 
                 color=color, fontsize=9, fontweight='bold')
    ax.axis('off')

plt.suptitle('CNN Predictions (Green=Correct, Red=Wrong)', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('cnn_sample_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

print("Sample predictions saved to 'cnn_sample_predictions.png'")

## 12. Feature Maps Visualization

In [None]:
# Visualize feature maps from first convolutional layer
def visualize_feature_maps(model, image, layer_name='conv1'):
    """Visualize feature maps from a specific layer"""
    
    # Get the layer
    if layer_name == 'conv1':
        layer = model.conv1[0]  # First conv layer
    elif layer_name == 'conv2':
        layer = model.conv2[0]
    elif layer_name == 'conv3':
        layer = model.conv3[0]
    
    # Register hook to capture feature maps
    activations = {}
    def hook_fn(module, input, output):
        activations['features'] = output.detach()
    
    handle = layer.register_forward_hook(hook_fn)
    
    # Forward pass
    model.eval()
    with torch.no_grad():
        _ = model(image.unsqueeze(0).to(device))
    
    handle.remove()
    
    return activations['features'].squeeze().cpu()

# Get a sample image
sample_img, sample_label = test_dataset[0]

# Visualize feature maps from conv1
feature_maps = visualize_feature_maps(model, sample_img, 'conv1')

# Plot first 32 feature maps
fig, axes = plt.subplots(4, 8, figsize=(16, 8))

for i, ax in enumerate(axes.flat):
    if i < feature_maps.shape[0]:
        ax.imshow(feature_maps[i], cmap='viridis')
        ax.set_title(f'Filter {i+1}', fontsize=8)
    ax.axis('off')

plt.suptitle('Feature Maps from First Convolutional Layer', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('cnn_feature_maps.png', dpi=150, bbox_inches='tight')
plt.show()

print("Feature maps saved to 'cnn_feature_maps.png'")

## Summary

### What We Accomplished:

**Built a CNN** - Convolutional Neural Network with multiple conv blocks

**Data Augmentation** - RandomFlip, RandomCrop, ColorJitter

**Complete Training Pipeline** - Train, validation, and test

**TensorBoard Logging** - Comprehensive metrics tracking

**Advanced Features** - BatchNorm, Dropout, Adaptive LR

**Comprehensive Evaluation** - Confusion matrix, per-class metrics

**Feature Visualization** - Understanding learned representations

### Performance Comparison:

| Model | Test Accuracy | Parameters |
|-------|--------------|------------|
| **MLP** (Tutorial 3) | ~50-55% | ~1.7M |
| **CNN** (This tutorial) | ~75-85% | ~1.5M |
| **Improvement** | **+25-30%** | Fewer! |

### Why CNN Performs Better:

1. **Spatial Structure**: Preserves 2D image structure
2. **Parameter Sharing**: Same filters applied across image
3. **Translation Invariance**: Recognizes patterns anywhere
4. **Hierarchical Features**: Low-level → Mid-level → High-level
5. **Data Augmentation**: Artificially increases training data

### Architecture Insights:

```
Input (3x32x32)
    ↓
Conv Block 1: 3→64 channels (32x32 → 16x16)
    ↓
Conv Block 2: 64→128 channels (16x16 → 8x8)
    ↓
Conv Block 3: 128→256 channels (8x8 → 4x4)
    ↓
Global Avg Pool → Flatten
    ↓
FC Layers: 256 → 512 → 10 classes
```

### Key Techniques Used:

- **Batch Normalization**: Stabilizes training
- **Dropout**: Prevents overfitting
- **Data Augmentation**: Improves generalization
- **Adaptive LR**: ReduceLROnPlateau scheduler
- **Global Average Pooling**: Reduces parameters

### Next Steps:

- Try ResNet, VGG, or other architectures
- Experiment with transfer learning
- Try different augmentation strategies
- Implement attention mechanisms
- Use mixed precision training for speed

### TensorBoard Commands:

```bash
# View training logs
tensorboard --logdir=runs

# Compare multiple runs
tensorboard --logdir=runs --port=6006
```