# CS182: Deep Neural Networks - Assignment on CNNs and Advanced Optimizers

**UC Berkeley - Fall 2025**

In this assignment, you will:
1. Build a standard training-validation loop for CNN models **(Part a - 20 points)**
2. Implement and understand the Newton-Schulz (NS) iteration used in modern optimizers **(Part b - 40 points)**
3. Implement the Lion optimizer and compare its performance with AdamW on ResNet18/CIFAR-10 **(Part c - 40 points)**

**Important Notes:**
- This notebook is designed to run on Google Colab with GPUs
- All random seeds are set to 42 for reproducibility across runs  
- PyTorch 2.0+ recommended

## Setup and Installation

In [None]:
# Check PyTorch version and upgrade if necessary
import torch
print(f"PyTorch version: {torch.__version__}")

# If version is less than 2.9, upgrade
if torch.__version__ < '2.9':
    !pip install --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import copy
import random

# Set random seeds for reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # for multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"Random seed set to: {SEED}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Data Loading and Visualization

We'll use the CIFAR-10 dataset, which consists of 60,000 32x32 color images in 10 classes.

In [None]:
# Define transforms for CIFAR-10
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

# Split training set into train and validation
# Using 40,000 train / 10,000 validation to match test set size
train_size = 40000
val_size = 10000

# Use generator with fixed seed for reproducible split
generator = torch.Generator().manual_seed(SEED)
trainset, valset = random_split(trainset, [train_size, val_size], generator=generator)

print(f"Training samples: {len(trainset)}")
print(f"Validation samples: {len(valset)}")
print(f"Test samples: {len(testset)}")
print(f"\nNote: Validation and test sets have equal size ({val_size} samples each)")

# CIFAR-10 classes
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# Visualize some training images
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# Get some random training images
temp_loader = DataLoader(trainset, batch_size=8, shuffle=True)
dataiter = iter(temp_loader)
images, labels = next(dataiter)

# Show images
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
for idx, ax in enumerate(axes.flat):
    ax.imshow(np.transpose(images[idx].numpy() / 2 + 0.5, (1, 2, 0)))
    ax.set_title(classes[labels[idx]])
    ax.axis('off')
plt.tight_layout()
plt.show()

---
## Part (a): Build a Standard Training-Validation Loop

Your task is to implement a general `train_validation_loop` function that can be used to train any model with any optimizer. This function will be reused in later parts of the assignment.

**Requirements:**
- The function should accept: model, train_loader, val_loader, optimizer, criterion, num_epochs, and device
- Track and return both training and validation losses and accuracies for each epoch
- Use tqdm for progress bars during training
- Save the best model based on validation accuracy
- Return a dictionary containing training history and the best model state

**Deliverables:**
1. Complete implementation of `train_validation_loop` function
2. Test it with a simple CNN (provided below) on CIFAR-10

In [None]:
# Simple CNN for testing (provided)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

In [None]:
def train_validation_loop(model, train_loader, val_loader, optimizer, criterion, 
                         num_epochs, device):
    """
    General training and validation loop for PyTorch models.
    
    Args:
        model: PyTorch model to train
        train_loader: DataLoader for training data
        val_loader: DataLoader for validation data
        optimizer: PyTorch optimizer
        criterion: Loss function
        num_epochs: Number of epochs to train
        device: Device to train on (cuda/cpu)
    
    Returns:
        Dictionary containing:
            - 'train_losses': List of training losses per epoch
            - 'val_losses': List of validation losses per epoch
            - 'train_accs': List of training accuracies per epoch
            - 'val_accs': List of validation accuracies per epoch
            - 'best_model_state': State dict of best model (based on val accuracy)
            - 'best_val_acc': Best validation accuracy achieved
    """
    
    # TODO: Initialize tracking variables
    # - Lists to store losses and accuracies
    # - Best validation accuracy and best model state
    
    # YOUR CODE HERE
    raise NotImplementedError()
    
    for epoch in range(num_epochs):
        # TODO: Implement training phase
        # 1. Set model to training mode
        # 2. Iterate through train_loader with tqdm
        # 3. For each batch:
        #    - Move data to device
        #    - Zero gradients
        #    - Forward pass
        #    - Compute loss
        #    - Backward pass
        #    - Optimizer step
        #    - Track running loss and accuracy
        
        # YOUR CODE HERE
        raise NotImplementedError()
        
        # TODO: Implement validation phase
        # 1. Set model to evaluation mode
        # 2. A short step regarding gradient computation (what is it and why?)
        # 3. Iterate through val_loader
        # 4. Compute validation loss and accuracy
        # 5. Save best model if current val accuracy is best so far
        
        # YOUR CODE HERE
        raise NotImplementedError()
        
        # TODO: Print epoch statistics
        
        # YOUR CODE HERE
    
    # TODO: Return results dictionary
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# Test your implementation with SimpleCNN
train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(valset, batch_size=128, shuffle=False, num_workers=2)

# Initialize model, optimizer, and criterion
simple_model = SimpleCNN().to(device)
optimizer = torch.optim.Adam(simple_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Train for a few epochs to test
results = train_validation_loop(
    model=simple_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    num_epochs=5,
    device=device
)

In [None]:
# Plot training and validation curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss curves
ax1.plot(results['train_losses'], label='Train Loss')
ax1.plot(results['val_losses'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy curves
ax2.plot(results['train_accs'], label='Train Acc')
ax2.plot(results['val_accs'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"Best Validation Accuracy: {results['best_val_acc']:.2f}%")

---
## Newton-Schulz (NS) Coefficients Visualization

Before we dive into implementing optimizers, let's understand the Newton-Schulz iteration and how different coefficient sets affect the convergence behavior.

### Background: What is Newton-Schulz and Why Orthogonalization Matters

The Newton-Schulz iteration is a method for computing the **inverse square root** of a matrix, $X^{-1/2}$. The iteration uses a polynomial approximation:

$$X_{k+1} = X_k(\alpha I + \beta X_k^2 + \gamma X_k^4)$$

**Why do we want the inverse square root?** Because when applied to a matrix, it produces an **orthogonal matrix** (or approximately orthogonal after sufficient iterations).

### The Power of Orthogonalization in Optimization

Orthogonal matrices have special properties that are crucial for stable and efficient neural network training:

1. **Preserve Gradient Norms**: Orthogonal transformations preserve the magnitude of vectors. This prevents gradients from vanishing or exploding as they backpropagate through layers.

2. **Balanced Learning Directions**: In the weight matrix, different directions (singular vectors) can have vastly different scales. Orthogonalization "evens out" these scales, ensuring all feature dimensions learn at comparable rates.

3. **Condition Number Control**: The condition number of an orthogonal matrix is always 1 (optimal). This leads to better-conditioned optimization problems that converge more reliably.

4. **Prevents Feature Collapse**: Without orthogonalization, weight matrices can develop redundant or highly correlated features. Orthogonalization encourages diverse, independent feature representations.

### How Muon Uses Newton-Schulz

In the Muon optimizer, the Newton-Schulz iteration is applied to **precondition the gradients**:
- Instead of using gradients directly, Muon orthogonalizes them
- This ensures gradient updates are well-balanced across all parameter dimensions
- The NS coefficients control how aggressively this orthogonalization is performed

### Two Coefficient Sets with Different Trade-offs

1. **Aggressive (Keller) Coefficients**: α ≈ 3.4445, β ≈ -4.7750, γ ≈ 2.0315
   - **Does not converge** (coefficients don't sum to 1)
   - High initial slope (~3.44) aggressively inflates small singular values
   - Pushes the matrix rapidly toward the orthogonal approximation range
   - Best for: LLM pretraining and speedruns where you want maximum speed
   - Trades mathematical convergence for rapid feature amplification
   - Risk: May overshoot the orthogonal target, causing instability

2. **Stable (Taylor) Coefficients**: α = 1.875, β = -1.25, γ = 0.375
   - **Convergent** (coefficients sum to exactly 1)
   - Gentle slope (1.875) for stable orthogonalization
   - Iteratively refines the matrix toward perfect orthogonality
   - Best for: Fine-tuning and stability-critical tasks
   - Focuses on convergence and preserving learned structure
   - Guarantees: Will eventually reach an orthogonal matrix if run indefinitely

### The Orthogonalization Process

The visualization below shows how these coefficients affect **singular values** over multiple NS iterations:
- **Singular values** measure the "stretch" along different directions in the matrix
- An orthogonal matrix has all singular values equal to 1
- NS iterations gradually transform the singular value spectrum toward this ideal
- Different coefficients change how aggressively this transformation happens

---
## Part (b): Implement Newton-Schulz Iteration

Now that you understand the theory, let's implement the Newton-Schulz iteration function.

The Newton-Schulz iteration computes: $$X_{k+1} = X_k(\alpha I + \beta X_k^2 + \gamma X_k^4)$$

Your task is to implement this iteration step and visualize how different coefficient sets affect the convergence behavior.

In [None]:
def ns_iteration(X, alpha, beta, gamma):
    """
    Perform one Newton-Schulz iteration step.
    
    Formula: X_{k+1} = X_k * (alpha * I + beta * X_k^2 + gamma * X_k^4)
    
    Args:
        X: Current matrix (torch.Tensor)
        alpha: Coefficient for identity term
        beta: Coefficient for X^2 term  
        gamma: Coefficient for X^4 term
    
    Returns:
        X_{k+1}: Updated matrix after one NS iteration
    """
    # TODO: Implement Newton-Schulz iteration
    # 1. Compute X^2 = X @ X
    # 2. Compute X^4 = X^2 @ X^2
    # 3. Create identity matrix I with same shape as X
    # 4. Compute polynomial: alpha * I + beta * X^2 + gamma * X^4
    # 5. Return X @ polynomial
    
    # YOUR CODE HERE
    raise NotImplementedError()


def track_singular_values(A, coeffs, num_iters):
    """
    Track singular values through multiple NS iterations.
    
    Args:
        A: Initial matrix (torch.Tensor)
        coeffs: Tuple of (alpha, beta, gamma) coefficients
        num_iters: Number of iterations to perform
    
    Returns:
        List of numpy arrays containing singular values at each iteration
    """
    # TODO: Implement singular value tracking
    # 1. Clone the input matrix A to X
    # 2. Create empty list sv_history to store singular values
    # 3. Loop num_iters times:
    #    a. Compute SVD of X using torch.linalg.svd
    #    b. Extract singular values S and convert to numpy, append to sv_history
    #    c. Update X using ns_iteration with the given coefficients
    # 4. Return sv_history
    
    # YOUR CODE HERE
    raise NotImplementedError()

In [None]:
# Test your implementation with both coefficient sets
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# Define coefficient sets
aggressive_coeffs = (3.4445, -4.7750, 2.0315)  # Keller coefficients
stable_coeffs = (1.875, -1.25, 0.375)  # Taylor series coefficients

# Create a random test matrix (using same seed for reproducibility)
torch.manual_seed(SEED)
n = 50
A = torch.randn(n, n)
A = A @ A.T + 0.1 * torch.eye(n)

# Normalize to have singular values in [0, 1]
U, S, Vt = torch.linalg.svd(A)
S_normalized = S / S.max()
A_normalized = U @ torch.diag(S_normalized) @ Vt

num_iterations = 10

# Track singular values for both coefficient sets
aggressive_sv = track_singular_values(A_normalized, aggressive_coeffs, num_iterations)
stable_sv = track_singular_values(A_normalized, stable_coeffs, num_iterations)

print("Newton-Schulz iterations completed!")

In [None]:
# Create animated visualization showing NS iteration convergence
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

def animate(frame):
    ax1.clear()
    ax2.clear()
    
    # Aggressive coefficients
    ax1.bar(range(len(aggressive_sv[frame])), aggressive_sv[frame], color='red', alpha=0.7)
    ax1.set_ylim([0, max(aggressive_sv[-1].max(), stable_sv[-1].max()) * 1.1])
    ax1.set_xlabel('Singular Value Index')
    ax1.set_ylabel('Magnitude')
    ax1.set_title(f'Aggressive Coefficients (Iteration {frame})')
    ax1.axhline(y=1.0, color='black', linestyle='--', linewidth=2, alpha=0.5, label='Target')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    ax1.text(0.05, 0.95, 'Rapidly inflates small values\nMay not converge', 
             transform=ax1.transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Stable coefficients
    ax2.bar(range(len(stable_sv[frame])), stable_sv[frame], color='blue', alpha=0.7)
    ax2.set_ylim([0, max(aggressive_sv[-1].max(), stable_sv[-1].max()) * 1.1])
    ax2.set_xlabel('Singular Value Index')
    ax2.set_ylabel('Magnitude')
    ax2.set_title(f'Stable Coefficients (Iteration {frame})')
    ax2.axhline(y=1.0, color='black', linestyle='--', linewidth=2, alpha=0.5, label='Target')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    ax2.text(0.05, 0.95, 'Gentle orthogonalization\nConverges to equilibrium', 
             transform=ax2.transAxes, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5))
    
    plt.tight_layout()

anim = FuncAnimation(fig, animate, frames=num_iterations, interval=500, repeat=True)
plt.close()

HTML(anim.to_jshtml())

### Newton-Schulz Analysis Questions

Based on the visualization above, answer the following questions in the cell below:

1. **Convergence Behavior**: Describe the difference in how singular values evolve for the aggressive (Keller) vs. stable (Taylor) coefficients. Which set brings singular values closer to 1.0 more quickly? Which set appears to converge to exactly 1.0?

2. **Mathematical Connection**: The aggressive coefficients sum to α + β + γ ≈ 1.001 (approximately 1 but not exact), while the stable coefficients sum to exactly 1.0. How does this mathematical property relate to what you observe in the convergence behavior? Why might coefficients that sum to 1 guarantee convergence?

3. **Overshoot vs. Stability**: In the aggressive coefficient visualization, do you observe any singular values that exceed 1.0 (overshoot the target)? What are the potential risks of this behavior when orthogonalizing gradient matrices in an optimizer?

**YOUR ANSWERS HERE:**

1. 

2. 

3. 

---
## Important Note: Muon's 2D Parameter Requirement

Before moving on to Part (c), it's important to understand a key limitation of the Muon optimizer:

**Muon requires EXACTLY 2D parameters**. According to the [PyTorch 2.9 documentation](https://docs.pytorch.org/docs/stable/generated/torch.optim.Muon.html):

> "Muon is an optimizer for 2D parameters of neural network hidden layers."

The implementation checks: `if p.ndim != 2: raise ValueError(...)`

**Parameter dimensionality in typical networks:**
- **4D parameters**: Conv2d weights `[out_channels, in_channels, height, width]`
- **2D parameters**: Linear layer weights `[out_features, in_features]`  
- **1D parameters**: Biases, BatchNorm weights

**Why this matters for CNNs:**
Since most CNN parameters are 4D (convolutional layers), Muon cannot optimize them directly. This significantly limits Muon's applicability to pure CNN architectures like ResNet. While Muon excels at training Transformers and MLPs (which are dominated by 2D matrix multiplications), it's not well-suited for CNNs.

**In the next section**, we'll implement the **Lion optimizer**, which works with parameters of any dimensionality and is much better suited for training CNNs like ResNet on CIFAR-10.

---
## Part (c): Implement Lion Optimizer and Compare with AdamW

In this section, you will:
1. Implement the Lion optimizer from scratch based on the pseudocode
2. Prepare ResNet18 for CIFAR-10 transfer learning (with ALL parameters trainable)
3. Train ResNet18 with Lion using suggested hyperparameters
4. Train ResNet18 with AdamW (SOTA baseline) for comparison
5. Compare the performance of Lion vs AdamW

### Lion Optimizer Background

Lion (evoLved sIgn mOmeNtum) is a simple yet effective optimizer discovered through program search. Refer to the original paper: [Lion: Adversarial Learning with Momentum](https://arxiv.org/abs/2302.06675)

The key update rules are:
- Use the sign of the interpolation between momentum and gradient
- Update momentum as EMA of gradients
- Apply weight decay

Unlike Muon, **Lion works with parameters of any dimensionality**, making it suitable for CNNs.

### Step 1: Implement Lion Optimizer

Implement the Lion optimizer class following the algorithm from the paper.

In [None]:
class Lion(torch.optim.Optimizer):
    """
    Implements Lion optimizer.
    
    Based on the paper: https://arxiv.org/abs/2302.06675
    
    Args:
        params: iterable of parameters to optimize
        lr: learning rate (default: 1e-4)
        beta1: coefficient for interpolation with momentum (default: 0.9)
        beta2: coefficient for momentum EMA (default: 0.99)
        weight_decay: weight decay coefficient (default: 0)
    
    Hint: Refer to PyTorch's Optimizer base class documentation and SGD source code
    as examples for implementing custom optimizers.
    """
    
    def __init__(self, params, lr=1e-4, beta1=0.9, beta2=0.99, weight_decay=0.0):
        # TODO: Validate hyperparameters and call parent constructor
        # YOUR CODE HERE
        raise NotImplementedError()
    
    @torch.no_grad()
    def step(self, closure=None):
        """
        Performs a single optimization step.
        
        Algorithm (from paper):
        1. c_t = β₁ * m_{t-1} + (1 - β₁) * g_t  (interpolation)
        2. θ_t = θ_{t-1} - η * (sign(c_t) + λ * θ_{t-1})  (param update)
        3. m_t = β₂ * m_{t-1} + (1 - β₂) * g_t  (momentum update)
        
        Hint: Use self.state[p] to store per-parameter state (like momentum).
        Refer to torch.sign() for the sign function.
        """
        loss = None
        if closure is not None:
            loss = closure()
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                
                # TODO: Implement the three steps above
                # YOUR CODE HERE
                raise NotImplementedError()
        
        return loss

In [None]:
# Quick test of Lion implementation
# This will fail until you implement the Lion class above
test_model = SimpleCNN().to(device)
test_optimizer = Lion(test_model.parameters(), lr=1e-3)
print("Lion optimizer created successfully!")
print(f"Optimizer state: {test_optimizer.state_dict()}")

### Step 2: Prepare ResNet18 for CIFAR-10 Transfer Learning

Now let's prepare a pretrained ResNet18 model for CIFAR-10. Since Lion works with parameters of any dimensionality, we can train ALL parameters (unlike Muon which only works with 2D parameters).

In [None]:
from torchvision.models import resnet18, ResNet18_Weights

def prepare_resnet_for_cifar10(pretrained=True):
    """
    Prepare ResNet18 for CIFAR-10 transfer learning.
    
    TODO: Complete this function
    1. Load ResNet18 with pretrained ImageNet weights if pretrained=True
       Hint: Use resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
    2. Replace the final fully connected layer to output 10 classes instead of 1000
       Hint: The final layer is model.fc, and it's a Linear layer
       You'll need to check model.fc.in_features to get the input size
    3. All parameters should remain trainable (this is the default, so no action needed)
    
    Args:
        pretrained (bool): Whether to load ImageNet pretrained weights
    
    Returns:
        model: Modified ResNet18 ready for CIFAR-10
    """
    # YOUR CODE HERE
    raise NotImplementedError()
    
    return model

In [None]:
# Test the function
test_resnet = prepare_resnet_for_cifar10(pretrained=True).to(device)
print(f"Final layer: {test_resnet.fc}")
print(f"Total parameters: {sum(p.numel() for p in test_resnet.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in test_resnet.parameters() if p.requires_grad):,}")

### Step 3: Hyperparameter Tuning for Lion

Before comparing Lion with AdamW, we need to find the best hyperparameters for Lion. We'll perform a grid search over learning rate, batch size, and weight decay.

**Search Space:**
- Learning rates: [1e-4, 3e-4, 1e-3]
- Batch sizes: [64, 128, 256]
- Weight decays: [0.0, 0.01, 0.1]
- Total configurations: 27 (3 × 3 × 3)
- Training epochs per config: 15 epochs

In [None]:
# Hyperparameter search space
learning_rates = [1e-4, 3e-4, 1e-3]
batch_sizes = [64, 128, 256]
weight_decays = [0.0, 0.01, 0.1]

# TODO: Implement hyperparameter grid search for Lion
# 1. Create an empty list to store results: results_grid = []
# 2. Calculate total_configs = len(learning_rates) * len(batch_sizes) * len(weight_decays)
# 3. Create nested loops over learning_rates, batch_sizes, and weight_decays
# 4. For each configuration:
#    a. Print configuration number and hyperparameters
#    b. Create DataLoaders with the current batch_size
#    c. Initialize a fresh ResNet18 model using prepare_resnet_for_cifar10()
#    d. Create Lion optimizer with current lr and weight_decay (use default betas)
#    e. Train for 15 epochs using train_validation_loop
#    f. Store results in a dictionary with keys: 'lr', 'batch_size', 'weight_decay', 'best_val_acc'
#    g. Append to results_grid
# 5. After all configs, print "Hyperparameter search complete!"

criterion = nn.CrossEntropyLoss()

# YOUR CODE HERE
raise NotImplementedError()

# Note: This will take significant time (~1-2 hours on GPU)
# For testing, you can reduce the search space or num_epochs

### Analyze Hyperparameter Search Results

In [None]:
import pandas as pd

# TODO: Analyze hyperparameter search results
# 1. Create a DataFrame from results_grid
# 2. Sort by 'best_val_acc' in descending order
# 3. Display the top 5 configurations
# 4. Extract the best configuration (first row after sorting)
# 5. Print the best hyperparameters

# YOUR CODE HERE
raise NotImplementedError()

print(f"\nBest Configuration:")
print(f"Learning Rate: {best_config['lr']}")
print(f"Batch Size: {best_config['batch_size']}")
print(f"Weight Decay: {best_config['weight_decay']}")
print(f"Best Val Acc: {best_config['best_val_acc']:.2f}%")

In [None]:
# Visualize hyperparameter effects
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Learning rate effect
lr_grouped = df_results.groupby('lr')['best_val_acc'].mean()
axes[0].bar(range(len(lr_grouped)), lr_grouped.values)
axes[0].set_xticks(range(len(lr_grouped)))
axes[0].set_xticklabels([f'{lr:.0e}' for lr in lr_grouped.index])
axes[0].set_xlabel('Learning Rate')
axes[0].set_ylabel('Avg Val Accuracy (%)')
axes[0].set_title('Effect of Learning Rate')
axes[0].grid(True, alpha=0.3)

# Batch size effect
bs_grouped = df_results.groupby('batch_size')['best_val_acc'].mean()
axes[1].bar(range(len(bs_grouped)), bs_grouped.values)
axes[1].set_xticks(range(len(bs_grouped)))
axes[1].set_xticklabels(bs_grouped.index)
axes[1].set_xlabel('Batch Size')
axes[1].set_ylabel('Avg Val Accuracy (%)')
axes[1].set_title('Effect of Batch Size')
axes[1].grid(True, alpha=0.3)

# Weight decay effect
wd_grouped = df_results.groupby('weight_decay')['best_val_acc'].mean()
axes[2].bar(range(len(wd_grouped)), wd_grouped.values)
axes[2].set_xticks(range(len(wd_grouped)))
axes[2].set_xticklabels(wd_grouped.index)
axes[2].set_xlabel('Weight Decay')
axes[2].set_ylabel('Avg Val Accuracy (%)')
axes[2].set_title('Effect of Weight Decay')
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Step 4: Train with Best Lion Configuration

Now let's train ResNet18 with the best hyperparameters we found for a full 20 epochs.

In [None]:
# Initialize model with best hyperparameters
model_lion = prepare_resnet_for_cifar10(pretrained=True).to(device)

# Create optimizer with best hyperparameters
optimizer_lion = Lion(
    model_lion.parameters(),
    lr=best_config['lr'],
    beta1=0.9,
    beta2=0.99,
    weight_decay=best_config['weight_decay']
)

# Create data loaders with best batch size
train_loader_best = DataLoader(trainset, batch_size=int(best_config['batch_size']), 
                               shuffle=True, num_workers=2)
val_loader_best = DataLoader(valset, batch_size=int(best_config['batch_size']), 
                             shuffle=False, num_workers=2)

# Train for 20 epochs
print("Training ResNet18 with best Lion hyperparameters...")
results_lion = train_validation_loop(
    model=model_lion,
    train_loader=train_loader_best,
    val_loader=val_loader_best,
    optimizer=optimizer_lion,
    criterion=criterion,
    num_epochs=20,
    device=device
)

print(f"\nLion (Best Config) - Best Validation Accuracy: {results_lion['best_val_acc']:.2f}%")

#### Evaluate Lion Model on Test Set

Now let's evaluate the trained Lion model on the held-out test set to see how well it generalizes.

In [None]:
# Evaluate on test set
def evaluate_on_test(model, test_loader, criterion, device):
    """
    Evaluate model on test set.
    
    TODO: Implement this function to evaluate a trained model on the test set.
    
    Args:
        model: Trained model
        test_loader: DataLoader for test data
        criterion: Loss function
        device: Device to run on
    
    Returns:
        test_loss: Average test loss
        test_acc: Test accuracy (%)
    
    Hint: This should be similar to the validation phase in train_validation_loop:
    1. Set model to evaluation mode
    2. Do the same step regarding gradients that was done in validation phase
    3. Iterate through test_loader with tqdm for progress tracking
    4. For each batch:
       - Move inputs and labels to device
       - Forward pass
       - Compute loss
       - Track predictions and correct counts
    5. Compute average loss and accuracy
    6. Return test_loss and test_acc (as percentage)
    """
    # YOUR CODE HERE
    raise NotImplementedError()

# Create test loader
test_loader = DataLoader(testset, batch_size=int(best_config['batch_size']), 
                         shuffle=False, num_workers=2)

# Load best Lion model and evaluate
model_lion.load_state_dict(results_lion['best_model_state'])
lion_test_loss, lion_test_acc = evaluate_on_test(model_lion, test_loader, criterion, device)

print(f"\nLion Model - Test Set Performance:")
print(f"Test Loss: {lion_test_loss:.4f}")
print(f"Test Accuracy: {lion_test_acc:.2f}%")

### Step 5: Train ResNet18 with AdamW (Baseline)

Now let's train the same model with AdamW, which is currently the most widely-used optimizer for deep learning. We'll use SOTA (state-of-the-art) hyperparameters for fair comparison.

In [None]:
# Initialize model for AdamW training
model_adamw = prepare_resnet_for_cifar10(pretrained=True).to(device)

# AdamW hyperparameters (SOTA for ResNet transfer learning)
optimizer_adamw = torch.optim.AdamW(
    model_adamw.parameters(),
    lr=3e-4,  # AdamW typically uses higher learning rates than Lion
    betas=(0.9, 0.999),
    weight_decay=0.01
)

# Use same batch size as Lion's best config for fair comparison
# Note: train_loader_best and val_loader_best were created in the previous cell

# Train for 20 epochs (same as Lion for fair comparison)
print("Training ResNet18 with AdamW optimizer...")
results_adamw = train_validation_loop(
    model=model_adamw,
    train_loader=train_loader_best,
    val_loader=val_loader_best,
    optimizer=optimizer_adamw,
    criterion=criterion,
    num_epochs=20,
    device=device
)

print(f"\nAdamW - Best Validation Accuracy: {results_adamw['best_val_acc']:.2f}%")

#### Evaluate AdamW Model on Test Set

Let's also evaluate the trained AdamW model on the test set for comparison.

In [None]:
# Load best AdamW model and evaluate on test set
model_adamw.load_state_dict(results_adamw['best_model_state'])
adamw_test_loss, adamw_test_acc = evaluate_on_test(model_adamw, test_loader, criterion, device)

print(f"\nAdamW Model - Test Set Performance:")
print(f"Test Loss: {adamw_test_loss:.4f}")
print(f"Test Accuracy: {adamw_test_acc:.2f}%")

print(f"\n{'='*50}")
print(f"Test Set Comparison:")
print(f"{'='*50}")
print(f"Lion  - Test Acc: {lion_test_acc:.2f}%")
print(f"AdamW - Test Acc: {adamw_test_acc:.2f}%")
print(f"Difference: {abs(lion_test_acc - adamw_test_acc):.2f}%")

### Comparison: Lion vs AdamW

In [None]:
# Compare Lion vs AdamW
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 10))

# Training Loss
ax1.plot(results_lion['train_losses'], label='Lion', color='purple', linewidth=2)
ax1.plot(results_adamw['train_losses'], label='AdamW', color='green', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss Comparison')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Validation Loss
ax2.plot(results_lion['val_losses'], label='Lion', color='purple', linewidth=2)
ax2.plot(results_adamw['val_losses'], label='AdamW', color='green', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Loss')
ax2.set_title('Validation Loss Comparison')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Training Accuracy
ax3.plot(results_lion['train_accs'], label='Lion', color='purple', linewidth=2)
ax3.plot(results_adamw['train_accs'], label='AdamW', color='green', linewidth=2)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy (%)')
ax3.set_title('Training Accuracy Comparison')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Validation Accuracy
ax4.plot(results_lion['val_accs'], label='Lion', color='purple', linewidth=2)
ax4.plot(results_adamw['val_accs'], label='AdamW', color='green', linewidth=2)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Accuracy (%)')
ax4.set_title('Validation Accuracy Comparison')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFinal Results:")
print(f"Lion    - Best Val Acc: {results_lion['best_val_acc']:.2f}%")
print(f"AdamW   - Best Val Acc: {results_adamw['best_val_acc']:.2f}%")
print(f"\nDifference: {abs(results_lion['best_val_acc'] - results_adamw['best_val_acc']):.2f}%")

### Analysis Questions (Answer in the cell below):

1. **Performance Comparison**: Which optimizer achieved better validation accuracy on CIFAR-10: Lion or AdamW? By how much?

2. **Training Dynamics**: Describe the differences in training curves between Lion and AdamW. Which converged faster? Were there any notable differences in stability?

3. **Hyperparameter Sensitivity**: What was the best learning rate found for Lion? How does it compare to AdamW's learning rate (3e-4)? Based on the algorithm differences (sign-based updates vs adaptive learning rates), why might Lion have different learning rate requirements?

4. **Hyperparameter Effects**: Based on your grid search results, which hyperparameter (learning rate, batch size, or weight decay) had the most significant impact on Lion's performance? Why do you think this is? (Hint: Consider Lion's sign-based updates as 1-bit quantization of gradients, which benefits from larger batch sizes to reduce stochastic noise.)

5. **Memory and Computation**: Consider the implementation details of both optimizers. Which one is more memory-efficient? Why? (Hint: Think about what state each optimizer needs to maintain.)

6. **When to Use Each**: Based on your results and understanding of both optimizers, in what scenarios would you choose Lion over AdamW, and vice versa?

**YOUR ANSWERS HERE:**

1. 

2. 

3. 

4. 

5. 

6. 

---
## Conclusion

In this assignment, you:
1. Built a reusable training-validation loop for deep learning models **(Part a)**
2. Understood Newton-Schulz coefficients and their role in optimization **(Part b)**
3. Implemented the ns_iteration and track_singular_values functions to visualize how different NS coefficients affect convergence behavior **(Part b)**
4. Learned about Muon's 2D parameter limitation and why it's not suitable for pure CNN architectures
5. Implemented the Lion optimizer from scratch **(Part c)**
6. Prepared ResNet18 for CIFAR-10 transfer learning with all parameters trainable **(Part c)**
7. Performed systematic hyperparameter tuning for Lion optimizer using grid search **(Part c)**
8. Compared the best Lion configuration vs AdamW performance on ResNet18/CIFAR-10 **(Part c)**

**Key Takeaways:**
- Different optimizers have different strengths: Muon for Transformers/MLPs, Lion and AdamW for general use
- Understanding the mathematical foundations (like Newton-Schulz iterations) helps explain optimizer behavior
- Modern optimizers often make trade-offs between convergence speed, stability, and computational efficiency
- Hyperparameter tuning is crucial for achieving optimal performance with any optimizer
- Empirical comparison is essential for choosing the right optimizer for your specific task

**Submission Instructions:**
1. Ensure all code cells run without errors
2. Answer all analysis questions
3. Export this notebook as PDF
4. Submit both .ipynb and .pdf files