# Assignment 2: Neural Network Architecture and Optimization

In this assignment, you will implement and experiment with different neural network architectures for MNIST digit classification. Building on the concepts from Recitation 2, you'll explore how architecture choices affect performance and prepare models for adversarial testing.

**Instructions**

1. Complete all exercises in this notebook
2. Ensure all code runs without errors
3. Include written responses where requested
4. Save your best model for use in Part 4
5. Submit the completed notebook


## Setup

Run the following code to set up the notebook

In [None]:
# Setup and imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torchvision import datasets

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm import tqdm
import time
from collections import defaultdict

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load MNIST data (you may reuse code from recitation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

# Create data loaders with batch_size=64
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

print("✅ Setup complete!")


## Exercise 1: Architecture Comparison 

Compare different CNN architectures and analyze their performance trade-offs.

You will implement three different architectures:
1. **SimpleNet**: A minimal CNN (2 conv layers)
2. **MediumNet**: A balanced CNN (similar to recitation)  
3. **DeepNet**: A deeper CNN (4+ conv layers)

For each architecture, measure:
- Training time per epoch
- Final test accuracy
- Number of parameters
- Memory usage

### Part A: Implement SimpleNet 


In [None]:
class SimpleNet(nn.Module):
    """
    Simple CNN with minimal layers for MNIST classification.
    
    Architecture Requirements:
    - 2 convolutional layers
    - 1 fully connected layer (plus output layer)
    - Use ReLU activations and max pooling
    - Aim for < 50,000 parameters
    """
    
    def __init__(self):
        super(SimpleNet, self).__init__()
        
        # TODO: Implement a simple CNN architecture
    # TODO: Implement your solution here
        
        # TODO: Define your layers here
        # Hint: Start with nn.Conv2d(1, ...) for grayscale input
        
    def forward(self, x):
        # TODO: Implement forward pass
    # TODO: Implement your solution here
        
        # TODO: Implement your forward pass here
        pass

simple_model = SimpleNet().to(device)

total_params = sum(p.numel() for p in simple_model.parameters())
print(f"SimpleNet - Total parameters: {total_params:,}")
print(f"Model summary:\n{simple_model}")

# Test with dummy input
dummy_input = torch.randn(1, 1, 28, 28).to(device)
with torch.no_grad():
    dummy_output = simple_model(dummy_input)
    print(f"Output shape: {dummy_output.shape}")

### Part B: Implement DeepNet

Create a deeper network with 4+ convolutional layers. Pay attention to managing the spatial dimensions as you add more layers.


In [None]:
class DeepNet(nn.Module):
    """
    Deep CNN with 4+ convolutional layers for MNIST classification.
    
    Architecture Requirements:
    - At least 4 convolutional layers
    - Use batch normalization to help training
    - Include dropout for regularization
    - Manage spatial dimensions carefully
    """
    
    def __init__(self):
        super(DeepNet, self).__init__()
        
        # TODO: Implement a deep CNN architecture
    # TODO: Implement your solution here
        
        # TODO: Define your deep architecture here
        # Hint: Use batch normalization: nn.BatchNorm2d(channels)
        
    def forward(self, x):
        # TODO: Implement forward pass
    # TODO: Implement your solution here
        
        # TODO: Implement your forward pass here
        pass

deep_model = DeepNet().to(device)


deep_params = sum(p.numel() for p in deep_model.parameters())
print(f"DeepNet - Total parameters: {deep_params:,}")

# Test forward pass
dummy_input = torch.randn(1, 1, 28, 28).to(device)
with torch.no_grad():
    dummy_output = deep_model(dummy_input)
    print(f"Output shape: {dummy_output.shape}")


### Part C: Architecture Comparison Analysis 

Train all three models and compare their performance. Create a comparison table and discuss the trade-offs.


In [None]:
def train_model(model, train_loader, criterion, optimizer, num_epochs=3):
    """Train model and return metrics."""
    model.train()
    history = {'loss': [], 'accuracy': [], 'time': []}
    
    for epoch in range(num_epochs):
        start_time = time.time()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for data, targets in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
            data, targets = data.to(device), targets.to(device)
            
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
        
        epoch_time = time.time() - start_time
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        
        history['loss'].append(epoch_loss)
        history['accuracy'].append(epoch_acc)
        history['time'].append(epoch_time)
        
        print(f'Epoch {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%, Time={epoch_time:.1f}s')
    
    return history
    pass

models = {
    'SimpleNet': simple_model,
    'DeepNet': deep_model,
    # Add MediumNet from recitation if you want
}

results = {}
criterion = nn.CrossEntropyLoss()

for name, model in models.items():
    print(f"\n{'='*50}")
    print(f"Training {name}")
    print(f"{'='*50}")
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    history = train_model(model, train_loader, criterion, optimizer, num_epochs=2)
    
    # Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    test_acc = 100 * correct / total
    params = sum(p.numel() for p in model.parameters())
    avg_epoch_time = np.mean(history['time'])
    
    results[name] = {
        'test_accuracy': test_acc,
        'parameters': params,
        'avg_epoch_time': avg_epoch_time,
        'final_train_acc': history['accuracy'][-1]
    }
    
    print(f"Final test accuracy: {test_acc:.2f}%")


# Exercise 2: Model Analysis

Let's analyze your best model's performance, identify failure cases, and prepare for adversarial testing.


Identify which digits your model struggles with most and why.

In [None]:
# Feel free to change this to your best model, or keep the deep_model
best_model = deep_model

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns

best_model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, targets in test_loader:
        data, targets = data.to(device), targets.to(device)
        outputs = best_model(data)
        _, predicted = torch.max(outputs, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(targets.cpu().numpy())

# Create confusion matrix
cm = confusion_matrix(all_targets, all_predictions)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=range(10), yticklabels=range(10))
plt.title('Confusion Matrix - Best Model')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

# Print classification report
print("Classification Report:")
print(classification_report(all_targets, all_predictions, 
                          target_names=[str(i) for i in range(10)]))

# Analyze most confused pairs
print("\nMost Confused Digit Pairs:")
for i in range(10):
    for j in range(10):
        if i != j and cm[i][j] > 20:  # Threshold for significant confusion
            print(f"True {i} predicted as {j}: {cm[i][j]} times ({cm[i][j]/cm[i].sum()*100:.1f}%)")

Now let's see what the worst performing examples are.

In [None]:
def find_worst_examples(model, dataset, num_examples=12):
    """Find examples where model is most confident but wrong."""
    # TODO: Implement this function
    # TODO: Implement your solution here
    pass

worst_cases = find_worst_examples(best_model, test_dataset)

if worst_cases:
    fig, axes = plt.subplots(3, 4, figsize=(12, 9))
    fig.suptitle('Worst Predictions (High Confidence, Wrong Answer)', fontsize=16)
    
    for i, case in enumerate(worst_cases):
        if i >= 12:
            break
        row, col = i // 4, i % 4
        
        image_np = case['image'].squeeze().numpy()
        axes[row, col].imshow(image_np, cmap='gray')
        title = f"True: {case['true_label']}, Pred: {case['predicted_label']}\nConf: {case['confidence']:.2f}"
        axes[row, col].set_title(title, color='red', fontsize=10)
        axes[row, col].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    print(f"Found {len(worst_cases)} high-confidence wrong predictions")
    

# 🎉 Assignment Complete!

Next Steps

Your trained model is now ready for **Adversarial Attacks**, where you will generate adversarial examples to fool your classifier.

**Save your work and prepare for the exciting world of adversarial machine learning!**
