# Deep Learning Applications: Laboratory #1

In this first laboratory we will work relatively simple architectures to get a feel for working with Deep Models. This notebook is designed to work with PyTorch, but as I said in the introductory lecture: please feel free to use and experiment with whatever tools you like.

**Important Notes**:
1. Be sure to **document** all of your decisions, as well as your intermediate and final results. Make sure your conclusions and analyses are clearly presented. Don't make us dig into your code or walls of printed results to try to draw conclusions from your code.
2. If you use code from someone else (e.g. Github, Stack Overflow, ChatGPT, etc) you **must be transparent about it**. Document your sources and explain how you adapted any partial solutions to creat **your** solution.



## Exercise 1: Warming Up
In this series of exercises I want you to try to duplicate (on a small scale) the results of the ResNet paper:

> [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, CVPR 2016.

We will do this in steps using a Multilayer Perceptron on MNIST.

Recall that the main message of the ResNet paper is that **deeper** networks do not **guarantee** more reduction in training loss (or in validation accuracy). Below you will incrementally build a sequence of experiments to verify this for an MLP. A few guidelines:

+ I have provided some **starter** code at the beginning. **NONE** of this code should survive in your solutions. Not only is it **very** badly written, it is also written in my functional style that also obfuscates what it's doing (in part to **discourage** your reuse!). It's just to get you *started*.
+ These exercises ask you to compare **multiple** training runs, so it is **really** important that you factor this into your **pipeline**. Using [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) is a **very** good idea -- or, even better [Weights and Biases](https://wandb.ai/site).
+ You may work and submit your solutions in **groups of at most two**. Share your ideas with everyone, but the solutions you submit *must be your own*.

First some boilerplate to get you started, then on to the actual exercises!

### Preface: Some code to get you started

What follows is some **very simple** code for training an MLP on MNIST. The point of this code is to get you up and running (and to verify that your Python environment has all needed dependencies).

**Note**: As you read through my code and execute it, this would be a good time to think about *abstracting* **your** model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

In [1]:
import sys
print(sys.executable)

D:\Programmi\anaconda3\envs\DLA\python.exe


In [2]:
# Start with some standard imports.
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
import torch
from torchvision.datasets import MNIST
from torch.utils.data import Subset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

#### Data preparation

Here is some basic dataset loading, validation splitting code to get you started working with MNIST.

In [3]:
# Standard MNIST transform.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST train and test.
ds_train = MNIST(root='./data', train=True, download=True, transform=transform)
ds_test = MNIST(root='./data', train=False, download=True, transform=transform)

# Split train into train and validation.
val_size = 5000
I = np.random.permutation(len(ds_train))
ds_val = Subset(ds_train, I[:val_size]) #estraggo elementi per la valutazione, crea un subset dal dataset
ds_train = Subset(ds_train, I[val_size:])

#### Boilerplate training and evaluation code

This is some **very** rough training, evaluation, and plotting code. Again, just to get you started. I will be *very* disappointed if any of this code makes it into your final submission.

In [4]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report

# Function to train a model for a single epoch over the data loader.
def train_epoch(model, dl, opt, epoch='Unknown', device='cpu'):
    model.train()
    losses = []
    for (xs, ys) in tqdm(dl, desc=f'Training epoch {epoch}', leave=True):
        xs = xs.to(device)
        ys = ys.to(device)
        opt.zero_grad() # importante azzerare gradienti prima di fare gradient descent
        logits = model(xs)
        loss = F.cross_entropy(logits, ys)
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return np.mean(losses)

# Function to evaluate model over all samples in the data loader.
def evaluate_model(model, dl, device='cpu'):
    model.eval()
    predictions = []
    gts = []
    for (xs, ys) in tqdm(dl, desc='Evaluating', leave=False):
        xs = xs.to(device)
        preds = torch.argmax(model(xs), dim=1)
        gts.append(ys)
        predictions.append(preds.detach().cpu().numpy()) #detach(), il modello crea un grafo di calcolo, necessario per calcolare 
        #gradiente in modo automatico, se noi faccimao altri calcoli succede che torch.argmax viene aggiunto al grafo e se provo 
        #a portare questo caloclo sulla cpu o si porta tutto e on si puo fare oppure si fa detach e si stacca solo il tensore e lo si porta sulla cpu
        
    # Return accuracy score and classification report.
    return (accuracy_score(np.hstack(gts), np.hstack(predictions)),
            classification_report(np.hstack(gts), np.hstack(predictions), zero_division=0, digits=3))

# Simple function to plot the loss curve and validation accuracy.
def plot_validation_curves(losses_and_accs):
    losses = [x for (x, _) in losses_and_accs]
    accs = [x for (_, x) in losses_and_accs]
    plt.figure(figsize=(16, 8))
    plt.subplot(1, 2, 1)
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Average Training Loss per Epoch')
    plt.subplot(1, 2, 2)
    plt.plot(accs)
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title(f'Best Accuracy = {np.max(accs)} @ epoch {np.argmax(accs)}')

#### A basic, parameterized MLP

This is a very basic implementation of a Multilayer Perceptron. Don't waste too much time trying to figure out how it works -- the important detail is that it allows you to pass in a list of input, hidden layer, and output *widths*. **Your** implementation should also support this for the exercises to come.

In [5]:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(nin, nout) for (nin, nout) in zip(layer_sizes[:-1], layer_sizes[1:])])
    
    def forward(self, x):
        return reduce(lambda f, g: lambda x: g(F.relu(f(x))), self.layers, lambda x: x.flatten(1))(x)

#### A *very* minimal training pipeline.

Here is some basic training and evaluation code to get you started.

**Important**: I cannot stress enough that this is a **terrible** example of how to implement a training pipeline. You can do better!

In [None]:
# Training hyperparameters.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 100
lr = 0.0001
batch_size = 128

# Architecture hyperparameters.
input_size = 28*28
width = 16
depth = 2

# Dataloaders.
dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=True, num_workers=4)
dl_val   = torch.utils.data.DataLoader(ds_val, batch_size, num_workers=4)
dl_test  = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=True, num_workers=4)

# Instantiate model and optimizer.
model_mlp = MLP([input_size] + [width]*depth + [10]).to(device)
opt = torch.optim.Adam(params=model_mlp.parameters(), lr=lr)

# Training loop.
losses_and_accs = []
for epoch in range(epochs):
    loss = train_epoch(model_mlp, dl_train, opt, epoch, device=device)
    (val_acc, _) = evaluate_model(model_mlp, dl_val, device=device)
    losses_and_accs.append((loss, val_acc))

# And finally plot the curves.
plot_validation_curves(losses_and_accs)
print(f'Accuracy report on TEST:\n {evaluate_model(model_mlp, dl_test, device=device)[1]}')

In [None]:
from torch.utils.tensorboard import SummaryWriter

#A better trainig loop
def train_model(model, epochs, opt, dl_tria, dl_val, logdir, device='cuda', verbose=False):
    writer = SummaryWriter(logdir)
    for epoch in range(epochs)
        loss = train_epoch(model, dl_train. opt, epoch, device=device)
        (val_acc, _) = evaluate_model(model, dl_val, device=device)
        writer.add_scalar('Loss/train', loss, epoch)
        writer.add_scalar('Acc/val', val_acc, epoch)
    writer.close()

### Exercise 1.1: A baseline MLP

Implement a *simple* Multilayer Perceptron to classify the 10 digits of MNIST (e.g. two *narrow* layers). Use my code above as inspiration, but implement your own training pipeline -- you will need it later. Train this model to convergence, monitoring (at least) the loss and accuracy on the training and validation sets for every epoch. Below I include a basic implementation to get you started -- remember that you should write your *own* pipeline!

**Note**: This would be a good time to think about *abstracting* your model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

**Important**: Given the *many* runs you will need to do, and the need to *compare* performance between them, this would **also** be a great point to study how **Tensorboard** or **Weights and Biases** can be used for performance monitoring.

In [2]:
# Exercise 1.1: A baseline MLP (Basic Version)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb

# Simple MLP model - basic version without regularization
class BasicMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Training function
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in tqdm(dataloader, desc="Training"):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Validating"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Test function
def test_model(model, dataloader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Testing"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()

            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    accuracy = 100.0 * correct / total
    avg_loss = test_loss / len(dataloader)

    return avg_loss, accuracy, all_preds, all_targets

# Simple training pipeline
def train_model(model, train_loader, val_loader, optimizer, criterion, epochs, device):
    print("Starting training...")

    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Train
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)

        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        # Log to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc
        })

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

def main():
    # Model architecture parameters
    input_size = 28 * 28 * 1  # MNIST image dimensions (28x28 pixels, 1 channel)
    hidden_size = 128         # Hidden layer size
    num_classes = 10          # Number of classes (digits 0-9)
    
    # Training hyperparameters
    epochs = 30
    batch_size = 64
    learning_rate = 0.001
    val_size = 5000
    
    # Initialize W&B
    wandb.init(
        project="mnist-basic-mlp",
        config={
            "epochs": epochs,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "hidden_size": hidden_size,
            "input_size": input_size,
            "num_classes": num_classes,
            "architecture": "Basic 2-layer MLP"
        }
    )

    config = wandb.config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load datasets
    train_dataset = MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = MNIST('./data', train=False, transform=transform)

    # Split training set into train and validation
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Model setup
    model = BasicMLP(input_size, hidden_size, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Log model to W&B
    wandb.watch(model)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Training pipeline
    train_model(model, train_loader, val_loader, optimizer, criterion, epochs, device)

    # Test final model
    test_loss, test_acc, test_preds, test_targets = test_model(model, test_loader, criterion, device)

    print(f"\nFinal Test Results:")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f}")

    # Detailed classification report
    print(f"\nClassification Report:")
    print(classification_report(test_targets, test_preds,
                              target_names=[str(i) for i in range(10)],
                              digits=3))

    # Log final results
    wandb.log({
        "final_test_accuracy": test_acc,
        "final_test_loss": test_loss
    })

    wandb.finish()

    print(f"\nTraining completed! Test accuracy: {test_acc:.2f}%")

if __name__ == "__main__":
    main()

Using device: cuda
Dataset sizes - Train: 55000, Val: 5000, Test: 10000
Model parameters: 118,282
Starting training...

Epoch 1/30


Training: 100%|██████████| 860/860 [00:12<00:00, 66.62it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 85.22it/s]


Train Loss: 0.2634, Train Acc: 91.97%
Val Loss: 0.1378, Val Acc: 96.00%

Epoch 2/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.21it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.68it/s]


Train Loss: 0.1121, Train Acc: 96.51%
Val Loss: 0.1052, Val Acc: 96.58%

Epoch 3/30


Training: 100%|██████████| 860/860 [00:13<00:00, 64.66it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.20it/s]


Train Loss: 0.0797, Train Acc: 97.51%
Val Loss: 0.0955, Val Acc: 97.14%

Epoch 4/30


Training: 100%|██████████| 860/860 [00:13<00:00, 64.17it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 78.51it/s]


Train Loss: 0.0588, Train Acc: 98.09%
Val Loss: 0.0928, Val Acc: 97.22%

Epoch 5/30


Training: 100%|██████████| 860/860 [00:13<00:00, 65.11it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.09it/s]


Train Loss: 0.0468, Train Acc: 98.47%
Val Loss: 0.0896, Val Acc: 97.56%

Epoch 6/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.14it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.46it/s]


Train Loss: 0.0407, Train Acc: 98.66%
Val Loss: 0.0779, Val Acc: 97.66%

Epoch 7/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.54it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 80.59it/s]


Train Loss: 0.0320, Train Acc: 98.92%
Val Loss: 0.0900, Val Acc: 97.40%

Epoch 8/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.30it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 83.15it/s]


Train Loss: 0.0271, Train Acc: 99.08%
Val Loss: 0.0778, Val Acc: 97.88%

Epoch 9/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.36it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 82.38it/s]


Train Loss: 0.0262, Train Acc: 99.13%
Val Loss: 0.0892, Val Acc: 97.52%

Epoch 10/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.03it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.18it/s]


Train Loss: 0.0246, Train Acc: 99.18%
Val Loss: 0.0913, Val Acc: 97.86%

Epoch 11/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.04it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 83.44it/s]


Train Loss: 0.0193, Train Acc: 99.36%
Val Loss: 0.1019, Val Acc: 97.58%

Epoch 12/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.49it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.86it/s]


Train Loss: 0.0185, Train Acc: 99.35%
Val Loss: 0.0972, Val Acc: 97.68%

Epoch 13/30


Training: 100%|██████████| 860/860 [00:12<00:00, 68.37it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 59.10it/s]


Train Loss: 0.0190, Train Acc: 99.34%
Val Loss: 0.1192, Val Acc: 97.44%

Epoch 14/30


Training: 100%|██████████| 860/860 [00:13<00:00, 62.79it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 80.87it/s]


Train Loss: 0.0178, Train Acc: 99.41%
Val Loss: 0.1005, Val Acc: 97.82%

Epoch 15/30


Training: 100%|██████████| 860/860 [00:13<00:00, 65.44it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.42it/s]


Train Loss: 0.0155, Train Acc: 99.49%
Val Loss: 0.1055, Val Acc: 97.76%

Epoch 16/30


Training: 100%|██████████| 860/860 [00:11<00:00, 72.70it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.28it/s]


Train Loss: 0.0133, Train Acc: 99.59%
Val Loss: 0.1139, Val Acc: 97.54%

Epoch 17/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.56it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.45it/s]


Train Loss: 0.0158, Train Acc: 99.49%
Val Loss: 0.1064, Val Acc: 97.96%

Epoch 18/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.68it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 88.38it/s]


Train Loss: 0.0128, Train Acc: 99.58%
Val Loss: 0.1351, Val Acc: 97.72%

Epoch 19/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.53it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.18it/s]


Train Loss: 0.0132, Train Acc: 99.58%
Val Loss: 0.1074, Val Acc: 98.02%

Epoch 20/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.06it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.50it/s]


Train Loss: 0.0148, Train Acc: 99.51%
Val Loss: 0.1288, Val Acc: 97.72%

Epoch 21/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.46it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.93it/s]


Train Loss: 0.0106, Train Acc: 99.66%
Val Loss: 0.1173, Val Acc: 97.72%

Epoch 22/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.44it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.42it/s]


Train Loss: 0.0128, Train Acc: 99.61%
Val Loss: 0.1171, Val Acc: 98.06%

Epoch 23/30


Training: 100%|██████████| 860/860 [00:11<00:00, 73.52it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 67.43it/s]


Train Loss: 0.0106, Train Acc: 99.67%
Val Loss: 0.1100, Val Acc: 97.96%

Epoch 24/30


Training: 100%|██████████| 860/860 [00:11<00:00, 73.06it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.59it/s]


Train Loss: 0.0138, Train Acc: 99.56%
Val Loss: 0.1413, Val Acc: 97.54%

Epoch 25/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.27it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 85.64it/s]


Train Loss: 0.0134, Train Acc: 99.57%
Val Loss: 0.1360, Val Acc: 97.86%

Epoch 26/30


Training: 100%|██████████| 860/860 [00:11<00:00, 72.07it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 85.07it/s]


Train Loss: 0.0083, Train Acc: 99.73%
Val Loss: 0.1160, Val Acc: 97.94%

Epoch 27/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.49it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.98it/s]


Train Loss: 0.0091, Train Acc: 99.64%
Val Loss: 0.1197, Val Acc: 97.94%

Epoch 28/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.59it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 82.56it/s]


Train Loss: 0.0121, Train Acc: 99.65%
Val Loss: 0.1368, Val Acc: 97.90%

Epoch 29/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.12it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.80it/s]


Train Loss: 0.0108, Train Acc: 99.69%
Val Loss: 0.1448, Val Acc: 97.92%

Epoch 30/30


Training: 100%|██████████| 860/860 [00:11<00:00, 73.19it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 70.29it/s]


Train Loss: 0.0112, Train Acc: 99.67%
Val Loss: 0.1491, Val Acc: 97.72%


Testing: 100%|██████████| 157/157 [00:01<00:00, 85.42it/s]



Final Test Results:
Test Accuracy: 97.65%
Test Loss: 0.1519

Classification Report:
              precision    recall  f1-score   support

           0      0.993     0.988     0.990       980
           1      0.988     0.988     0.988      1135
           2      0.987     0.964     0.975      1032
           3      0.989     0.970     0.980      1010
           4      0.979     0.978     0.978       982
           5      0.962     0.982     0.972       892
           6      0.974     0.974     0.974       958
           7      0.974     0.965     0.970      1028
           8      0.966     0.977     0.971       974
           9      0.953     0.979     0.966      1009

    accuracy                          0.977     10000
   macro avg      0.976     0.977     0.976     10000
weighted avg      0.977     0.977     0.977     10000



0,1
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
final_test_accuracy,▁
final_test_loss,▁
train_accuracy,▁▅▆▇▇▇▇▇▇█████████████████████
train_loss,█▄▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▃▅▅▆▇▆▇▆▇▆▇▆▇▇▆█▇█▇▇██▆▇██▇█▇
val_loss,▇▄▃▂▂▁▂▁▂▂▃▃▅▃▄▅▄▇▄▆▅▅▄▇▇▅▅▇██

0,1
epoch,30.0
final_test_accuracy,97.65
final_test_loss,0.15187
train_accuracy,99.66727
train_loss,0.01122
val_accuracy,97.72
val_loss,0.1491



Training completed! Test accuracy: 97.65%


In [1]:
# Exercise 1.1: Improved MLP (with Dropout & Best Model Saving)

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb

# Simple MLP model - following exercise requirement for "narrow" layers
class BaselineMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(0.2)  # Basic regularization

    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Training function
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in tqdm(dataloader, desc="Training"):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Validating"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Test function with detailed metrics
def test_model(model, dataloader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Testing"):
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()

            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

            all_preds.extend(pred.cpu().numpy())
            all_targets.extend(target.cpu().numpy())

    accuracy = 100.0 * correct / total
    avg_loss = test_loss / len(dataloader)

    return avg_loss, accuracy, all_preds, all_targets

# Complete training pipeline
def train_model(model, train_loader, val_loader, optimizer, criterion, config, device):
    print("Starting training...")
    best_val_acc = 0

    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch + 1}/{config.epochs}")

        # Train
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)

        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        # Track best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pth')

        # Log to W&B
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_accuracy": train_acc,
            "val_loss": val_loss,
            "val_accuracy": val_acc
        })

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        print(f"Best Val Acc: {best_val_acc:.2f}%")

    return best_val_acc

def main():
    # Model architecture parameters
    input_size = 28 * 28 * 1  # MNIST image dimensions (28x28 pixels, 1 channel)
    hidden_size = 128         # Hidden layer size
    num_classes = 10          # Number of classes (digits 0-9)
    
    # Training hyperparameters
    epochs = 30
    batch_size = 64
    learning_rate = 0.001
    val_size = 5000
    
    # Initialize W&B
    wandb.init(
        project="mnist-baseline-mlp",
        config={
            "epochs": epochs,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "hidden_size": hidden_size,
            "input_size": input_size,
            "num_classes": num_classes,
            "architecture": "2-layer MLP"
        }
    )

    config = wandb.config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load datasets
    train_dataset = MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = MNIST('./data', train=False, transform=transform)

    # Split training set into train and validation
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

    # Model setup
    model = BaselineMLP(input_size, hidden_size, num_classes).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Log model to W&B
    wandb.watch(model)

    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Training pipeline
    best_val_acc = train_model(model, train_loader, val_loader, optimizer, criterion, config, device)

    # Load best model and test
    print(f"\nLoading best model (Val Acc: {best_val_acc:.2f}%)")
    model.load_state_dict(torch.load('best_model.pth'))

    test_loss, test_acc, test_preds, test_targets = test_model(model, test_loader, criterion, device)

    print(f"\nFinal Test Results:")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Test Loss: {test_loss:.4f}")

    # Detailed classification report
    print(f"\nClassification Report:")
    print(classification_report(test_targets, test_preds,
                              target_names=[str(i) for i in range(10)],
                              digits=3))

    # Log final results
    wandb.log({
        "best_val_accuracy": best_val_acc,
        "final_test_accuracy": test_acc,
        "final_test_loss": test_loss
    })

    wandb.finish()

    print(f"\nTraining completed! Best validation accuracy: {best_val_acc:.2f}%")

if __name__ == "__main__":
    main()

wandb: Currently logged in as: leonardobiondi (leonardobiondi-universit-degli-studi-di-firenze) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin


Using device: cuda
Dataset sizes - Train: 55000, Val: 5000, Test: 10000
Model parameters: 118,282
Starting training...

Epoch 1/30


Training: 100%|██████████| 860/860 [00:11<00:00, 73.09it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.75it/s]


Train Loss: 0.3252, Train Acc: 90.07%
Val Loss: 0.1292, Val Acc: 96.00%
Best Val Acc: 96.00%

Epoch 2/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.37it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.44it/s]


Train Loss: 0.1587, Train Acc: 95.19%
Val Loss: 0.1028, Val Acc: 97.10%
Best Val Acc: 97.10%

Epoch 3/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.68it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 81.76it/s]


Train Loss: 0.1208, Train Acc: 96.33%
Val Loss: 0.0908, Val Acc: 97.06%
Best Val Acc: 97.10%

Epoch 4/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.50it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.28it/s]


Train Loss: 0.1040, Train Acc: 96.82%
Val Loss: 0.0859, Val Acc: 97.48%
Best Val Acc: 97.48%

Epoch 5/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.05it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.33it/s]


Train Loss: 0.0979, Train Acc: 96.96%
Val Loss: 0.0872, Val Acc: 97.42%
Best Val Acc: 97.48%

Epoch 6/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.52it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.21it/s]


Train Loss: 0.0819, Train Acc: 97.47%
Val Loss: 0.0731, Val Acc: 97.72%
Best Val Acc: 97.72%

Epoch 7/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.56it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.06it/s]


Train Loss: 0.0786, Train Acc: 97.52%
Val Loss: 0.0778, Val Acc: 97.72%
Best Val Acc: 97.72%

Epoch 8/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.51it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.05it/s]


Train Loss: 0.0718, Train Acc: 97.71%
Val Loss: 0.0805, Val Acc: 97.82%
Best Val Acc: 97.82%

Epoch 9/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.34it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.66it/s]


Train Loss: 0.0701, Train Acc: 97.76%
Val Loss: 0.0839, Val Acc: 97.70%
Best Val Acc: 97.82%

Epoch 10/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.32it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.97it/s]


Train Loss: 0.0641, Train Acc: 97.97%
Val Loss: 0.0843, Val Acc: 97.94%
Best Val Acc: 97.94%

Epoch 11/30


Training: 100%|██████████| 860/860 [00:11<00:00, 72.36it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 66.63it/s]


Train Loss: 0.0588, Train Acc: 98.11%
Val Loss: 0.0807, Val Acc: 97.80%
Best Val Acc: 97.94%

Epoch 12/30


Training: 100%|██████████| 860/860 [00:11<00:00, 72.40it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.30it/s]


Train Loss: 0.0581, Train Acc: 98.13%
Val Loss: 0.0730, Val Acc: 98.10%
Best Val Acc: 98.10%

Epoch 13/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.79it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 87.33it/s]


Train Loss: 0.0557, Train Acc: 98.15%
Val Loss: 0.0809, Val Acc: 97.96%
Best Val Acc: 98.10%

Epoch 14/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.83it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 81.04it/s]


Train Loss: 0.0545, Train Acc: 98.19%
Val Loss: 0.0775, Val Acc: 97.92%
Best Val Acc: 98.10%

Epoch 15/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.80it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 78.69it/s]


Train Loss: 0.0497, Train Acc: 98.41%
Val Loss: 0.0873, Val Acc: 97.70%
Best Val Acc: 98.10%

Epoch 16/30


Training: 100%|██████████| 860/860 [00:13<00:00, 64.52it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 85.71it/s]


Train Loss: 0.0481, Train Acc: 98.43%
Val Loss: 0.0894, Val Acc: 97.72%
Best Val Acc: 98.10%

Epoch 17/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.05it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.06it/s]


Train Loss: 0.0499, Train Acc: 98.37%
Val Loss: 0.0839, Val Acc: 98.04%
Best Val Acc: 98.10%

Epoch 18/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.45it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.51it/s]


Train Loss: 0.0479, Train Acc: 98.48%
Val Loss: 0.0810, Val Acc: 97.96%
Best Val Acc: 98.10%

Epoch 19/30


Training: 100%|██████████| 860/860 [00:12<00:00, 71.24it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 72.44it/s]


Train Loss: 0.0442, Train Acc: 98.61%
Val Loss: 0.0859, Val Acc: 97.88%
Best Val Acc: 98.10%

Epoch 20/30


Training: 100%|██████████| 860/860 [00:12<00:00, 68.74it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.30it/s]


Train Loss: 0.0444, Train Acc: 98.63%
Val Loss: 0.0828, Val Acc: 98.04%
Best Val Acc: 98.10%

Epoch 21/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.08it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 82.98it/s]


Train Loss: 0.0442, Train Acc: 98.55%
Val Loss: 0.0932, Val Acc: 97.84%
Best Val Acc: 98.10%

Epoch 22/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.07it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 80.83it/s]


Train Loss: 0.0423, Train Acc: 98.59%
Val Loss: 0.0904, Val Acc: 97.88%
Best Val Acc: 98.10%

Epoch 23/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.30it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 86.41it/s]


Train Loss: 0.0408, Train Acc: 98.70%
Val Loss: 0.0906, Val Acc: 98.18%
Best Val Acc: 98.18%

Epoch 24/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.15it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 80.78it/s]


Train Loss: 0.0398, Train Acc: 98.67%
Val Loss: 0.0854, Val Acc: 97.98%
Best Val Acc: 98.18%

Epoch 25/30


Training: 100%|██████████| 860/860 [00:12<00:00, 70.30it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 83.03it/s]


Train Loss: 0.0403, Train Acc: 98.67%
Val Loss: 0.0900, Val Acc: 98.00%
Best Val Acc: 98.18%

Epoch 26/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.22it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 82.36it/s]


Train Loss: 0.0412, Train Acc: 98.66%
Val Loss: 0.0852, Val Acc: 98.00%
Best Val Acc: 98.18%

Epoch 27/30


Training: 100%|██████████| 860/860 [00:12<00:00, 69.70it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 82.19it/s]


Train Loss: 0.0417, Train Acc: 98.67%
Val Loss: 0.0825, Val Acc: 98.00%
Best Val Acc: 98.18%

Epoch 28/30


Training: 100%|██████████| 860/860 [00:12<00:00, 66.47it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 64.13it/s]


Train Loss: 0.0387, Train Acc: 98.73%
Val Loss: 0.0867, Val Acc: 98.02%
Best Val Acc: 98.18%

Epoch 29/30


Training: 100%|██████████| 860/860 [00:13<00:00, 64.19it/s]
Validating: 100%|██████████| 79/79 [00:01<00:00, 62.57it/s]


Train Loss: 0.0390, Train Acc: 98.69%
Val Loss: 0.0879, Val Acc: 98.14%
Best Val Acc: 98.18%

Epoch 30/30


Training: 100%|██████████| 860/860 [00:12<00:00, 68.83it/s]
Validating: 100%|██████████| 79/79 [00:00<00:00, 84.76it/s]


Train Loss: 0.0377, Train Acc: 98.76%
Val Loss: 0.0954, Val Acc: 98.06%
Best Val Acc: 98.18%

Loading best model (Val Acc: 98.18%)


Testing: 100%|██████████| 157/157 [00:02<00:00, 72.36it/s]



Final Test Results:
Test Accuracy: 98.12%
Test Loss: 0.0747

Classification Report:
              precision    recall  f1-score   support

           0      0.984     0.993     0.988       980
           1      0.986     0.995     0.990      1135
           2      0.982     0.982     0.982      1032
           3      0.980     0.983     0.982      1010
           4      0.981     0.981     0.981       982
           5      0.984     0.969     0.976       892
           6      0.981     0.986     0.984       958
           7      0.973     0.978     0.975      1028
           8      0.982     0.975     0.979       974
           9      0.979     0.968     0.974      1009

    accuracy                          0.981     10000
   macro avg      0.981     0.981     0.981     10000
weighted avg      0.981     0.981     0.981     10000



0,1
best_val_accuracy,▁
epoch,▁▁▁▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▇▇▇▇███
final_test_accuracy,▁
final_test_loss,▁
train_accuracy,▁▅▆▆▇▇▇▇▇▇▇▇██████████████████
train_loss,█▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▅▄▆▆▇▇▇▆▇▇█▇▇▆▇█▇▇█▇▇█▇▇▇▇▇██
val_loss,█▅▃▃▃▁▂▂▂▂▂▁▂▂▃▃▂▂▃▂▄▃▃▃▃▃▂▃▃▄

0,1
best_val_accuracy,98.18
epoch,30.0
final_test_accuracy,98.12
final_test_loss,0.07471
train_accuracy,98.75636
train_loss,0.03767
val_accuracy,98.06
val_loss,0.09542



Training completed! Best validation accuracy: 98.18%


## Considerazioni Finali

### Risultati a Confronto
| Metrica         | Basic MLP | Improved MLP | Differenza |
|-----------------|-----------|--------------|------------|
| Test Accuracy   | 97.65%    | 98.12%       | +0.47%     |
| Test Loss       | 0.1519    | 0.0747       | -50.8%     |
| Train Accuracy  | 99.67%    | 98.76%       | -0.91%     |
| Val Accuracy    | 97.72%    | 98.18%       | +0.46%     |
| Train-Test Gap  | 2.02%     | 0.64%        | -68%       |

---

### Analisi dell'Overfitting

**Basic MLP: Leggero Overfitting**

- Train accuracy: 99.67% vs Test: 97.65% (gap del 2.02%)
- Validation loss in crescita nelle epoche finali
- Il modello memorizza i dati di training senza generalizzare bene

  
**Enhanced MLP: Generalizzazione Ottima**

- Train accuracy: 98.76% vs Test: 98.12% (gap di soli 0.64%)
- Curve di learning più stabili
- Migliore capacità di generalizzazione su dati non visti

---

### Impatto della versione migliorata

**Dropout (0.2):**

- Previene l'overfitting riducendo la dipendenza eccessiva tra i neuroni
- Migliora la robustezza del modello durante il training
- Trade-off controllato: leggera riduzione della train accuracy per migliore generalizzazione

**Best Model Saving:**

- Selezione ottimale: usa il modello con migliore validation performance
- Evita il deterioramento delle prestazioni nelle epoche finali
- Early stopping implicito: ferma l'overfitting al momento giusto

---

### Qualità delle Predizioni

**improved MLP mostra:**

- Test loss dimezzata: 0.0747 vs 0.1519 (maggiore confidenza nelle predizioni)
- Performance più uniformi tra le classi nel classification report
- Stabilità superiore nelle metriche validation

---

### Conclusione
I risultati ottenuti confermano che un **MLP ben addestrato** è già estremamente efficace per la classificazione su **MNIST**.  
L'**Improved MLP**, con le sue tecniche di regolarizzazione, raggiunge prestazioni eccellenti (**98.12% di accuratezza**) dimostrando come l'implementazione corretta di pratiche standard possa fare la differenza tra un modello base (Basic MLP con **2.02% di overfitting**) e uno che **generalizza efficacemente**.

Tuttavia, rimane spazio per ulteriori miglioramenti.  
Architetture più avanzate come le **Convolutional Neural Networks (CNN)** potrebbero sfruttare la **struttura spaziale delle immagini** per raggiungere accuratezze superiori al **99%**, mentre l'implementazione di **early stopping esplicito** (invece del semplice best model saving) potrebbe ottimizzare ulteriormente l'efficienza del training e la stabilità della convergenza.

Nel complesso, questo confronto evidenzia come anche su dataset **relativamente semplici** come MNIST, la differenza tra un approccio **naive** e uno **metodologicamente rigoroso** sia **sostanziale e misurabile**, ponendo le basi per affrontare sfide più complesse con maggiore **consapevolezza tecnica**.


### Exercise 1.2: Adding Residual Connections

Implement a variant of your parameterized MLP network to support **residual** connections. Your network should be defined as a composition of **residual MLP** blocks that have one or more linear layers and add a skip connection from the block input to the output of the final linear layer.

**Compare** the performance (in training/validation loss and test accuracy) of your MLP and ResidualMLP for a range of depths. Verify that deeper networks **with** residual connections are easier to train than a network of the same depth **without** residual connections.

**For extra style points**: See if you can explain by analyzing the gradient magnitudes on a single training batch *why* this is the case.

In [4]:
# Exercise 1.2: Adding Residual Connections

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
from tqdm import tqdm
from sklearn.metrics import classification_report
import wandb

# Standard MLP Block
class MLPBlock(nn.Module):
    def __init__(self, hidden_size, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = self.dropout(out)
        out = self.fc2(out)
        return F.relu(out)

# Residual MLP Block with Skip Connection
class ResidualMLPBlock(nn.Module):
    def __init__(self, hidden_size, dropout_rate=0.1):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, x):
        identity = x  # Skip connection
        out = F.relu(self.fc1(x))
        out = self.dropout(out)
        out = self.fc2(out)
        out = out + identity  # Add skip connection
        return F.relu(out)

# Standard MLP Network
class StandardMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, depth, dropout_rate=0.1):
        super().__init__()
        self.flatten = nn.Flatten()
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.blocks = nn.ModuleList([
            MLPBlock(hidden_size, dropout_rate) for _ in range(depth)
        ])
        self.output_layer = nn.Linear(hidden_size, num_classes)
        self.depth = depth
        
    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.input_layer(x))
        
        for block in self.blocks:
            x = block(x)
            
        x = self.output_layer(x)
        return x

# Residual MLP Network
class ResidualMLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, depth, dropout_rate=0.1):
        super().__init__()
        self.flatten = nn.Flatten()
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.blocks = nn.ModuleList([
            ResidualMLPBlock(hidden_size, dropout_rate) for _ in range(depth)
        ])
        self.output_layer = nn.Linear(hidden_size, num_classes)
        self.depth = depth
        
    def forward(self, x):
        x = self.flatten(x)
        x = F.relu(self.input_layer(x))
        
        for block in self.blocks:
            x = block(x)
            
        x = self.output_layer(x)
        return x

# Training function
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for data, target in tqdm(dataloader, desc="Training", leave=False):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Validation function
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for data, target in tqdm(dataloader, desc="Validating", leave=False):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)

            total_loss += loss.item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

    return total_loss / len(dataloader), 100.0 * correct / total

# Gradient analysis function
def analyze_gradients(model, sample_batch, criterion, device, model_name):
    model.train()
    data, target = sample_batch
    data, target = data.to(device), target.to(device)
    
    # Forward pass
    output = model(data)
    loss = criterion(output, target)
    
    # Backward pass
    model.zero_grad()
    loss.backward()
    
    # Collect gradient norms
    grad_norms = {}
    for name, param in model.named_parameters():
        if param.grad is not None and 'weight' in name:
            grad_norm = param.grad.norm().item()
            grad_norms[name] = grad_norm
    
    return grad_norms

# Training pipeline for comparison
def train_and_compare_models(standard_mlp, residual_mlp, train_loader, val_loader, 
                           config, device, depth):
    print(f"\n{'='*60}")
    print(f"Training Networks with Depth: {depth}")
    print(f"{'='*60}")
    
    # Setup optimizers and criterion
    optimizer_std = optim.Adam(standard_mlp.parameters(), lr=config.learning_rate)
    optimizer_res = optim.Adam(residual_mlp.parameters(), lr=config.learning_rate)
    criterion = nn.CrossEntropyLoss()
    
    # Storage for results
    results = {
        'standard': {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []},
        'residual': {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    }
    
    # Sample batch for gradient analysis
    sample_batch = next(iter(train_loader))
    
    for epoch in range(config.epochs):
        print(f"\nEpoch {epoch + 1}/{config.epochs}")
        
        # Train Standard MLP
        train_loss_std, train_acc_std = train_one_epoch(standard_mlp, train_loader, 
                                                      optimizer_std, criterion, device)
        val_loss_std, val_acc_std = validate(standard_mlp, val_loader, criterion, device)
        
        # Train Residual MLP
        train_loss_res, train_acc_res = train_one_epoch(residual_mlp, train_loader, 
                                                      optimizer_res, criterion, device)
        val_loss_res, val_acc_res = validate(residual_mlp, val_loader, criterion, device)
        
        # Store results
        results['standard']['train_loss'].append(train_loss_std)
        results['standard']['val_loss'].append(val_loss_std)
        results['standard']['train_acc'].append(train_acc_std)
        results['standard']['val_acc'].append(val_acc_std)
        
        results['residual']['train_loss'].append(train_loss_res)
        results['residual']['val_loss'].append(val_loss_res)
        results['residual']['train_acc'].append(train_acc_res)
        results['residual']['val_acc'].append(val_acc_res)
        
        # Log to W&B
        wandb.log({
            f"depth_{depth}/standard_train_loss": train_loss_std,
            f"depth_{depth}/standard_val_loss": val_loss_std,
            f"depth_{depth}/standard_train_acc": train_acc_std,
            f"depth_{depth}/standard_val_acc": val_acc_std,
            f"depth_{depth}/residual_train_loss": train_loss_res,
            f"depth_{depth}/residual_val_loss": val_loss_res,
            f"depth_{depth}/residual_train_acc": train_acc_res,
            f"depth_{depth}/residual_val_acc": val_acc_res,
            "epoch": epoch + 1
        })
        
        # Print results
        print(f"Standard MLP  - Train Loss: {train_loss_std:.4f}, Train Acc: {train_acc_std:.2f}%, "
              f"Val Loss: {val_loss_std:.4f}, Val Acc: {val_acc_std:.2f}%")
        print(f"Residual MLP  - Train Loss: {train_loss_res:.4f}, Train Acc: {train_acc_res:.2f}%, "
              f"Val Loss: {val_loss_res:.4f}, Val Acc: {val_acc_res:.2f}%")
    
    # Gradient analysis at the end
    print(f"\n{'-'*40}")
    print("GRADIENT ANALYSIS")
    print(f"{'-'*40}")
    
    grad_norms_std = analyze_gradients(standard_mlp, sample_batch, criterion, device, "Standard")
    grad_norms_res = analyze_gradients(residual_mlp, sample_batch, criterion, device, "Residual")
    
    print(f"\nGradient Norms Comparison (Depth {depth}):")
    print(f"{'Layer':<30} {'Standard MLP':<15} {'Residual MLP':<15} {'Ratio':<10}")
    print("-" * 70)
    
    for layer_std in grad_norms_std:
        layer_res = layer_std.replace('blocks', 'blocks')  # Same naming convention
        if layer_res in grad_norms_res:
            ratio = grad_norms_res[layer_res] / max(grad_norms_std[layer_std], 1e-8)
            print(f"{layer_std:<30} {grad_norms_std[layer_std]:<15.6f} "
                  f"{grad_norms_res[layer_res]:<15.6f} {ratio:<10.2f}")
    
    # Log gradient analysis
    for layer_name, norm in grad_norms_std.items():
        wandb.log({f"gradients/depth_{depth}/standard_{layer_name}": norm})
    for layer_name, norm in grad_norms_res.items():
        wandb.log({f"gradients/depth_{depth}/residual_{layer_name}": norm})
    
    return results

def main():
    # Model architecture parameters
    input_size = 28 * 28 * 1
    hidden_size = 128
    num_classes = 10
    
    # Training hyperparameters
    epochs = 15
    batch_size = 64
    learning_rate = 0.001
    val_size = 5000
    dropout_rate = 0.1
    
    # Depths to compare
    depths = [1, 3, 5, 8]
    
    # Initialize W&B
    wandb.init(
        project="mnist-residual-mlp-comparison",
        config={
            "epochs": epochs,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "hidden_size": hidden_size,
            "dropout_rate": dropout_rate,
            "depths": depths,
            "architecture": "Standard vs Residual MLP"
        }
    )
    
    config = wandb.config
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Data preparation
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # Load datasets
    train_dataset = MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = MNIST('./data', train=False, transform=transform)
    
    # Split training set
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    
    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    # Compare different depths
    all_results = {}
    
    for depth in depths:
        print(f"\nCreating models with depth {depth}...")
        
        # Create models
        standard_mlp = StandardMLP(input_size, hidden_size, num_classes, depth, dropout_rate).to(device)
        residual_mlp = ResidualMLP(input_size, hidden_size, num_classes, depth, dropout_rate).to(device)
        
        print(f"Standard MLP parameters: {sum(p.numel() for p in standard_mlp.parameters()):,}")
        print(f"Residual MLP parameters: {sum(p.numel() for p in residual_mlp.parameters()):,}")
        
        # Train and compare
        results = train_and_compare_models(standard_mlp, residual_mlp, train_loader, 
                                         val_loader, config, device, depth)
        all_results[depth] = results
        
        # Test final performance
        criterion = nn.CrossEntropyLoss()
        _, test_acc_std = validate(standard_mlp, test_loader, criterion, device)
        _, test_acc_res = validate(residual_mlp, test_loader, criterion, device)
        
        print(f"\nFinal Test Accuracy - Depth {depth}:")
        print(f"Standard MLP: {test_acc_std:.2f}%")
        print(f"Residual MLP: {test_acc_res:.2f}%")
        print(f"Improvement: {test_acc_res - test_acc_std:.2f}%")
        
        # Log final test results
        wandb.log({
            f"final_test/depth_{depth}_standard": test_acc_std,
            f"final_test/depth_{depth}_residual": test_acc_res,
            f"final_test/depth_{depth}_improvement": test_acc_res - test_acc_std
        })
    
    # Summary
    print(f"\n{'='*60}")
    print("EXPERIMENT SUMMARY")
    print(f"{'='*60}")
    print(f"{'Depth':<8} {'Standard Acc':<15} {'Residual Acc':<15} {'Improvement':<12}")
    print("-" * 60)
    
    for depth in depths:
        std_acc = all_results[depth]['standard']['val_acc'][-1]
        res_acc = all_results[depth]['residual']['val_acc'][-1]
        improvement = res_acc - std_acc
        print(f"{depth:<8} {std_acc:<15.2f} {res_acc:<15.2f} {improvement:<12.2f}")
    
    wandb.finish()
    print(f"\nExperiment completed! Check W&B for detailed comparisons.")

if __name__ == "__main__":
    main()

Using device: cuda
Dataset sizes - Train: 55000, Val: 5000, Test: 10000

Creating models with depth 1...
Standard MLP parameters: 134,794
Residual MLP parameters: 134,794

Training Networks with Depth: 1

Epoch 1/15


                                                           

Standard MLP  - Train Loss: 0.2829, Train Acc: 91.35%, Val Loss: 0.1275, Val Acc: 96.02%
Residual MLP  - Train Loss: 0.2493, Train Acc: 92.55%, Val Loss: 0.1393, Val Acc: 95.64%

Epoch 2/15


                                                           

Standard MLP  - Train Loss: 0.1173, Train Acc: 96.38%, Val Loss: 0.0926, Val Acc: 97.32%
Residual MLP  - Train Loss: 0.1056, Train Acc: 96.73%, Val Loss: 0.1003, Val Acc: 96.78%

Epoch 3/15


                                                           

Standard MLP  - Train Loss: 0.0862, Train Acc: 97.33%, Val Loss: 0.0895, Val Acc: 97.22%
Residual MLP  - Train Loss: 0.0749, Train Acc: 97.62%, Val Loss: 0.0888, Val Acc: 97.14%

Epoch 4/15


                                                           

Standard MLP  - Train Loss: 0.0688, Train Acc: 97.83%, Val Loss: 0.0942, Val Acc: 97.22%
Residual MLP  - Train Loss: 0.0571, Train Acc: 98.18%, Val Loss: 0.0943, Val Acc: 97.36%

Epoch 5/15


                                                           

Standard MLP  - Train Loss: 0.0564, Train Acc: 98.18%, Val Loss: 0.0806, Val Acc: 97.70%
Residual MLP  - Train Loss: 0.0495, Train Acc: 98.38%, Val Loss: 0.0882, Val Acc: 97.64%

Epoch 6/15


                                                           

Standard MLP  - Train Loss: 0.0488, Train Acc: 98.39%, Val Loss: 0.0845, Val Acc: 97.74%
Residual MLP  - Train Loss: 0.0392, Train Acc: 98.70%, Val Loss: 0.0850, Val Acc: 97.66%

Epoch 7/15


                                                           

Standard MLP  - Train Loss: 0.0431, Train Acc: 98.58%, Val Loss: 0.1006, Val Acc: 97.26%
Residual MLP  - Train Loss: 0.0351, Train Acc: 98.82%, Val Loss: 0.0758, Val Acc: 97.96%

Epoch 8/15


                                                           

Standard MLP  - Train Loss: 0.0376, Train Acc: 98.74%, Val Loss: 0.1240, Val Acc: 96.74%
Residual MLP  - Train Loss: 0.0305, Train Acc: 99.02%, Val Loss: 0.0853, Val Acc: 97.84%

Epoch 9/15


                                                           

Standard MLP  - Train Loss: 0.0362, Train Acc: 98.82%, Val Loss: 0.0979, Val Acc: 97.72%
Residual MLP  - Train Loss: 0.0265, Train Acc: 99.12%, Val Loss: 0.1063, Val Acc: 97.46%

Epoch 10/15


                                                           

Standard MLP  - Train Loss: 0.0317, Train Acc: 98.94%, Val Loss: 0.1052, Val Acc: 97.62%
Residual MLP  - Train Loss: 0.0253, Train Acc: 99.12%, Val Loss: 0.0954, Val Acc: 97.50%

Epoch 11/15


                                                           

Standard MLP  - Train Loss: 0.0296, Train Acc: 99.07%, Val Loss: 0.1035, Val Acc: 97.68%
Residual MLP  - Train Loss: 0.0201, Train Acc: 99.29%, Val Loss: 0.0834, Val Acc: 98.02%

Epoch 12/15


                                                           

Standard MLP  - Train Loss: 0.0256, Train Acc: 99.14%, Val Loss: 0.1048, Val Acc: 97.62%
Residual MLP  - Train Loss: 0.0205, Train Acc: 99.27%, Val Loss: 0.1150, Val Acc: 97.58%

Epoch 13/15


                                                           

Standard MLP  - Train Loss: 0.0239, Train Acc: 99.19%, Val Loss: 0.1218, Val Acc: 97.66%
Residual MLP  - Train Loss: 0.0219, Train Acc: 99.26%, Val Loss: 0.1105, Val Acc: 97.60%

Epoch 14/15


                                                           

Standard MLP  - Train Loss: 0.0214, Train Acc: 99.30%, Val Loss: 0.1183, Val Acc: 97.62%
Residual MLP  - Train Loss: 0.0171, Train Acc: 99.49%, Val Loss: 0.1034, Val Acc: 97.92%

Epoch 15/15


                                                           

Standard MLP  - Train Loss: 0.0223, Train Acc: 99.28%, Val Loss: 0.1009, Val Acc: 97.90%
Residual MLP  - Train Loss: 0.0138, Train Acc: 99.52%, Val Loss: 0.1392, Val Acc: 97.34%

----------------------------------------
GRADIENT ANALYSIS
----------------------------------------

Gradient Norms Comparison (Depth 1):
Layer                          Standard MLP    Residual MLP    Ratio     
----------------------------------------------------------------------
input_layer.weight             0.006229        0.696967        111.90    
blocks.0.fc1.weight            0.005742        0.561574        97.80     
blocks.0.fc2.weight            0.003113        0.242486        77.90     
output_layer.weight            0.005161        0.781906        151.51    


                                                             


Final Test Accuracy - Depth 1:
Standard MLP: 97.81%
Residual MLP: 97.62%
Improvement: -0.19%

Creating models with depth 3...
Standard MLP parameters: 200,842
Residual MLP parameters: 200,842

Training Networks with Depth: 3

Epoch 1/15


                                                           

Standard MLP  - Train Loss: 0.4357, Train Acc: 85.95%, Val Loss: 0.1800, Val Acc: 95.02%
Residual MLP  - Train Loss: 0.2411, Train Acc: 92.68%, Val Loss: 0.1496, Val Acc: 95.36%

Epoch 2/15


                                                           

Standard MLP  - Train Loss: 0.1640, Train Acc: 95.46%, Val Loss: 0.1487, Val Acc: 95.92%
Residual MLP  - Train Loss: 0.1090, Train Acc: 96.68%, Val Loss: 0.1343, Val Acc: 96.48%

Epoch 3/15


                                                           

Standard MLP  - Train Loss: 0.1217, Train Acc: 96.60%, Val Loss: 0.1397, Val Acc: 96.20%
Residual MLP  - Train Loss: 0.0786, Train Acc: 97.58%, Val Loss: 0.0850, Val Acc: 97.52%

Epoch 4/15


                                                           

Standard MLP  - Train Loss: 0.0989, Train Acc: 97.31%, Val Loss: 0.1171, Val Acc: 96.70%
Residual MLP  - Train Loss: 0.0635, Train Acc: 97.97%, Val Loss: 0.1175, Val Acc: 96.54%

Epoch 5/15


                                                           

Standard MLP  - Train Loss: 0.0861, Train Acc: 97.61%, Val Loss: 0.1127, Val Acc: 97.18%
Residual MLP  - Train Loss: 0.0522, Train Acc: 98.29%, Val Loss: 0.1026, Val Acc: 97.22%

Epoch 6/15


                                                           

Standard MLP  - Train Loss: 0.0783, Train Acc: 97.77%, Val Loss: 0.1124, Val Acc: 97.38%
Residual MLP  - Train Loss: 0.0428, Train Acc: 98.60%, Val Loss: 0.1056, Val Acc: 97.14%

Epoch 7/15


                                                           

Standard MLP  - Train Loss: 0.0674, Train Acc: 98.21%, Val Loss: 0.1295, Val Acc: 96.86%
Residual MLP  - Train Loss: 0.0379, Train Acc: 98.73%, Val Loss: 0.0909, Val Acc: 97.82%

Epoch 8/15


                                                           

Standard MLP  - Train Loss: 0.0593, Train Acc: 98.36%, Val Loss: 0.1332, Val Acc: 96.86%
Residual MLP  - Train Loss: 0.0304, Train Acc: 98.98%, Val Loss: 0.1060, Val Acc: 97.74%

Epoch 9/15


                                                           

Standard MLP  - Train Loss: 0.0574, Train Acc: 98.52%, Val Loss: 0.1152, Val Acc: 97.44%
Residual MLP  - Train Loss: 0.0289, Train Acc: 99.06%, Val Loss: 0.1146, Val Acc: 97.26%

Epoch 10/15


                                                           

Standard MLP  - Train Loss: 0.0500, Train Acc: 98.62%, Val Loss: 0.1047, Val Acc: 97.60%
Residual MLP  - Train Loss: 0.0241, Train Acc: 99.23%, Val Loss: 0.1089, Val Acc: 97.40%

Epoch 11/15


                                                           

Standard MLP  - Train Loss: 0.0466, Train Acc: 98.73%, Val Loss: 0.1256, Val Acc: 97.24%
Residual MLP  - Train Loss: 0.0214, Train Acc: 99.32%, Val Loss: 0.0991, Val Acc: 97.78%

Epoch 12/15


                                                           

Standard MLP  - Train Loss: 0.0447, Train Acc: 98.79%, Val Loss: 0.1072, Val Acc: 97.50%
Residual MLP  - Train Loss: 0.0258, Train Acc: 99.19%, Val Loss: 0.1087, Val Acc: 97.60%

Epoch 13/15


                                                           

Standard MLP  - Train Loss: 0.0374, Train Acc: 98.95%, Val Loss: 0.1157, Val Acc: 97.54%
Residual MLP  - Train Loss: 0.0199, Train Acc: 99.33%, Val Loss: 0.1047, Val Acc: 97.80%

Epoch 14/15


                                                           

Standard MLP  - Train Loss: 0.0355, Train Acc: 99.06%, Val Loss: 0.1199, Val Acc: 97.66%
Residual MLP  - Train Loss: 0.0164, Train Acc: 99.46%, Val Loss: 0.0976, Val Acc: 97.76%

Epoch 15/15


                                                           

Standard MLP  - Train Loss: 0.0376, Train Acc: 98.98%, Val Loss: 0.1094, Val Acc: 97.78%
Residual MLP  - Train Loss: 0.0190, Train Acc: 99.38%, Val Loss: 0.0973, Val Acc: 97.82%

----------------------------------------
GRADIENT ANALYSIS
----------------------------------------

Gradient Norms Comparison (Depth 3):
Layer                          Standard MLP    Residual MLP    Ratio     
----------------------------------------------------------------------
input_layer.weight             0.271185        0.001908        0.01      
blocks.0.fc1.weight            0.323914        0.000958        0.00      
blocks.0.fc2.weight            0.175747        0.000545        0.00      
blocks.1.fc1.weight            0.092680        0.000816        0.01      
blocks.1.fc2.weight            0.053904        0.000476        0.01      
blocks.2.fc1.weight            0.047610        0.000932        0.02      
blocks.2.fc2.weight            0.044799        0.000611        0.01      
output_layer.weight 

                                                             


Final Test Accuracy - Depth 3:
Standard MLP: 97.74%
Residual MLP: 97.79%
Improvement: 0.05%

Creating models with depth 5...
Standard MLP parameters: 266,890
Residual MLP parameters: 266,890

Training Networks with Depth: 5

Epoch 1/15


                                                           

Standard MLP  - Train Loss: 0.6944, Train Acc: 74.61%, Val Loss: 0.2802, Val Acc: 93.08%
Residual MLP  - Train Loss: 0.2344, Train Acc: 92.78%, Val Loss: 0.1300, Val Acc: 96.00%

Epoch 2/15


                                                           

Standard MLP  - Train Loss: 0.2426, Train Acc: 93.94%, Val Loss: 0.2227, Val Acc: 94.72%
Residual MLP  - Train Loss: 0.1085, Train Acc: 96.67%, Val Loss: 0.0972, Val Acc: 97.16%

Epoch 3/15


                                                           

Standard MLP  - Train Loss: 0.1950, Train Acc: 95.26%, Val Loss: 0.1586, Val Acc: 96.24%
Residual MLP  - Train Loss: 0.0763, Train Acc: 97.57%, Val Loss: 0.0955, Val Acc: 97.22%

Epoch 4/15


                                                           

Standard MLP  - Train Loss: 0.1644, Train Acc: 95.99%, Val Loss: 0.1506, Val Acc: 96.40%
Residual MLP  - Train Loss: 0.0611, Train Acc: 98.05%, Val Loss: 0.0947, Val Acc: 97.34%

Epoch 5/15


                                                           

Standard MLP  - Train Loss: 0.1507, Train Acc: 96.43%, Val Loss: 0.1742, Val Acc: 96.22%
Residual MLP  - Train Loss: 0.0500, Train Acc: 98.42%, Val Loss: 0.0962, Val Acc: 97.56%

Epoch 6/15


                                                           

Standard MLP  - Train Loss: 0.1337, Train Acc: 96.80%, Val Loss: 0.1411, Val Acc: 96.94%
Residual MLP  - Train Loss: 0.0451, Train Acc: 98.61%, Val Loss: 0.0834, Val Acc: 97.78%

Epoch 7/15


                                                           

Standard MLP  - Train Loss: 0.1173, Train Acc: 97.20%, Val Loss: 0.1416, Val Acc: 96.94%
Residual MLP  - Train Loss: 0.0364, Train Acc: 98.81%, Val Loss: 0.0867, Val Acc: 97.80%

Epoch 8/15


                                                           

Standard MLP  - Train Loss: 0.1126, Train Acc: 97.42%, Val Loss: 0.1325, Val Acc: 97.22%
Residual MLP  - Train Loss: 0.0327, Train Acc: 98.95%, Val Loss: 0.0983, Val Acc: 97.46%

Epoch 9/15


                                                           

Standard MLP  - Train Loss: 0.1029, Train Acc: 97.60%, Val Loss: 0.1587, Val Acc: 96.88%
Residual MLP  - Train Loss: 0.0300, Train Acc: 99.01%, Val Loss: 0.0987, Val Acc: 97.60%

Epoch 10/15


                                                           

Standard MLP  - Train Loss: 0.1020, Train Acc: 97.60%, Val Loss: 0.1262, Val Acc: 97.10%
Residual MLP  - Train Loss: 0.0266, Train Acc: 99.10%, Val Loss: 0.0899, Val Acc: 97.84%

Epoch 11/15


                                                           

Standard MLP  - Train Loss: 0.0926, Train Acc: 97.69%, Val Loss: 0.1474, Val Acc: 97.36%
Residual MLP  - Train Loss: 0.0235, Train Acc: 99.25%, Val Loss: 0.1007, Val Acc: 97.88%

Epoch 12/15


                                                           

Standard MLP  - Train Loss: 0.0828, Train Acc: 98.14%, Val Loss: 0.1272, Val Acc: 97.44%
Residual MLP  - Train Loss: 0.0222, Train Acc: 99.26%, Val Loss: 0.1173, Val Acc: 97.58%

Epoch 13/15


                                                           

Standard MLP  - Train Loss: 0.0854, Train Acc: 97.96%, Val Loss: 0.1293, Val Acc: 97.30%
Residual MLP  - Train Loss: 0.0212, Train Acc: 99.34%, Val Loss: 0.1141, Val Acc: 97.70%

Epoch 14/15


                                                           

Standard MLP  - Train Loss: 0.0757, Train Acc: 98.15%, Val Loss: 0.1305, Val Acc: 97.48%
Residual MLP  - Train Loss: 0.0182, Train Acc: 99.45%, Val Loss: 0.0993, Val Acc: 97.90%

Epoch 15/15


                                                           

Standard MLP  - Train Loss: 0.0725, Train Acc: 98.28%, Val Loss: 0.1410, Val Acc: 97.26%
Residual MLP  - Train Loss: 0.0188, Train Acc: 99.39%, Val Loss: 0.1071, Val Acc: 97.64%

----------------------------------------
GRADIENT ANALYSIS
----------------------------------------

Gradient Norms Comparison (Depth 5):
Layer                          Standard MLP    Residual MLP    Ratio     
----------------------------------------------------------------------
input_layer.weight             0.070593        0.010738        0.15      
blocks.0.fc1.weight            0.058370        0.005103        0.09      
blocks.0.fc2.weight            0.025526        0.003532        0.14      
blocks.1.fc1.weight            0.011315        0.004429        0.39      
blocks.1.fc2.weight            0.010423        0.002816        0.27      
blocks.2.fc1.weight            0.011475        0.006636        0.58      
blocks.2.fc2.weight            0.015336        0.005074        0.33      
blocks.3.fc1.weight 

                                                             


Final Test Accuracy - Depth 5:
Standard MLP: 97.23%
Residual MLP: 97.86%
Improvement: 0.63%

Creating models with depth 8...
Standard MLP parameters: 365,962
Residual MLP parameters: 365,962

Training Networks with Depth: 8

Epoch 1/15


                                                           

Standard MLP  - Train Loss: 1.3304, Train Acc: 46.73%, Val Loss: 0.7041, Val Acc: 80.42%
Residual MLP  - Train Loss: 0.2377, Train Acc: 92.66%, Val Loss: 0.1280, Val Acc: 96.12%

Epoch 2/15


                                                           

Standard MLP  - Train Loss: 0.7650, Train Acc: 76.78%, Val Loss: 0.7653, Val Acc: 79.68%
Residual MLP  - Train Loss: 0.1131, Train Acc: 96.45%, Val Loss: 0.0966, Val Acc: 97.20%

Epoch 3/15


                                                           

Standard MLP  - Train Loss: 0.7337, Train Acc: 77.68%, Val Loss: 0.7654, Val Acc: 77.18%
Residual MLP  - Train Loss: 0.0801, Train Acc: 97.50%, Val Loss: 0.0963, Val Acc: 97.04%

Epoch 4/15


                                                           

Standard MLP  - Train Loss: 0.9762, Train Acc: 66.08%, Val Loss: 0.9629, Val Acc: 61.50%
Residual MLP  - Train Loss: 0.0667, Train Acc: 97.89%, Val Loss: 0.0985, Val Acc: 97.14%

Epoch 5/15


                                                           

Standard MLP  - Train Loss: 0.7797, Train Acc: 74.90%, Val Loss: 0.6499, Val Acc: 80.74%
Residual MLP  - Train Loss: 0.0543, Train Acc: 98.25%, Val Loss: 0.0889, Val Acc: 97.18%

Epoch 6/15


                                                           

Standard MLP  - Train Loss: 0.6266, Train Acc: 82.14%, Val Loss: 0.6459, Val Acc: 83.20%
Residual MLP  - Train Loss: 0.0442, Train Acc: 98.64%, Val Loss: 0.0888, Val Acc: 97.38%

Epoch 7/15


                                                           

Standard MLP  - Train Loss: 0.6386, Train Acc: 82.23%, Val Loss: 0.8605, Val Acc: 75.86%
Residual MLP  - Train Loss: 0.0394, Train Acc: 98.73%, Val Loss: 0.1140, Val Acc: 97.30%

Epoch 8/15


                                                           

Standard MLP  - Train Loss: 0.6160, Train Acc: 83.58%, Val Loss: 0.4649, Val Acc: 88.02%
Residual MLP  - Train Loss: 0.0343, Train Acc: 98.89%, Val Loss: 0.1337, Val Acc: 96.92%

Epoch 9/15


                                                           

Standard MLP  - Train Loss: 0.7533, Train Acc: 77.38%, Val Loss: 1.3916, Val Acc: 48.30%
Residual MLP  - Train Loss: 0.0345, Train Acc: 98.88%, Val Loss: 0.1158, Val Acc: 97.16%

Epoch 10/15


                                                           

Standard MLP  - Train Loss: 0.7709, Train Acc: 76.28%, Val Loss: 0.8813, Val Acc: 73.48%
Residual MLP  - Train Loss: 0.0289, Train Acc: 99.10%, Val Loss: 0.0922, Val Acc: 97.80%

Epoch 11/15


                                                           

Standard MLP  - Train Loss: 0.9345, Train Acc: 69.73%, Val Loss: 0.7630, Val Acc: 76.22%
Residual MLP  - Train Loss: 0.0265, Train Acc: 99.18%, Val Loss: 0.0951, Val Acc: 97.74%

Epoch 12/15


                                                           

Standard MLP  - Train Loss: 0.9927, Train Acc: 65.35%, Val Loss: 0.8255, Val Acc: 76.60%
Residual MLP  - Train Loss: 0.0237, Train Acc: 99.25%, Val Loss: 0.0986, Val Acc: 97.98%

Epoch 13/15


                                                           

Standard MLP  - Train Loss: 0.7940, Train Acc: 76.48%, Val Loss: 0.5815, Val Acc: 86.06%
Residual MLP  - Train Loss: 0.0224, Train Acc: 99.29%, Val Loss: 0.1013, Val Acc: 97.80%

Epoch 14/15


                                                           

Standard MLP  - Train Loss: 0.7921, Train Acc: 76.49%, Val Loss: 0.7436, Val Acc: 76.46%
Residual MLP  - Train Loss: 0.0232, Train Acc: 99.23%, Val Loss: 0.0939, Val Acc: 97.50%

Epoch 15/15


                                                           

Standard MLP  - Train Loss: 0.8114, Train Acc: 73.11%, Val Loss: 0.6618, Val Acc: 83.40%
Residual MLP  - Train Loss: 0.0174, Train Acc: 99.46%, Val Loss: 0.0923, Val Acc: 98.14%

----------------------------------------
GRADIENT ANALYSIS
----------------------------------------

Gradient Norms Comparison (Depth 8):
Layer                          Standard MLP    Residual MLP    Ratio     
----------------------------------------------------------------------
input_layer.weight             0.261273        0.043286        0.17      
blocks.0.fc1.weight            0.331847        0.027608        0.08      
blocks.0.fc2.weight            0.167810        0.018926        0.11      
blocks.1.fc1.weight            0.156891        0.022001        0.14      
blocks.1.fc2.weight            0.264559        0.016629        0.06      
blocks.2.fc1.weight            0.331135        0.023862        0.07      
blocks.2.fc2.weight            0.373550        0.008972        0.02      
blocks.3.fc1.weight 

                                                             


Final Test Accuracy - Depth 8:
Standard MLP: 84.77%
Residual MLP: 98.07%
Improvement: 13.30%

EXPERIMENT SUMMARY
Depth    Standard Acc    Residual Acc    Improvement 
------------------------------------------------------------
1        97.90           97.34           -0.56       
3        97.78           97.82           0.04        
5        97.26           97.64           0.38        
8        83.40           98.14           14.74       


0,1
depth_1/residual_train_acc,▁▅▆▇▇▇▇▇███████
depth_1/residual_train_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁
depth_1/residual_val_acc,▁▄▅▆▇▇█▇▆▆█▇▇█▆
depth_1/residual_val_loss,█▄▂▃▂▂▁▂▄▃▂▅▅▄█
depth_1/standard_train_acc,▁▅▆▇▇▇▇████████
depth_1/standard_train_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁
depth_1/standard_val_acc,▁▆▅▅▇▇▆▄▇▇▇▇▇▇█
depth_1/standard_val_loss,█▃▂▃▁▂▄▇▄▅▄▅▇▇▄
depth_3/residual_train_acc,▁▅▆▆▇▇▇████████
depth_3/residual_train_loss,█▄▃▂▂▂▂▁▁▁▁▁▁▁▁

0,1
depth_1/residual_train_acc,99.52364
depth_1/residual_train_loss,0.01382
depth_1/residual_val_acc,97.34
depth_1/residual_val_loss,0.13924
depth_1/standard_train_acc,99.27636
depth_1/standard_train_loss,0.02229
depth_1/standard_val_acc,97.9
depth_1/standard_val_loss,0.10094
depth_3/residual_train_acc,99.37636
depth_3/residual_train_loss,0.019



Experiment completed! Check W&B for detailed comparisons.


## Considerazioni Finali 

### Risultati Sperimentali

| Profondità | Standard MLP Acc (%) | Residual MLP Acc (%) | Miglioramento (%) |
|------------|----------------------|------------------------|-------------------|
| 1          | 97.81                | 97.62                  | -0.19             |
| 3          | 97.74                | 97.79                  | +0.05             |
| 5          | 97.23                | 97.86                  | +0.63             |
| 8          | 84.77                | 98.07                  | +13.30            |

---

### Evidenza del Vanishing Gradient Problem

L'analisi dei gradienti rivela chiaramente il **vanishing gradient problem** negli MLP standard:

**Depth 8 - Standard MLP:**
- Primo layer: gradiente norm = 0.261  
- Ultimo layer: gradiente norm = 4.602 (**esplosione**)  
- Layers intermedi: gradienti molto piccoli (~0.02–0.1)

**Depth 8 - Residual MLP:**
- Gradienti **uniformi**: tutti i layer mantengono norme tra 0.006–0.043  
- **Stabilità**: nessuna esplosione o vanishing eccessivo  
- **Training efficace**: validation accuracy del **98.07%**

---

### Analisi del Collasso nella Rete Profonda

Il **crollo drammatico** dell'MLP standard a depth 8 (84.77% vs 98.07%) dimostra concretamente:

- **Vanishing Gradients**: i layer iniziali ricevono gradienti troppo piccoli per apprendere efficacemente  
- **Exploding Gradients**: i layer finali hanno gradienti eccessivi (4.602 vs ~0.01 nei residual)  
- **Mancanza di Convergenza**: training accuracy solo **73.11%** vs **99.46%** nei residual

---

### Skip Connections

Le **residual connections** risolvono il problema attraverso:

**Flusso diretto del gradiente:**  
∇loss → output → skip connection → input

- I gradienti possono "saltare" i layer intermedi  
- Ogni layer riceve un **segnale di gradient robusto**  
- **Training stabile** anche in reti profonde

**Confronto training accuracy (depth 8):**
- Standard MLP: **73.11%** (collasso del training)  
- Residual MLP: **99.46%** (convergenza ottimale)

---

### Implicazioni
Questi risultati confermano empiricamente perché le **residual connections** sono diventate **standard** nell'architettura di reti profonde:

- Stabilizzano il training prevenendo **vanishing/exploding gradients**  
- Abilitano reti più profonde **senza degradazione** delle performance  
- Mantengono la **qualità del gradient flow** attraverso tutti i layer

**Shallow Networks (depth 1–3):**
- Beneficio limitato delle residual connections  
- Standard MLP già capace di training efficace  
- Overhead delle skip connections non sempre vantaggioso

**Medium Networks (depth 5):**
- Miglioramento moderato (+0.63%)  
- Prime evidenze di vanishing gradients nell'MLP standard  
- Residual connections iniziano a mostrare valore

**Deep Networks (depth 8):**
- **Beneficio drammatico** (+13.30%)  
- MLP standard completamente incapace di training  
- Residual MLP mantiene performance eccellenti

---

### Conclusione

L'esperimento dimostra chiaramente che le **residual connections** non sono solo un'**ottimizzazione tecnica**, ma una **necessità fondamentale** per il training di reti profonde.

La differenza di **13.3 punti percentuali** a depth 8 evidenzia come, senza skip connections, l’aumento della profondità diventi **controproducente**, mentre con esse si mantengano **performance ottimali** indipendentemente dalla profondità della rete.

Questo risultato pone le basi **teoriche e pratiche** per comprendere l’importanza delle architetture moderne come **ResNet**, dove le skip connections permettono di trainare reti con **centinaia di layer** mantenendo un **gradient flow efficace**.


### Exercise 1.3: Rinse and Repeat (but with a CNN)

Repeat the verification you did above, but with **Convolutional** Neural Networks. If you were careful about abstracting your model and training code, this should be a simple exercise. Show that **deeper** CNNs *without* residual connections do not always work better and **even deeper** ones *with* residual connections.

**Hint**: You probably should do this exercise using CIFAR-10, since MNIST is *very* easy (at least up to about 99% accuracy).

**Tip**: Feel free to reuse the ResNet building blocks defined in `torchvision.models.resnet` (e.g. [BasicBlock](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L59) which handles the cascade of 3x3 convolutions, skip connections, and optional downsampling). This is an excellent exercise in code diving. 

**Spoiler**: Depending on the optional exercises you plan to do below, you should think *very* carefully about the architectures of your CNNs here (so you can reuse them!).

In [6]:
# Exercise 1.3: Rinse and Repeat (but with a CNN)
# Academic version with essential improvements

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import gc

# Setup device
if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.empty_cache()
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("Using CPU")

# 1. Simple CNN (Baseline)
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, num_classes)
        self.dropout = nn.Dropout(0.3)  # For consistency with MLP approach

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

# 2. Deep CNN (No residual connections)
class DeepCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.4),  # For consistency with MLP approach
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        return self.classifier(x)

# 3. ResNet CNN (With residual connections - from scratch)
class ResNetCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Use ResNet18 architecture but train from scratch for fair comparison
        self.resnet = models.resnet18(weights=False)
        
        # Modify for CIFAR-10 (32x32 instead of ImageNet 224x224)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()  # Remove maxpool for smaller images
        
        # Modify final layer for CIFAR-10 (10 classes)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Training function
def train_model(model, train_loader, val_loader, epochs, device, model_name):
    print(f"\nTraining {model_name} for {epochs} epochs...")
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc

        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    return best_val_acc

# Test function
def test_model(model, test_loader, device, model_name):
    print(f"\nTesting {model_name}...")
    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total
    print(f"{model_name} Test Accuracy: {test_acc:.2f}%")
    return test_acc

# Main experiment
def main():
    print("Starting CNN comparison on CIFAR-10...")
    
    # Standard academic parameters
    epochs = 10
    batch_size = 64
    val_size = 5000
    
    print(f"Configuration: {epochs} epochs, batch size {batch_size}")
    
    # Minimal data augmentation - academic standard
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load CIFAR-10 dataset
    train_dataset = CIFAR10('./data', train=True, download=True, transform=train_transform)
    test_dataset = CIFAR10('./data', train=False, transform=test_transform)
    
    # Split training set into train/validation
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    results = {}
    
    # ===== MODEL 1: Simple CNN =====
    print(f"\n{'='*50}")
    print("TRAINING SIMPLE CNN (BASELINE)")
    print(f"{'='*50}")
    
    simple_cnn = SimpleCNN().to(device)
    params = sum(p.numel() for p in simple_cnn.parameters())
    print(f"Simple CNN parameters: {params:,}")
    
    best_val_acc = train_model(simple_cnn, train_loader, val_loader, epochs, device, "SimpleCNN")
    test_acc = test_model(simple_cnn, test_loader, device, "SimpleCNN")
    
    results['SimpleCNN'] = {
        'val_acc': best_val_acc,
        'test_acc': test_acc,
        'params': params
    }
    
    # Cleanup memory
    del simple_cnn
    torch.cuda.empty_cache()
    gc.collect()
    
    # ===== MODEL 2: Deep CNN =====
    print(f"\n{'='*50}")
    print("TRAINING DEEP CNN (NO RESIDUAL)")
    print(f"{'='*50}")
    
    deep_cnn = DeepCNN().to(device)
    params = sum(p.numel() for p in deep_cnn.parameters())
    print(f"Deep CNN parameters: {params:,}")
    
    best_val_acc = train_model(deep_cnn, train_loader, val_loader, epochs, device, "DeepCNN")
    test_acc = test_model(deep_cnn, test_loader, device, "DeepCNN")
    
    results['DeepCNN'] = {
        'val_acc': best_val_acc,
        'test_acc': test_acc,
        'params': params
    }
    
    # Cleanup memory
    del deep_cnn
    torch.cuda.empty_cache()
    gc.collect()
    
    # ===== MODEL 3: ResNet CNN =====
    print(f"\n{'='*50}")
    print("TRAINING RESNET CNN (WITH RESIDUAL)")
    print(f"{'='*50}")
    
    resnet_cnn = ResNetCNN().to(device)
    params = sum(p.numel() for p in resnet_cnn.parameters())
    print(f"ResNet CNN parameters: {params:,}")
    
    best_val_acc = train_model(resnet_cnn, train_loader, val_loader, epochs, device, "ResNetCNN")
    test_acc = test_model(resnet_cnn, test_loader, device, "ResNetCNN")
    
    results['ResNetCNN'] = {
        'val_acc': best_val_acc,
        'test_acc': test_acc,
        'params': params
    }
    
    # Cleanup memory
    del resnet_cnn
    torch.cuda.empty_cache()
    gc.collect()
    
    # ===== DETAILED ANALYSIS =====
    print(f"\n{'='*70}")
    print("EXPERIMENT SUMMARY - CNN PROGRESSION ON CIFAR-10")
    print(f"{'='*70}")
    print(f"{'Model':<12} {'Val Acc':<10} {'Test Acc':<10} {'Parameters':<12}")
    print("-" * 70)
    
    for model_name, result in results.items():
        print(f"{model_name:<12} {result['val_acc']:<10.2f} {result['test_acc']:<10.2f} {result['params']:<12,}")
    
    # Detailed analysis
    print(f"\n{'='*70}")
    print("ANALYSIS - EFFECT OF RESIDUAL CONNECTIONS:")
    print(f"{'='*70}")
    
    simple_acc = results.get('SimpleCNN', {}).get('test_acc', 0)
    deep_acc = results.get('DeepCNN', {}).get('test_acc', 0)
    resnet_acc = results.get('ResNetCNN', {}).get('test_acc', 0)
    
    print(f"RESULTS:")
    print(f"• Simple CNN:  {simple_acc:.2f}% (baseline with 2 conv layers)")
    print(f"• Deep CNN:    {deep_acc:.2f}% (4 conv layers, no residual connections)")
    print(f"• ResNet CNN:  {resnet_acc:.2f}% (ResNet-18 architecture with residual connections)")
    
    print(f"\n{'='*70}")
    print("Experiment completed successfully!")

if __name__ == "__main__":
    main()

Using GPU: NVIDIA GeForce RTX 3060
Starting CNN comparison on CIFAR-10...
Configuration: 10 epochs, batch size 64
Dataset sizes - Train: 45000, Val: 5000, Test: 10000

TRAINING SIMPLE CNN (BASELINE)
Simple CNN parameters: 8,475,530

Training SimpleCNN for 10 epochs...


Epoch 1/10: 100%|████████████████████████████| 704/704 [00:11<00:00, 59.89it/s]


Epoch 1: Train Acc: 45.98%, Val Acc: 59.14%


Epoch 2/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 58.51it/s]


Epoch 2: Train Acc: 60.41%, Val Acc: 62.60%


Epoch 3/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.90it/s]


Epoch 3: Train Acc: 65.99%, Val Acc: 65.88%


Epoch 4/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 58.09it/s]


Epoch 4: Train Acc: 69.33%, Val Acc: 67.98%


Epoch 5/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.34it/s]


Epoch 5: Train Acc: 71.84%, Val Acc: 72.08%


Epoch 6/10: 100%|████████████████████████████| 704/704 [00:11<00:00, 58.85it/s]


Epoch 6: Train Acc: 74.29%, Val Acc: 73.34%


Epoch 7/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 56.81it/s]


Epoch 7: Train Acc: 76.01%, Val Acc: 73.08%


Epoch 8/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.14it/s]


Epoch 8: Train Acc: 78.00%, Val Acc: 72.32%


Epoch 9/10: 100%|████████████████████████████| 704/704 [00:12<00:00, 55.09it/s]


Epoch 9: Train Acc: 79.41%, Val Acc: 74.36%


Epoch 10/10: 100%|███████████████████████████| 704/704 [00:12<00:00, 58.45it/s]


Epoch 10: Train Acc: 80.75%, Val Acc: 75.02%

Testing SimpleCNN...
SimpleCNN Test Accuracy: 74.85%

TRAINING DEEP CNN (NO RESIDUAL)
Deep CNN parameters: 4,570,762

Training DeepCNN for 10 epochs...


Epoch 1/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 46.53it/s]


Epoch 1: Train Acc: 41.32%, Val Acc: 56.98%


Epoch 2/10: 100%|████████████████████████████| 704/704 [00:14<00:00, 47.32it/s]


Epoch 2: Train Acc: 60.74%, Val Acc: 67.24%


Epoch 3/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 45.97it/s]


Epoch 3: Train Acc: 68.07%, Val Acc: 69.68%


Epoch 4/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 46.65it/s]


Epoch 4: Train Acc: 71.90%, Val Acc: 75.58%


Epoch 5/10: 100%|████████████████████████████| 704/704 [00:14<00:00, 47.11it/s]


Epoch 5: Train Acc: 75.18%, Val Acc: 76.64%


Epoch 6/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 46.66it/s]


Epoch 6: Train Acc: 77.55%, Val Acc: 75.52%


Epoch 7/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 45.40it/s]


Epoch 7: Train Acc: 79.21%, Val Acc: 78.10%


Epoch 8/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 46.63it/s]


Epoch 8: Train Acc: 80.71%, Val Acc: 79.22%


Epoch 9/10: 100%|████████████████████████████| 704/704 [00:15<00:00, 46.35it/s]


Epoch 9: Train Acc: 82.08%, Val Acc: 79.44%


Epoch 10/10: 100%|███████████████████████████| 704/704 [00:15<00:00, 46.68it/s]


Epoch 10: Train Acc: 83.42%, Val Acc: 80.10%

Testing DeepCNN...
DeepCNN Test Accuracy: 79.35%

TRAINING RESNET CNN (WITH RESIDUAL)
ResNet CNN parameters: 11,173,962

Training ResNetCNN for 10 epochs...


Epoch 1/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.09it/s]


Epoch 1: Train Acc: 55.54%, Val Acc: 66.54%


Epoch 2/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.92it/s]


Epoch 2: Train Acc: 73.42%, Val Acc: 71.52%


Epoch 3/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.87it/s]


Epoch 3: Train Acc: 79.29%, Val Acc: 79.36%


Epoch 4/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.88it/s]


Epoch 4: Train Acc: 83.03%, Val Acc: 81.40%


Epoch 5/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.26it/s]


Epoch 5: Train Acc: 85.67%, Val Acc: 83.56%


Epoch 6/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.18it/s]


Epoch 6: Train Acc: 87.86%, Val Acc: 82.98%


Epoch 7/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.19it/s]


Epoch 7: Train Acc: 89.63%, Val Acc: 85.28%


Epoch 8/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.31it/s]


Epoch 8: Train Acc: 91.25%, Val Acc: 86.18%


Epoch 9/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 21.30it/s]


Epoch 9: Train Acc: 92.43%, Val Acc: 85.80%


Epoch 10/10: 100%|███████████████████████████| 704/704 [00:33<00:00, 21.24it/s]


Epoch 10: Train Acc: 93.88%, Val Acc: 86.36%

Testing ResNetCNN...
ResNetCNN Test Accuracy: 84.94%

EXPERIMENT SUMMARY - CNN PROGRESSION ON CIFAR-10
Model        Val Acc    Test Acc   Parameters  
----------------------------------------------------------------------
SimpleCNN    75.02      74.85      8,475,530   
DeepCNN      80.10      79.35      4,570,762   
ResNetCNN    86.36      84.94      11,173,962  

ANALYSIS - EFFECT OF RESIDUAL CONNECTIONS:
RESULTS:
• Simple CNN:  74.85% (baseline with 2 conv layers)
• Deep CNN:    79.35% (4 conv layers, no residual connections)
• ResNet CNN:  84.94% (ResNet-18 architecture with residual connections)

Experiment completed successfully!


## Considerazioni Finali 

### Risultati Sperimentali

| Modello     | Test Accuracy | Parametri | Architettura                      |
|-------------|----------------|-----------|-----------------------------------|
| SimpleCNN   | 74.85%         | 8.5M      | 2 conv layers (baseline)          |
| DeepCNN     | 79.35%         | 4.6M      | 4 conv layers, no residual        |
| ResNetCNN   | 84.94%         | 11.2M     | ResNet-18 with residual connections |

---

### Analisi delle Performance

**Progressione delle accuratezze:**

- SimpleCNN → DeepCNN: **+4.5%** di miglioramento  
- DeepCNN → ResNetCNN: **+5.59%** di miglioramento  
- **Miglioramento totale:** +10.1% (SimpleCNN → ResNetCNN)

---

### Conclusioni

**1. Efficacia della Profondità Standard**  
Il **DeepCNN** dimostra che l'aumento di profondità (da 2 a 4 conv layers) porta a **miglioramenti significativi**, contrariamente al problema del vanishing gradient che spesso si manifesta in reti molto più profonde.

**2. Potenza delle Residual Connections**  
Il **ResNetCNN** ottiene le performance migliori (**84.94%**), dimostrando come le **skip connections** permettano di sfruttare efficacemente **architetture più profonde** (ResNet-18 vs 4 layer CNN).
L’esperimento dimostra che **non è solo la profondità** a determinare le performance, ma **come** la profondità viene implementata.  
Le **residual connections** permettono **gradient flow efficace** anche in reti profonde.



-----
## Exercise 2: Choose at Least One

Below are **three** exercises that ask you to deepen your understanding of Deep Networks for visual recognition. You must choose **at least one** of the below for your final submission -- feel free to do **more**, but at least **ONE** you must submit. Each exercise is designed to require you to dig your hands **deep** into the guts of your models in order to do new and interesting things.

**Note**: These exercises are designed to use your small, custom CNNs and small datasets. This is to keep training times reasonable. If you have a decent GPU, feel free to use pretrained ResNets and larger datasets (e.g. the [Imagenette](https://pytorch.org/vision/0.20/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette) dataset at 160px).

### Exercise 2.1: *Fine-tune* a pre-trained model
Train one of your residual CNN models from Exercise 1.3 on CIFAR-10. Then:
1. Use the pre-trained model as a **feature extractor** (i.e. to extract the feature activations of the layer input into the classifier) on CIFAR-100. Use a **classical** approach (e.g. Linear SVM, K-Nearest Neighbor, or Bayesian Generative Classifier) from scikit-learn to establish a **stable baseline** performance on CIFAR-100 using the features extracted using your CNN.
2. Fine-tune your CNN on the CIFAR-100 training set and compare with your stable baseline. Experiment with different strategies:
    - Unfreeze some of the earlier layers for fine-tuning.
    - Test different optimizers (Adam, SGD, etc.).

Each of these steps will require you to modify your model definition in some way. For 1, you will need to return the activations of the last fully-connected layer (or the global average pooling layer). For 2, you will need to replace the original, 10-class classifier with a new, randomly-initialized 100-class classifier.

In [None]:
# Your code here.

### Exercise 2.2: *Distill* the knowledge from a large model into a smaller one
In this exercise you will see if you can derive a *small* model that performs comparably to a larger one on CIFAR-10. To do this, you will use [Knowledge Distillation](https://arxiv.org/abs/1503.02531):

> Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the Knowledge in a Neural Network, NeurIPS 2015.

To do this:
1. Train one of your best-performing CNNs on CIFAR-10 from Exercise 1.3 above. This will be your **teacher** model.
2. Define a *smaller* variant with about half the number of parameters (change the width and/or depth of the network). Train it on CIFAR-10 and verify that it performs *worse* than your **teacher**. This small network will be your **student** model.
3. Train the **student** using a combination of **hard labels** from the CIFAR-10 training set (cross entropy loss) and **soft labels** from predictions of the **teacher** (Kulback-Leibler loss between teacher and student).

Try to optimize training parameters in order to maximize the performance of the student. It should at least outperform the student trained only on hard labels in Setp 2.

**Tip**: You can save the predictions of the trained teacher network on the training set and adapt your dataloader to provide them together with hard labels. This will **greatly** speed up training compared to performing a forward pass through the teacher for each batch of training.

In [4]:
# Exercise 2.2: Knowledge Distillation
# Distill knowledge from ResNetCNN (teacher) to a smaller model (student)
# Based on Exercise 1.3 code structure

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, random_split, TensorDataset
import torchvision.transforms as transforms
import torchvision.models as models
from tqdm import tqdm
import gc

# Setup device (same as Exercise 1.3)
if torch.cuda.is_available():
    device = torch.device('cuda')
    torch.cuda.empty_cache()
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device('cpu')
    print("Using CPU")

# ===== TEACHER MODEL: ResNetCNN from Exercise 1.3 =====
class ResNetTeacher(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Same as your ResNetCNN from Exercise 1.3
        self.resnet = models.resnet18(weights=None)
        
        # Modify for CIFAR-10 (32x32 instead of ImageNet 224x224)
        self.resnet.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.Identity()  # Remove maxpool for smaller images
        
        # Modify final layer for CIFAR-10 (10 classes)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# ===== STUDENT MODEL: Smaller CNN (about half parameters) =====
class SmallStudent(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        # Smaller version: ~5M parameters vs 11.2M of teacher
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.3)  # Same as your models

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2(x), 2))
        x = F.relu(self.conv3(x))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

# ===== KNOWLEDGE DISTILLATION LOSS FUNCTION =====
def distillation_loss(student_logits, teacher_logits, labels, temperature=4.0, alpha=0.7):
    """
    Combine soft targets (from teacher) and hard targets (true labels)
    
    Args:
        student_logits: Raw outputs from student model
        teacher_logits: Raw outputs from teacher model
        labels: True labels
        temperature: Softmax temperature (higher = softer distributions)
        alpha: Weight for distillation vs hard loss
    
    Returns:
        Combined loss
    """
    # Soft targets from teacher (with temperature)
    soft_targets = F.softmax(teacher_logits / temperature, dim=1)
    soft_student = F.log_softmax(student_logits / temperature, dim=1)
    
    # Distillation loss (KL divergence)
    kl_loss = F.kl_div(soft_student, soft_targets, reduction='batchmean') * (temperature ** 2)
    
    # Hard loss (standard cross-entropy)
    ce_loss = F.cross_entropy(student_logits, labels)
    
    # Combined loss
    total_loss = alpha * kl_loss + (1 - alpha) * ce_loss
    
    return total_loss

# ===== TRAINING FUNCTIONS =====
def train_model(model, train_loader, val_loader, epochs, device, model_name):
    """Standard training function (same structure as Exercise 1.3)"""
    print(f"\nTraining {model_name} for {epochs} epochs...")
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()
    
    best_val_acc = 0
    final_train_acc = 0  # Track final train accuracy
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = 100 * correct / total
        final_train_acc = train_acc  # Update final train accuracy

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc

        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    return best_val_acc, final_train_acc

def save_teacher_predictions(teacher_model, train_loader, device):
    """Save teacher predictions to speed up distillation training"""
    print("Saving teacher predictions for distillation...")
    
    teacher_model.eval()
    all_inputs = []
    all_labels = []
    all_teacher_outputs = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(train_loader, desc="Extracting teacher knowledge"):
            inputs = inputs.to(device)
            teacher_outputs = teacher_model(inputs)
            
            all_inputs.append(inputs.cpu())
            all_labels.append(labels)
            all_teacher_outputs.append(teacher_outputs.cpu())
    
    # Create new dataset with teacher predictions
    inputs_tensor = torch.cat(all_inputs, dim=0)
    labels_tensor = torch.cat(all_labels, dim=0)
    teacher_outputs_tensor = torch.cat(all_teacher_outputs, dim=0)
    
    return TensorDataset(inputs_tensor, labels_tensor, teacher_outputs_tensor)

def train_student_with_distillation(model, distillation_loader, val_loader, epochs, device, temperature=4.0, alpha=0.7):
    """Train student using knowledge distillation"""
    print(f"\nTraining Student with Knowledge Distillation (T={temperature}, α={alpha})...")
    
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    best_val_acc = 0
    final_train_acc = 0  # Track final train accuracy
    
    for epoch in range(epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels, teacher_outputs in tqdm(distillation_loader, desc=f"KD Epoch {epoch+1}/{epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)
            teacher_outputs = teacher_outputs.to(device)

            optimizer.zero_grad()
            student_outputs = model(inputs)
            
            # Knowledge distillation loss
            loss = distillation_loss(student_outputs, teacher_outputs, labels, temperature, alpha)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(student_outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_loss = running_loss / len(distillation_loader)
        train_acc = 100 * correct / total
        final_train_acc = train_acc  # Update final train accuracy

        # Validation phase (same as standard training)
        model.eval()
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()

        val_acc = 100 * val_correct / val_total
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc

        print(f"Epoch {epoch+1}: Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")
    
    return best_val_acc, final_train_acc

def test_model(model, test_loader, device, model_name):
    """Test function (same as Exercise 1.3)"""
    print(f"\nTesting {model_name}...")
    
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    test_acc = 100 * correct / total
    print(f"{model_name} Test Accuracy: {test_acc:.2f}%")
    return test_acc

# ===== MAIN EXPERIMENT =====
def main():
    print("Starting Knowledge Distillation Experiment on CIFAR-10...")
    
    # Same configuration as Exercise 1.3
    epochs_teacher = 10  # Teacher training epochs
    epochs_student = 15  # Student training epochs (more time to learn)
    batch_size = 64
    val_size = 5000
    
    # Same data transforms as Exercise 1.3
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    test_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Load CIFAR-10 dataset (same as Exercise 1.3)
    train_dataset = CIFAR10('./data', train=True, download=True, transform=train_transform)
    test_dataset = CIFAR10('./data', train=False, transform=test_transform)
    
    # Split training set into train/validation
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    
    print(f"Dataset sizes - Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    
    results = {}
    
    # ===== STEP 1: TRAIN TEACHER MODEL =====
    print(f"\n{'='*60}")
    print("STEP 1: TRAINING TEACHER MODEL (ResNetCNN)")
    print(f"{'='*60}")
    
    teacher = ResNetTeacher().to(device)
    teacher_params = sum(p.numel() for p in teacher.parameters())
    print(f"Teacher parameters: {teacher_params:,}")
    
    teacher_val_acc, teacher_train_acc = train_model(teacher, train_loader, val_loader, epochs_teacher, device, "Teacher")
    teacher_test_acc = test_model(teacher, test_loader, device, "Teacher")
    
    results['Teacher'] = {
        'train_acc': teacher_train_acc,
        'val_acc': teacher_val_acc,
        'test_acc': teacher_test_acc,
        'params': teacher_params
    }
    
    # ===== STEP 2: TRAIN STUDENT BASELINE (HARD LABELS ONLY) =====
    print(f"\n{'='*60}")
    print("STEP 2: TRAINING STUDENT BASELINE (HARD LABELS ONLY)")
    print(f"{'='*60}")
    
    student_baseline = SmallStudent().to(device)
    student_params = sum(p.numel() for p in student_baseline.parameters())
    print(f"Student parameters: {student_params:,}")
    print(f"Parameter reduction: {((teacher_params - student_params) / teacher_params * 100):.1f}%")
    
    student_baseline_val_acc, student_baseline_train_acc = train_model(student_baseline, train_loader, val_loader, epochs_student, device, "Student_Baseline")
    student_baseline_test_acc = test_model(student_baseline, test_loader, device, "Student_Baseline")
    
    results['Student_Baseline'] = {
        'train_acc': student_baseline_train_acc,
        'val_acc': student_baseline_val_acc,
        'test_acc': student_baseline_test_acc,
        'params': student_params
    }
    
    # Cleanup memory
    del student_baseline
    torch.cuda.empty_cache()
    gc.collect()
    
    # ===== STEP 3: KNOWLEDGE DISTILLATION =====
    print(f"\n{'='*60}")
    print("STEP 3: KNOWLEDGE DISTILLATION")
    print(f"{'='*60}")
    
    # Save teacher predictions for efficient training
    distillation_dataset = save_teacher_predictions(teacher, train_loader, device)
    distillation_loader = DataLoader(distillation_dataset, batch_size=batch_size, shuffle=True)
    
    # Try different hyperparameter combinations
    hyperparams_to_try = [
        {'temperature': 4.0, 'alpha': 0.7},
        {'temperature': 6.0, 'alpha': 0.8},
        {'temperature': 3.0, 'alpha': 0.6}
    ]
    
    best_kd_test_acc = 0
    best_params = None
    best_kd_train_acc = 0  # Track best train accuracy for KD
    
    for i, params in enumerate(hyperparams_to_try):
        print(f"\n--- Testing hyperparameters {i+1}/3: T={params['temperature']}, α={params['alpha']} ---")
        
        student_kd = SmallStudent().to(device)
        
        kd_val_acc, kd_train_acc = train_student_with_distillation(
            student_kd, distillation_loader, val_loader, epochs_student, device,
            temperature=params['temperature'], alpha=params['alpha']
        )
        kd_test_acc = test_model(student_kd, test_loader, device, f"Student_KD_T{params['temperature']}_A{params['alpha']}")
        
        if kd_test_acc > best_kd_test_acc:
            best_kd_test_acc = kd_test_acc
            best_kd_train_acc = kd_train_acc
            best_params = params
        
        # Cleanup
        del student_kd
        torch.cuda.empty_cache()
        gc.collect()
    
    results['Student_KD'] = {
        'train_acc': best_kd_train_acc,
        'val_acc': kd_val_acc,  # Last val_acc from best run
        'test_acc': best_kd_test_acc,
        'params': student_params
    }
    
    # ===== FINAL RESULTS SUMMARY =====
    print(f"\n{'='*80}")
    print("KNOWLEDGE DISTILLATION EXPERIMENT SUMMARY")
    print(f"{'='*80}")
    print(f"{'Model':<15} {'Train Acc':<11} {'Val Acc':<10} {'Test Acc':<10} {'Parameters':<12} {'Improvement':<12}")
    print("-" * 95)
    
    baseline_test_acc = results['Student_Baseline']['test_acc']
    
    for model_name, result in results.items():
        if model_name == 'Student_Baseline':
            improvement = "Baseline"
        elif model_name == 'Student_KD':
            improvement = f"+{result['test_acc'] - baseline_test_acc:.2f}%"
        else:
            improvement = "Teacher"
            
        print(f"{model_name:<15} {result['train_acc']:<11.2f} {result['val_acc']:<10.2f} {result['test_acc']:<10.2f} {result['params']:<12,} {improvement:<12}")
    
    print(f"\n{'='*80}")
    print("DETAILED ANALYSIS:")
    print(f"{'='*80}")
    
    print(f"RESULTS:")
    print(f"• Teacher (ResNet):     Train: {results['Teacher']['train_acc']:.2f}%, Test: {results['Teacher']['test_acc']:.2f}% with {results['Teacher']['params']:,} parameters")
    print(f"• Student Baseline:     Train: {results['Student_Baseline']['train_acc']:.2f}%, Test: {results['Student_Baseline']['test_acc']:.2f}% with {results['Student_Baseline']['params']:,} parameters")
    print(f"• Student + KD:         Train: {results['Student_KD']['train_acc']:.2f}%, Test: {results['Student_KD']['test_acc']:.2f}% with {results['Student_KD']['params']:,} parameters")
    print(f"• Best KD parameters:   T={best_params['temperature']}, α={best_params['alpha']}")
    
    improvement = best_kd_test_acc - baseline_test_acc
    print(f"\nKNOWLEDGE DISTILLATION EFFECTIVENESS:")
    print(f"• Improvement over baseline: {improvement:+.2f}%")
    
    # Add overfitting analysis
    print(f"\nOVERFITTING ANALYSIS:")
    teacher_gap = results['Teacher']['train_acc'] - results['Teacher']['test_acc']
    baseline_gap = results['Student_Baseline']['train_acc'] - results['Student_Baseline']['test_acc']
    kd_gap = results['Student_KD']['train_acc'] - results['Student_KD']['test_acc']
    
    print(f"• Teacher train-test gap:       {teacher_gap:+.2f}%")
    print(f"• Student baseline gap:         {baseline_gap:+.2f}%")
    print(f"• Student KD gap:               {kd_gap:+.2f}%")
    
    if improvement > 0:
        print(f"SUCCESS: Knowledge distillation improved student performance!")
    else:
        print(f"Knowledge distillation did not improve performance")
    
    # Teacher vs Student comparison
    teacher_gap = results['Teacher']['test_acc'] - results['Student_KD']['test_acc']
    param_efficiency = teacher_gap / (results['Teacher']['params'] - results['Student_KD']['params']) * 1000000
    
    print(f"\nMODEL EFFICIENCY:")
    print(f"• Teacher-Student gap: {teacher_gap:.2f}%")
    print(f"• Parameter efficiency: {param_efficiency:.2f} accuracy loss per 1M parameters saved")
    
    print(f"\n{'='*80}")
    print("Experiment completed successfully!")

if __name__ == "__main__":
    main()

Using GPU: NVIDIA GeForce RTX 3060
Starting Knowledge Distillation Experiment on CIFAR-10...
Dataset sizes - Train: 45000, Val: 5000, Test: 10000

STEP 1: TRAINING TEACHER MODEL (ResNetCNN)
Teacher parameters: 11,173,962

Training Teacher for 10 epochs...


Epoch 1/10: 100%|████████████████████████████| 704/704 [00:37<00:00, 19.02it/s]


Epoch 1: Train Acc: 54.85%, Val Acc: 58.98%


Epoch 2/10: 100%|████████████████████████████| 704/704 [00:36<00:00, 19.43it/s]


Epoch 2: Train Acc: 72.94%, Val Acc: 74.16%


Epoch 3/10: 100%|████████████████████████████| 704/704 [00:36<00:00, 19.50it/s]


Epoch 3: Train Acc: 79.28%, Val Acc: 76.48%


Epoch 4/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.72it/s]


Epoch 4: Train Acc: 83.13%, Val Acc: 77.00%


Epoch 5/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.95it/s]


Epoch 5: Train Acc: 85.73%, Val Acc: 81.36%


Epoch 6/10: 100%|████████████████████████████| 704/704 [00:33<00:00, 20.78it/s]


Epoch 6: Train Acc: 87.86%, Val Acc: 83.34%


Epoch 7/10: 100%|████████████████████████████| 704/704 [00:34<00:00, 20.47it/s]


Epoch 7: Train Acc: 89.73%, Val Acc: 84.36%


Epoch 8/10: 100%|████████████████████████████| 704/704 [00:34<00:00, 20.25it/s]


Epoch 8: Train Acc: 91.26%, Val Acc: 85.28%


Epoch 9/10: 100%|████████████████████████████| 704/704 [00:34<00:00, 20.19it/s]


Epoch 9: Train Acc: 92.79%, Val Acc: 85.92%


Epoch 10/10: 100%|███████████████████████████| 704/704 [00:34<00:00, 20.32it/s]


Epoch 10: Train Acc: 93.69%, Val Acc: 84.64%

Testing Teacher...
Teacher Test Accuracy: 84.63%

STEP 2: TRAINING STUDENT BASELINE (HARD LABELS ONLY)
Student parameters: 2,159,114
Parameter reduction: 80.7%

Training Student_Baseline for 15 epochs...


Epoch 1/15: 100%|████████████████████████████| 704/704 [00:11<00:00, 63.83it/s]


Epoch 1: Train Acc: 43.42%, Val Acc: 53.24%


Epoch 2/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.38it/s]


Epoch 2: Train Acc: 58.46%, Val Acc: 63.54%


Epoch 3/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 58.52it/s]


Epoch 3: Train Acc: 64.42%, Val Acc: 66.36%


Epoch 4/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 56.30it/s]


Epoch 4: Train Acc: 68.67%, Val Acc: 69.54%


Epoch 5/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.35it/s]


Epoch 5: Train Acc: 71.19%, Val Acc: 71.86%


Epoch 6/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 58.25it/s]


Epoch 6: Train Acc: 73.77%, Val Acc: 72.72%


Epoch 7/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 56.92it/s]


Epoch 7: Train Acc: 75.81%, Val Acc: 73.42%


Epoch 8/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 56.91it/s]


Epoch 8: Train Acc: 77.39%, Val Acc: 73.80%


Epoch 9/15: 100%|████████████████████████████| 704/704 [00:12<00:00, 57.52it/s]


Epoch 9: Train Acc: 78.95%, Val Acc: 74.72%


Epoch 10/15: 100%|███████████████████████████| 704/704 [00:12<00:00, 58.30it/s]


Epoch 10: Train Acc: 80.32%, Val Acc: 75.22%


Epoch 11/15: 100%|███████████████████████████| 704/704 [00:12<00:00, 56.75it/s]


Epoch 11: Train Acc: 81.57%, Val Acc: 75.88%


Epoch 12/15: 100%|███████████████████████████| 704/704 [00:12<00:00, 57.38it/s]


Epoch 12: Train Acc: 82.80%, Val Acc: 75.60%


Epoch 13/15: 100%|███████████████████████████| 704/704 [00:12<00:00, 57.28it/s]


Epoch 13: Train Acc: 84.25%, Val Acc: 74.94%


Epoch 14/15: 100%|███████████████████████████| 704/704 [00:11<00:00, 59.33it/s]


Epoch 14: Train Acc: 85.01%, Val Acc: 75.26%


Epoch 15/15: 100%|███████████████████████████| 704/704 [00:11<00:00, 58.89it/s]


Epoch 15: Train Acc: 85.96%, Val Acc: 74.88%

Testing Student_Baseline...
Student_Baseline Test Accuracy: 74.93%

STEP 3: KNOWLEDGE DISTILLATION
Saving teacher predictions for distillation...


Extracting teacher knowledge: 100%|██████████| 704/704 [00:15<00:00, 45.75it/s]



--- Testing hyperparameters 1/3: T=4.0, α=0.7 ---

Training Student with Knowledge Distillation (T=4.0, α=0.7)...


KD Epoch 1/15: 100%|████████████████████████| 704/704 [00:03<00:00, 231.40it/s]


Epoch 1: Train Acc: 42.36%, Val Acc: 54.56%


KD Epoch 2/15: 100%|████████████████████████| 704/704 [00:03<00:00, 231.82it/s]


Epoch 2: Train Acc: 56.87%, Val Acc: 61.32%


KD Epoch 3/15: 100%|████████████████████████| 704/704 [00:03<00:00, 219.02it/s]


Epoch 3: Train Acc: 63.29%, Val Acc: 64.88%


KD Epoch 4/15: 100%|████████████████████████| 704/704 [00:03<00:00, 230.02it/s]


Epoch 4: Train Acc: 68.51%, Val Acc: 67.78%


KD Epoch 5/15: 100%|████████████████████████| 704/704 [00:03<00:00, 226.11it/s]


Epoch 5: Train Acc: 71.69%, Val Acc: 71.26%


KD Epoch 6/15: 100%|████████████████████████| 704/704 [00:03<00:00, 219.78it/s]


Epoch 6: Train Acc: 74.20%, Val Acc: 72.68%


KD Epoch 7/15: 100%|████████████████████████| 704/704 [00:03<00:00, 229.29it/s]


Epoch 7: Train Acc: 76.35%, Val Acc: 73.34%


KD Epoch 8/15: 100%|████████████████████████| 704/704 [00:03<00:00, 226.64it/s]


Epoch 8: Train Acc: 78.55%, Val Acc: 74.34%


KD Epoch 9/15: 100%|████████████████████████| 704/704 [00:03<00:00, 224.46it/s]


Epoch 9: Train Acc: 79.85%, Val Acc: 75.00%


KD Epoch 10/15: 100%|███████████████████████| 704/704 [00:03<00:00, 217.03it/s]


Epoch 10: Train Acc: 81.34%, Val Acc: 75.26%


KD Epoch 11/15: 100%|███████████████████████| 704/704 [00:03<00:00, 216.18it/s]


Epoch 11: Train Acc: 82.68%, Val Acc: 76.30%


KD Epoch 12/15: 100%|███████████████████████| 704/704 [00:03<00:00, 213.17it/s]


Epoch 12: Train Acc: 83.44%, Val Acc: 77.26%


KD Epoch 13/15: 100%|███████████████████████| 704/704 [00:03<00:00, 198.79it/s]


Epoch 13: Train Acc: 84.72%, Val Acc: 77.28%


KD Epoch 14/15: 100%|███████████████████████| 704/704 [00:03<00:00, 209.73it/s]


Epoch 14: Train Acc: 85.60%, Val Acc: 77.16%


KD Epoch 15/15: 100%|███████████████████████| 704/704 [00:03<00:00, 218.73it/s]


Epoch 15: Train Acc: 86.65%, Val Acc: 77.30%

Testing Student_KD_T4.0_A0.7...
Student_KD_T4.0_A0.7 Test Accuracy: 76.82%

--- Testing hyperparameters 2/3: T=6.0, α=0.8 ---

Training Student with Knowledge Distillation (T=6.0, α=0.8)...


KD Epoch 1/15: 100%|████████████████████████| 704/704 [00:03<00:00, 199.28it/s]


Epoch 1: Train Acc: 40.84%, Val Acc: 42.10%


KD Epoch 2/15: 100%|████████████████████████| 704/704 [00:03<00:00, 224.13it/s]


Epoch 2: Train Acc: 56.05%, Val Acc: 61.42%


KD Epoch 3/15: 100%|████████████████████████| 704/704 [00:03<00:00, 227.28it/s]


Epoch 3: Train Acc: 63.10%, Val Acc: 64.62%


KD Epoch 4/15: 100%|████████████████████████| 704/704 [00:03<00:00, 215.65it/s]


Epoch 4: Train Acc: 66.98%, Val Acc: 68.54%


KD Epoch 5/15: 100%|████████████████████████| 704/704 [00:03<00:00, 223.34it/s]


Epoch 5: Train Acc: 70.19%, Val Acc: 69.84%


KD Epoch 6/15: 100%|████████████████████████| 704/704 [00:03<00:00, 199.43it/s]


Epoch 6: Train Acc: 72.59%, Val Acc: 71.60%


KD Epoch 7/15: 100%|████████████████████████| 704/704 [00:03<00:00, 191.98it/s]


Epoch 7: Train Acc: 74.75%, Val Acc: 72.52%


KD Epoch 8/15: 100%|████████████████████████| 704/704 [00:03<00:00, 222.43it/s]


Epoch 8: Train Acc: 76.39%, Val Acc: 74.50%


KD Epoch 9/15: 100%|████████████████████████| 704/704 [00:03<00:00, 218.29it/s]


Epoch 9: Train Acc: 77.72%, Val Acc: 74.30%


KD Epoch 10/15: 100%|███████████████████████| 704/704 [00:03<00:00, 220.76it/s]


Epoch 10: Train Acc: 79.02%, Val Acc: 75.24%


KD Epoch 11/15: 100%|███████████████████████| 704/704 [00:03<00:00, 205.79it/s]


Epoch 11: Train Acc: 80.25%, Val Acc: 73.96%


KD Epoch 12/15: 100%|███████████████████████| 704/704 [00:03<00:00, 223.01it/s]


Epoch 12: Train Acc: 81.27%, Val Acc: 73.68%


KD Epoch 13/15: 100%|███████████████████████| 704/704 [00:03<00:00, 199.45it/s]


Epoch 13: Train Acc: 82.05%, Val Acc: 75.64%


KD Epoch 14/15: 100%|███████████████████████| 704/704 [00:03<00:00, 200.02it/s]


Epoch 14: Train Acc: 83.10%, Val Acc: 75.44%


KD Epoch 15/15: 100%|███████████████████████| 704/704 [00:03<00:00, 207.07it/s]


Epoch 15: Train Acc: 83.78%, Val Acc: 76.66%

Testing Student_KD_T6.0_A0.8...
Student_KD_T6.0_A0.8 Test Accuracy: 76.48%

--- Testing hyperparameters 3/3: T=3.0, α=0.6 ---

Training Student with Knowledge Distillation (T=3.0, α=0.6)...


KD Epoch 1/15: 100%|████████████████████████| 704/704 [00:03<00:00, 180.79it/s]


Epoch 1: Train Acc: 43.79%, Val Acc: 53.60%


KD Epoch 2/15: 100%|████████████████████████| 704/704 [00:03<00:00, 218.07it/s]


Epoch 2: Train Acc: 58.30%, Val Acc: 62.94%


KD Epoch 3/15: 100%|████████████████████████| 704/704 [00:03<00:00, 193.58it/s]


Epoch 3: Train Acc: 65.18%, Val Acc: 66.10%


KD Epoch 4/15: 100%|████████████████████████| 704/704 [00:03<00:00, 225.84it/s]


Epoch 4: Train Acc: 69.53%, Val Acc: 70.32%


KD Epoch 5/15: 100%|████████████████████████| 704/704 [00:03<00:00, 209.03it/s]


Epoch 5: Train Acc: 72.94%, Val Acc: 70.42%


KD Epoch 6/15: 100%|████████████████████████| 704/704 [00:03<00:00, 218.61it/s]


Epoch 6: Train Acc: 75.78%, Val Acc: 71.14%


KD Epoch 7/15: 100%|████████████████████████| 704/704 [00:03<00:00, 219.60it/s]


Epoch 7: Train Acc: 78.07%, Val Acc: 73.06%


KD Epoch 8/15: 100%|████████████████████████| 704/704 [00:03<00:00, 224.59it/s]


Epoch 8: Train Acc: 79.56%, Val Acc: 75.64%


KD Epoch 9/15: 100%|████████████████████████| 704/704 [00:03<00:00, 232.53it/s]


Epoch 9: Train Acc: 81.69%, Val Acc: 75.74%


KD Epoch 10/15: 100%|███████████████████████| 704/704 [00:03<00:00, 219.81it/s]


Epoch 10: Train Acc: 83.07%, Val Acc: 76.32%


KD Epoch 11/15: 100%|███████████████████████| 704/704 [00:03<00:00, 230.93it/s]


Epoch 11: Train Acc: 84.72%, Val Acc: 77.30%


KD Epoch 12/15: 100%|███████████████████████| 704/704 [00:03<00:00, 220.50it/s]


Epoch 12: Train Acc: 86.28%, Val Acc: 75.12%


KD Epoch 13/15: 100%|███████████████████████| 704/704 [00:03<00:00, 229.47it/s]


Epoch 13: Train Acc: 87.28%, Val Acc: 75.86%


KD Epoch 14/15: 100%|███████████████████████| 704/704 [00:03<00:00, 231.12it/s]


Epoch 14: Train Acc: 88.31%, Val Acc: 76.74%


KD Epoch 15/15: 100%|███████████████████████| 704/704 [00:03<00:00, 232.34it/s]


Epoch 15: Train Acc: 89.21%, Val Acc: 76.16%

Testing Student_KD_T3.0_A0.6...
Student_KD_T3.0_A0.6 Test Accuracy: 76.64%

KNOWLEDGE DISTILLATION EXPERIMENT SUMMARY
Model           Train Acc   Val Acc    Test Acc   Parameters   Improvement 
-----------------------------------------------------------------------------------------------
Teacher         93.69       85.92      84.63      11,173,962   Teacher     
Student_Baseline 85.96       75.88      74.93      2,159,114    Baseline    
Student_KD      86.65       77.30      76.82      2,159,114    +1.89%      

DETAILED ANALYSIS:
RESULTS:
• Teacher (ResNet):     Train: 93.69%, Test: 84.63% with 11,173,962 parameters
• Student Baseline:     Train: 85.96%, Test: 74.93% with 2,159,114 parameters
• Student + KD:         Train: 86.65%, Test: 76.82% with 2,159,114 parameters
• Best KD parameters:   T=4.0, α=0.7

KNOWLEDGE DISTILLATION EFFECTIVENESS:
• Improvement over baseline: +1.89%

OVERFITTING ANALYSIS:
• Teacher train-test gap:       +9.0

### Architetture Sperimentate

**Teacher Model: ResNetCNN**
- Architettura: ResNet-18 modificata per CIFAR-10  
- Parametri: 11,173,962 (~11.2M)  
- Performance: **83.96%** test accuracy  
- Ruolo: Fornire soft labels ricche di informazione per guidare l’apprendimento dello student

**Student Model: SmallStudent**
- Architettura: CNN compatta con 3 layer convoluzionali + 2 fully connected  
- Parametri: 2,159,114 (~2.2M) → riduzione dell’**80%** rispetto al teacher  
- Design: Struttura simile al teacher ma più compatta

---

### Metodologia Sperimentale

**Step 1: Training del Teacher**  
- Allenato per 10 epochs con cross-entropy standard  
- Test accuracy: **84.63%**

**Step 2: Student Baseline**  
- Allenato per 15 epochs con **sole hard labels**  
- Test accuracy: **74.93%**

**Step 3: Knowledge Distillation**  
- Combinazione di due loss functions:
  - **Soft Loss**: KL divergence tra le predizioni dello student e le soft labels del teacher (temperatura T)
  - **Hard Loss**: Cross-entropy tra predizioni dello student e true labels  
- Formula: `total_loss = α × soft_loss + (1−α) × hard_loss`

---

### Ottimizzazione Hyperparameters

Combinazioni testate:
- `T=4.0`, `α=0.7` (configurazione standard)  
- `T=6.0`, `α=0.8` (più soft, maggiore fiducia nel teacher)  
- `T=3.0`, `α=0.6` (più conservativa)

---

### Risultati Ottenuti

| Modello           | Train Accuracy | Val Accuracy | Test Accuracy | Parametri | Miglioramento | Train-Test Gap |
|-------------------|----------------|--------------|----------------|-----------|---------------|----------------|
| Teacher (ResNet)  | 93.69%         | 85.92%       | 84.63%         | 11.2M     | —             | +9.06%         |
| Student Baseline  | 85.96%         | 75.88%       | 74.93%         | 2.2M      | baseline      | +11.03%        |
| Student + KD      | 86.65%         | 77.30%       | 76.82%         | 2.2M      | **+1.89%**     | +9.83%         |


**Performance Ottimali:**
- Migliori hyperparameters: `T=4.0`, `α=0.7`
- Miglioramento ottenuto: **+1.89%** rispetto al baseline student  
- Gap teacher-student: **7.23%**  
- Riduzione parametri: **80.7%** (da 11.2M a 2.2M)

---

## Conclusioni
L'esperimento di **knowledge distillation** ha **superato gli obiettivi prefissati**, dimostrando con **eccellenti risultati** che è possibile **trasferire efficacemente la conoscenza** da un modello complesso a uno compatto.

Il miglioramento di **+1.89%** in test accuracy, combinato con la **riduzione dell’overfitting** (il train-test gap dello student migliora da **+11.03%** (baseline) a **+9.83%** (con KD), avvicinandosi al comportamento del teacher (**+9.06%**)), conferma la **validità della tecnica** proposta da Hinton et al. e la sua **applicabilità pratica**.

### Exercise 2.3: *Explain* the predictions of a CNN

Use the CNN model you trained in Exercise 1.3 and implement [*Class Activation Maps*](http://cnnlocalization.csail.mit.edu/#:~:text=A%20class%20activation%20map%20for,decision%20made%20by%20the%20CNN.):

> B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba. Learning Deep Features for Discriminative Localization. CVPR'16 (arXiv:1512.04150, 2015).

Use your CNN implementation to demonstrate how your trained CNN *attends* to specific image features to recognize *specific* classes. Try your implementation out using a pre-trained ResNet-18 model and some images from the [Imagenette](https://pytorch.org/vision/0.20/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette) dataset -- I suggest you start with the low resolution version of images at 160px.

In [None]:
# Your code here.