# Deep Learning Applications: Laboratory #1

In this first laboratory we will work relatively simple architectures to get a feel for working with Deep Models. This notebook is designed to work with PyTorch, but as I said in the introductory lecture: please feel free to use and experiment with whatever tools you like.

**Important Notes**:
1. Be sure to **document** all of your decisions, as well as your intermediate and final results. Make sure your conclusions and analyses are clearly presented. Don't make us dig into your code or walls of printed results to try to draw conclusions from your code.
2. If you use code from someone else (e.g. Github, Stack Overflow, ChatGPT, etc) you **must be transparent about it**. Document your sources and explain how you adapted any partial solutions to creat **your** solution.



## Exercise 1: Warming Up
In this series of exercises I want you to try to duplicate (on a small scale) the results of the ResNet paper:

> [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385), Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun, CVPR 2016.

We will do this in steps using a Multilayer Perceptron on MNIST.

Recall that the main message of the ResNet paper is that **deeper** networks do not **guarantee** more reduction in training loss (or in validation accuracy). Below you will incrementally build a sequence of experiments to verify this for an MLP. A few guidelines:

+ I have provided some **starter** code at the beginning. **NONE** of this code should survive in your solutions. Not only is it **very** badly written, it is also written in my functional style that also obfuscates what it's doing (in part to **discourage** your reuse!). It's just to get you *started*.
+ These exercises ask you to compare **multiple** training runs, so it is **really** important that you factor this into your **pipeline**. Using [Tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) is a **very** good idea -- or, even better [Weights and Biases](https://wandb.ai/site).
+ You may work and submit your solutions in **groups of at most two**. Share your ideas with everyone, but the solutions you submit *must be your own*.

First some boilerplate to get you started, then on to the actual exercises!

### Preface: Some code to get you started

What follows is some **very simple** code for training an MLP on MNIST. The point of this code is to get you up and running (and to verify that your Python environment has all needed dependencies).

**Note**: As you read through my code and execute it, this would be a good time to think about *abstracting* **your** model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

In [1]:
# Start with some standard imports.
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
import torch
from torchvision.datasets import MNIST
from torch.utils.data import Subset
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

#### Data preparation

Here is some basic dataset loading, validation splitting code to get you started working with MNIST.

In [None]:
# Standard MNIST transform.
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load MNIST train and test.
ds_train = MNIST(root='./data', train=True, download=True, transform=transform)
ds_test = MNIST(root='./data', train=False, download=True, transform=transform)

# Split train into train and validation.
val_size = 5000
I = np.random.permutation(len(ds_train))
ds_val = Subset(ds_train, I[:val_size])
ds_train = Subset(ds_train, I[val_size:])

#### Boilerplate training and evaluation code

This is some **very** rough training, evaluation, and plotting code. Again, just to get you started. I will be *very* disappointed if any of this code makes it into your final submission.

In [None]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report

# Function to train a model for a single epoch over the data loader.
def train_epoch(model, dl, opt, epoch='Unknown', device='cpu'):
    model.train()
    losses = []
    for (xs, ys) in tqdm(dl, desc=f'Training epoch {epoch}', leave=True):
        xs = xs.to(device)
        ys = ys.to(device)
        opt.zero_grad()
        logits = model(xs)
        loss = F.cross_entropy(logits, ys)
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return np.mean(losses)

# Function to evaluate model over all samples in the data loader.
def evaluate_model(model, dl, device='cpu'):
    model.eval()
    predictions = []
    gts = []
    for (xs, ys) in tqdm(dl, desc='Evaluating', leave=False):
        xs = xs.to(device)
        preds = torch.argmax(model(xs), dim=1)
        gts.append(ys)
        predictions.append(preds.detach().cpu().numpy())
        
    # Return accuracy score and classification report.
    return (accuracy_score(np.hstack(gts), np.hstack(predictions)),
            classification_report(np.hstack(gts), np.hstack(predictions), zero_division=0, digits=3))

# Simple function to plot the loss curve and validation accuracy.
def plot_validation_curves(losses_and_accs):
    losses = [x for (x, _) in losses_and_accs]
    accs = [x for (_, x) in losses_and_accs]
    plt.figure(figsize=(16, 8))
    plt.subplot(1, 2, 1)
    plt.plot(losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Average Training Loss per Epoch')
    plt.subplot(1, 2, 2)
    plt.plot(accs)
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title(f'Best Accuracy = {np.max(accs)} @ epoch {np.argmax(accs)}')

#### A basic, parameterized MLP

This is a very basic implementation of a Multilayer Perceptron. Don't waste too much time trying to figure out how it works -- the important detail is that it allows you to pass in a list of input, hidden layer, and output *widths*. **Your** implementation should also support this for the exercises to come.

In [None]:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(nin, nout) for (nin, nout) in zip(layer_sizes[:-1], layer_sizes[1:])])
    
    def forward(self, x):
        return reduce(lambda f, g: lambda x: g(F.relu(f(x))), self.layers, lambda x: x.flatten(1))(x)

#### A *very* minimal training pipeline.

Here is some basic training and evaluation code to get you started.

**Important**: I cannot stress enough that this is a **terrible** example of how to implement a training pipeline. You can do better!

In [None]:
# Training hyperparameters.
device = 'cuda' if torch.cuda.is_available else 'cpu'
epochs = 100
lr = 0.0001
batch_size = 128

# Architecture hyperparameters.
input_size = 28*28
width = 16
depth = 2

# Dataloaders.
dl_train = torch.utils.data.DataLoader(ds_train, batch_size, shuffle=True, num_workers=4)
dl_val   = torch.utils.data.DataLoader(ds_val, batch_size, num_workers=4)
dl_test  = torch.utils.data.DataLoader(ds_test, batch_size, shuffle=True, num_workers=4)

# Instantiate model and optimizer.
model_mlp = MLP([input_size] + [width]*depth + [10]).to(device)
opt = torch.optim.Adam(params=model_mlp.parameters(), lr=lr)

# Training loop.
losses_and_accs = []
for epoch in range(epochs):
    loss = train_epoch(model_mlp, dl_train, opt, epoch, device=device)
    (val_acc, _) = evaluate_model(model_mlp, dl_val, device=device)
    losses_and_accs.append((loss, val_acc))

# And finally plot the curves.
plot_validation_curves(losses_and_accs)
print(f'Accuracy report on TEST:\n {evaluate_model(model_mlp, dl_test, device=device)[1]}')

### Exercise 1.1: A baseline MLP

Implement a *simple* Multilayer Perceptron to classify the 10 digits of MNIST (e.g. two *narrow* layers). Use my code above as inspiration, but implement your own training pipeline -- you will need it later. Train this model to convergence, monitoring (at least) the loss and accuracy on the training and validation sets for every epoch. Below I include a basic implementation to get you started -- remember that you should write your *own* pipeline!

**Note**: This would be a good time to think about *abstracting* your model definition, and training and evaluation pipelines in order to make it easier to compare performance of different models.

**Important**: Given the *many* runs you will need to do, and the need to *compare* performance between them, this would **also** be a great point to study how **Tensorboard** or **Weights and Biases** can be used for performance monitoring.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

class MLP_Block(nn.Module):
    def __init__(self, in_size, out_size, activation="ReLU", dropout=0.0, batch_norm=False):
        super().__init__()
        self.activation = getattr(nn, activation)
        layers = [nn.Linear(in_size, out_size)]
        if batch_norm:
            layers.append(nn.BatchNorm1d(out_size))
        layers += [self.activation(), nn.Dropout(dropout)]
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

class MLP(nn.Module):
    def __init__(self, input_size, layers_dim, class_num, hidden_layers_num, residual=False, 
                 activation="ReLU", dropout=0.0, batch_norm=False):
        super().__init__()
        self.residual = residual
        self.input_layer = MLP_Block(input_size, layers_dim, activation, dropout, batch_norm)
        self.hidden_layers = nn.ModuleList([
            MLP_Block(layers_dim, layers_dim, activation, dropout, batch_norm)
            for _ in range(hidden_layers_num)
        ])
        self.output_layer = nn.Linear(layers_dim, class_num)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.input_layer(x)
        if self.residual:
            for layer in self.hidden_layers:
                res = x
                x = layer(x) + res
        else:
            for layer in self.hidden_layers:
                x = layer(x)
        return self.output_layer(x)

class Runner:
    def __init__(self, model, logger: Logger, criterion="CrossEntropyLoss", optimizer="Adam",
                 lr=0.001, device="cpu", scheduler=None, max_iter=100,
                 distill=False, T=1., alpha=1, teacher_model=None, topk=5):
        self.model = model.to(device)
        self.device = device
        self.alpha = alpha
        self.T = T
        self.distill = distill
        self.teacher_model = teacher_model
        self.topk = topk
        self.logger = logger

        self.criterion = getattr(nn, criterion)()
        self.optimizer = getattr(optim, optimizer)(model.parameters(), lr)
        self.scheduler = None
        if scheduler:
            self.scheduler = getattr(lr_scheduler, scheduler)(self.optimizer, max_iter)
        if distill and teacher_model:
            self.distillation_criterion = nn.KLDivLoss(reduction="batchmean")

    def distillation_loss(self, logits, labels, teacher_probs):
        ce_loss = self.criterion(logits, labels)
        log_student = F.log_softmax(logits / self.T, dim=1)
        soft_loss = self.distillation_criterion(log_student, teacher_probs) * (self.T * self.T)
        return self.alpha * soft_loss + (1 - self.alpha) * ce_loss

    def train(self, dl_train, dl_val, train_epochs_num):
        train_metrics = {'loss': [], 'acc1': [], 'acc5': []}
        val_metrics = {'loss': [], 'acc1': [], 'acc5': []}
        global_step = 0
        for epoch in range(1, train_epochs_num + 1):
            loss, acc1, acc5 = self.process_epoch(dl_train, train=True)
            train_metrics['loss'].append(loss)
            train_metrics['acc1'].append(acc1)
            train_metrics['acc5'].append(acc5)
            self.logger.log_metrics({'train_loss': loss, 'train_acc1': acc1, 'train_acc5': acc5}, global_step, epoch)

            loss, acc1, acc5 = self.process_epoch(dl_val, train=False)
            val_metrics['loss'].append(loss)
            val_metrics['acc1'].append(acc1)
            val_metrics['acc5'].append(acc5)
            self.logger.log_metrics({'val_loss': loss, 'val_acc1': acc1, 'val_acc5': acc5}, global_step, epoch)

            global_step += 1
            if self.scheduler:
                self.scheduler.step()
        return train_metrics, val_metrics

        def test(self, dl_test, test_epochs_num):
        """
        Detailed test: computes top1, top5 accuracies, average loss,
        and returns raw predictions, ground truths, and losses.
        """
        self.model.eval()
        gts = []
        predictions = []
        losses = []
        top5_correct = 0
        total_samples = 0

        for data, labels in tqdm(dl_test, desc="[Test/Validation]", leave=False):
            data, labels = data.to(self.device), labels.to(self.device)
            with torch.no_grad():
                logits = self.model(data)
                loss = self.criterion(logits, labels)
                prob = F.softmax(logits, dim=1)
                preds1 = torch.argmax(prob, dim=1)

            # top-5 accuracy
            top5 = torch.topk(prob, k=self.topk, dim=1).indices
            top5_correct += (top5 == labels.view(-1,1)).any(dim=1).sum().item()
            batch_size = labels.size(0)
            total_samples += batch_size

            gts.append(labels.cpu())
            predictions.append(preds1.cpu())
            losses.append(loss.item())

        # compute metrics
        top5_accuracy = top5_correct / total_samples
        top1_accuracy = (torch.cat(predictions) == torch.cat(gts)).float().mean().item()
        avg_loss = sum(losses) / len(losses)

        # log metrics
        self.logger.log_metrics({'test_loss': avg_loss,
                                 'test_acc1': top1_accuracy,
                                 'test_acc5': top5_accuracy},
                                 step=0, epoch=0, prefix='test_')

        return {
            'top1_accuracy': top1_accuracy,
            'top5_accuracy': top5_accuracy,
            'average_loss': avg_loss,
            'ground_truths': torch.cat(gts),
            'predictions': torch.cat(predictions),
            'losses': losses
        }

    def process_epoch(self, dataloader, train=True):
        mode = "Training" if train else "Evaluation"
        self.model.train() if train else self.model.eval()

        running_loss = 0.0
        correct1 = 0
        correctk = 0
        total = 0
        with torch.set_grad_enabled(train):
            for inputs, targets in tqdm(dataloader, desc=mode):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                logits = self.model(inputs)

                if train and self.distill and self.teacher_model:
                    with torch.no_grad():
                        teacher_logits = self.teacher_model(inputs)
                        teacher_probs = F.softmax(teacher_logits / self.T, dim=1)
                    loss = self.distillation_loss(logits, targets, teacher_probs)
                else:
                    loss = self.criterion(logits, targets)

                if train:
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()

                batch_size = targets.size(0)
                running_loss += loss.item() * batch_size
                preds1 = logits.argmax(dim=1)
                correct1 += (preds1 == targets).sum().item()
                topk_preds = logits.topk(self.topk, dim=1).indices
                correctk += (topk_preds == targets.view(-1,1)).any(dim=1).sum().item()
                total += batch_size

        return running_loss / total, correct1 / total, correctk / total

    def gradient_norm(self, dl):
        self.model.train()
        inputs, targets = next(iter(dl))
        inputs, targets = inputs.to(self.device), targets.to(self.device)
        logits = self.model(inputs)
        loss = self.criterion(logits, targets)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return [param.grad.norm(2).item() for _, param in self.model.named_parameters() if param.grad is not None]


### Exercise 1.2: Adding Residual Connections

Implement a variant of your parameterized MLP network to support **residual** connections. Your network should be defined as a composition of **residual MLP** blocks that have one or more linear layers and add a skip connection from the block input to the output of the final linear layer.

**Compare** the performance (in training/validation loss and test accuracy) of your MLP and ResidualMLP for a range of depths. Verify that deeper networks **with** residual connections are easier to train than a network of the same depth **without** residual connections.

**For extra style points**: See if you can explain by analyzing the gradient magnitudes on a single training batch *why* this is the case. 

In [None]:
# Your code here.

### Exercise 1.3: Rinse and Repeat (but with a CNN)

Repeat the verification you did above, but with **Convolutional** Neural Networks. If you were careful about abstracting your model and training code, this should be a simple exercise. Show that **deeper** CNNs *without* residual connections do not always work better and **even deeper** ones *with* residual connections.

**Hint**: You probably should do this exercise using CIFAR-10, since MNIST is *very* easy (at least up to about 99% accuracy).

**Tip**: Feel free to reuse the ResNet building blocks defined in `torchvision.models.resnet` (e.g. [BasicBlock](https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py#L59) which handles the cascade of 3x3 convolutions, skip connections, and optional downsampling). This is an excellent exercise in code diving. 

**Spoiler**: Depending on the optional exercises you plan to do below, you should think *very* carefully about the architectures of your CNNs here (so you can reuse them!).

In [1]:
import torch
import torch.nn as nn

class CNN_Block(nn.Module):
    expansion: int = 1

    def __init__(self, in_planes, out_planes, stride=1, downsample=None, dilation=1, norm_layer=None, residual=True):
        super().__init__()

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        else:
            norm_layer = getattr(nn, norm_layer)

        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
        self.bn1 = norm_layer(out_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1, padding=dilation, bias=False, dilation=dilation)
        self.bn2 = norm_layer(out_planes)
        self.downsample = downsample
        self.stride = stride
        self.residual = residual

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.residual:
            if self.downsample is not None:
                identity = self.downsample(x)
            out += identity

        out = self.relu(out)
        return out

class CNN(nn.Module):
    def __init__(
        self,
        block_type="basic",
        layers=[2, 2, 2, 2],
        num_classes=1000,
        residual=True
    ):
        super(CNN, self).__init__()

        self.inplanes = 64
        self.residual = residual

        # Initial convolution
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Block selection
        if block_type == "basic":
            block = CNN_Block
            
        elif block_type == "bottleneck":
            block = Bottleneck
        
        else:
            raise ValueError("must be basic or bottleneck")

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None

        if stride != 1 or self.inplanes != planes * block.expansion:
            if self.residual:
                downsample = nn.Sequential(
                    nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                    nn.BatchNorm2d(planes * block.expansion),
                )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, residual=self.residual))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, residual=self.residual))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x

-----
## Exercise 2: Choose at Least One

Below are **three** exercises that ask you to deepen your understanding of Deep Networks for visual recognition. You must choose **at least one** of the below for your final submission -- feel free to do **more**, but at least **ONE** you must submit. Each exercise is designed to require you to dig your hands **deep** into the guts of your models in order to do new and interesting things.

**Note**: These exercises are designed to use your small, custom CNNs and small datasets. This is to keep training times reasonable. If you have a decent GPU, feel free to use pretrained ResNets and larger datasets (e.g. the [Imagenette](https://pytorch.org/vision/0.20/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette) dataset at 160px).

### Exercise 2.1: *Fine-tune* a pre-trained model
Train one of your residual CNN models from Exercise 1.3 on CIFAR-10. Then:
1. Use the pre-trained model as a **feature extractor** (i.e. to extract the feature activations of the layer input into the classifier) on CIFAR-100. Use a **classical** approach (e.g. Linear SVM, K-Nearest Neighbor, or Bayesian Generative Classifier) from scikit-learn to establish a **stable baseline** performance on CIFAR-100 using the features extracted using your CNN.
2. Fine-tune your CNN on the CIFAR-100 training set and compare with your stable baseline. Experiment with different strategies:
    - Unfreeze some of the earlier layers for fine-tuning.
    - Test different optimizers (Adam, SGD, etc.).

Each of these steps will require you to modify your model definition in some way. For 1, you will need to return the activations of the last fully-connected layer (or the global average pooling layer). For 2, you will need to replace the original, 10-class classifier with a new, randomly-initialized 100-class classifier.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import wandb

class Logger:
    def __init__(self, project_name, run_name=None, config=None):
        """
        Wrapper over wandb for experiment logging.
        """
        self.run = wandb.init(project=project_name, name=run_name, config=config or {})

    def log_metrics(self, metrics: dict, step: int, epoch: int, prefix: str = ""):
        """
        Log a dictionary of metrics to wandb.
        metrics: dict of metric_name: value
        step: global step or batch index
        epoch: current epoch
        prefix: optional prefix for metric names
        """
        log_data = {f"{prefix}{k}": v for k, v in metrics.items()}
        log_data.update({'epoch': epoch, 'step': step})
        wandb.log(log_data)

### Exercise 2.2: *Distill* the knowledge from a large model into a smaller one
In this exercise you will see if you can derive a *small* model that performs comparably to a larger one on CIFAR-10. To do this, you will use [Knowledge Distillation](https://arxiv.org/abs/1503.02531):

> Geoffrey Hinton, Oriol Vinyals, and Jeff Dean. Distilling the Knowledge in a Neural Network, NeurIPS 2015.

To do this:
1. Train one of your best-performing CNNs on CIFAR-10 from Exercise 1.3 above. This will be your **teacher** model.
2. Define a *smaller* variant with about half the number of parameters (change the width and/or depth of the network). Train it on CIFAR-10 and verify that it performs *worse* than your **teacher**. This small network will be your **student** model.
3. Train the **student** using a combination of **hard labels** from the CIFAR-10 training set (cross entropy loss) and **soft labels** from predictions of the **teacher** (Kulback-Leibler loss between teacher and student).

Try to optimize training parameters in order to maximize the performance of the student. It should at least outperform the student trained only on hard labels in Setp 2.

**Tip**: You can save the predictions of the trained teacher network on the training set and adapt your dataloader to provide them together with hard labels. This will **greatly** speed up training compared to performing a forward pass through the teacher for each batch of training.

In [None]:
class DatasetLoader:
    """
    A class for loading and preparing MNIST and CIFAR10 datasets.
    """
    def __init__(self, dataset_name, batch_size=64, val_split=0.1, num_workers=4, data_dir='./data', 
                 augmentation=False, download=True):
        """
        Initialize DatasetLoader with the specified dataset.
        
        Args:
            dataset_name (str): Name of dataset ('mnist' or 'cifar10')
            batch_size (int): Batch size for dataloaders
            val_split (float): Proportion of training data to use for validation
            num_workers (int): Number of workers for data loading
            data_dir (str): Directory to store datasets
            augmentation (bool): Whether to use data augmentation
            download (bool): Whether to download the dataset if not available locally
        """
        self.dataset_name = dataset_name.lower()
        self.batch_size = batch_size
        self.val_split = val_split
        self.num_workers = num_workers
        self.data_dir = data_dir
        self.augmentation = augmentation
        self.download = download
        
        # Make sure dataset name is valid
        if self.dataset_name not in ['mnist', 'cifar10']:
            raise ValueError("Dataset must be 'mnist' or 'cifar10'")
        
        # Create data directory if it doesn't exist
        os.makedirs(data_dir, exist_ok=True)
        
        # Prepare transformations
        self._prepare_transforms()
        
        # Load datasets
        self._load_datasets()
        
        # Create dataloaders
        self._create_dataloaders()
        
    def _prepare_transforms(self):
        """Prepare transformations for the datasets."""
        if self.dataset_name == 'mnist':
            # MNIST transformations
            self.test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])
            
            if self.augmentation:
                self.train_transform = transforms.Compose([
                    transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.1307,), (0.3081,))
                ])
            else:
                self.train_transform = self.test_transform
                
        elif self.dataset_name == 'cifar10':
            # CIFAR10 transformations
            self.test_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
            ])
            
            if self.augmentation:
                self.train_transform = transforms.Compose([
                    transforms.RandomCrop(32, padding=4),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
                ])
            else:
                self.train_transform = self.test_transform
    
    def _load_datasets(self):
        """Load the specified dataset."""
        if self.dataset_name == 'mnist':
            # Load MNIST dataset
            self.full_train_dataset = datasets.MNIST(
                root=self.data_dir,
                train=True,
                download=self.download,
                transform=self.train_transform
            )
            
            self.test_dataset = datasets.MNIST(
                root=self.data_dir,
                train=False,
                download=self.download,
                transform=self.test_transform
            )
            
        elif self.dataset_name == 'cifar10':
            # Load CIFAR10 dataset
            self.full_train_dataset = datasets.CIFAR10(
                root=self.data_dir,
                train=True,
                download=self.download,
                transform=self.train_transform
            )
            
            self.test_dataset = datasets.CIFAR10(
                root=self.data_dir,
                train=False,
                download=self.download,
                transform=self.test_transform
            )
        
        # Split training set into training and validation
        val_size = int(len(self.full_train_dataset) * self.val_split)
        train_size = len(self.full_train_dataset) - val_size
        
        self.train_dataset, self.val_dataset = random_split(
            self.full_train_dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42)  # For reproducibility
        )
        
        print(f"Dataset: {self.dataset_name.upper()}")
        print(f"Training samples: {len(self.train_dataset)}")
        print(f"Validation samples: {len(self.val_dataset)}")
        print(f"Test samples: {len(self.test_dataset)}")
        print(f"Augmentation: {'Enabled' if self.augmentation else 'Disabled'}")
    
    def _create_dataloaders(self):
        """Create DataLoader objects for training, validation, and test sets."""
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def get_loaders(self):
        """
        Return the DataLoader objects.
        
        Returns:
            tuple: (train_loader, val_loader, test_loader)
        """
        return self.train_loader, self.val_loader, self.test_loader
    
    def get_input_size(self):
        """
        Get the input size for the dataset.
        
        Returns:
            tuple: Shape of a single input sample
        """
        if self.dataset_name == 'mnist':
            return 1, 28, 28  # MNIST: 1 channel, 28x28
        else:
            return 3, 32, 32  # CIFAR10: 3 channels, 32x32
    
    def get_num_classes(self):
        """
        Get the number of classes in the dataset.
        
        Returns:
            int: Number of classes (10 for both MNIST and CIFAR10)
        """
        return 10
    
    def get_class_names(self):
        """
        Get the class names for the dataset.
        
        Returns:
            list: List of class names
        """
        if self.dataset_name == 'mnist':
            return [str(i) for i in range(10)]  # MNIST classes are 0-9
        else:
            return [
                'airplane', 'automobile', 'bird', 'cat', 'deer',
                'dog', 'frog', 'horse', 'ship', 'truck'
            ]  # CIFAR10 class names


class DistillationDatasetLoader(DatasetLoader):
    """
    Extended DatasetLoader class that supports knowledge distillation by
    storing and providing teacher model predictions.
    """
    def __init__(self, dataset_name, teacher_model=None, temperature=1.0, batch_size=64, 
                 val_split=0.1, num_workers=4, data_dir='./data', augmentation=False, download=True):
        """
        Initialize DistillationDatasetLoader with the specified dataset and teacher model.
        
        Args:
            dataset_name (str): Name of dataset ('mnist' or 'cifar10')
            teacher_model: Pretrained teacher model for distillation
            temperature (float): Temperature for softening probabilities
            batch_size (int): Batch size for dataloaders
            val_split (float): Proportion of training data to use for validation
            num_workers (int): Number of workers for data loading
            data_dir (str): Directory to store datasets
            augmentation (bool): Whether to use data augmentation
            download (bool): Whether to download the dataset if not available locally
        """
        super().__init__(dataset_name, batch_size, val_split, num_workers, data_dir, augmentation, download)
        
        self.teacher_model = teacher_model
        self.temperature = temperature
        self.teacher_predictions = None
        
        if teacher_model is not None:
            self._generate_teacher_predictions()
    
    def _generate_teacher_predictions(self):
        """Generate and store predictions from the teacher model on the training dataset."""
        print("Generating teacher predictions for distillation...")
        
        # Set teacher model to evaluation mode
        self.teacher_model.eval()
        device = next(self.teacher_model.parameters()).device
        
        # Create a dataloader for the full training set
        full_train_loader = DataLoader(
            self.full_train_dataset,
            batch_size=self.batch_size,
            shuffle=False,  # Important: keep the same order as the dataset
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        # Generate predictions
        all_predictions = []
        
        with torch.no_grad():
            for inputs, _ in tqdm(full_train_loader, desc="Generating teacher predictions"):
                inputs = inputs.to(device)
                outputs = self.teacher_model(inputs)
                
                # Apply temperature scaling and softmax
                soft_targets = F.softmax(outputs / self.temperature, dim=1)
                all_predictions.append(soft_targets.cpu())
        
        # Concatenate all predictions
        self.teacher_predictions = torch.cat(all_predictions, dim=0)
        print(f"Generated {len(self.teacher_predictions)} teacher predictions")
    
    def create_distillation_loaders(self):
        """
        Create specialized DataLoaders that provide both hard labels and teacher's soft predictions.
        
        Returns:
            tuple: (distill_train_loader, val_loader, test_loader)
        """
        if self.teacher_predictions is None:
            raise ValueError("Teacher predictions not available. Either provide a teacher model or call set_teacher_predictions.")
        
        # Create a custom Dataset that provides both input, hard labels, and soft labels
        class DistillationDataset(torch.utils.data.Dataset):
            def __init__(self, original_dataset, teacher_preds, indices=None):
                self.original_dataset = original_dataset
                self.teacher_preds = teacher_preds
                self.indices = indices if indices is not None else list(range(len(original_dataset)))
                
            def __len__(self):
                return len(self.indices)
                
            def __getitem__(self, idx):
                # Get the actual index in the original dataset
                original_idx = self.indices[idx]
                
                # Get the input and hard label from the original dataset
                input_data, hard_label = self.original_dataset[original_idx]
                
                # Get the corresponding soft label
                soft_label = self.teacher_preds[original_idx]
                
                return input_data, hard_label, soft_label
        
        # Create distillation datasets
        train_indices = list(range(len(self.train_dataset)))
        distill_train_dataset = DistillationDataset(
            self.full_train_dataset,
            self.teacher_predictions,
            indices=[self.train_dataset.indices[i] for i in train_indices]
        )
        
        # Create distillation dataloader
        distill_train_loader = DataLoader(
            distill_train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        return distill_train_loader, self.val_loader, self.test_loader
    
    def set_teacher_predictions(self, predictions):
        """
        Set pre-computed teacher predictions.
        
        Args:
            predictions (torch.Tensor): Tensor of teacher predictions
        """
        if len(predictions) != len(self.full_train_dataset):
            raise ValueError(f"Number of predictions ({len(predictions)}) does not match dataset size ({len(self.full_train_dataset)})")
        
        self.teacher_predictions = predictions
        print(f"Set {len(self.teacher_predictions)} teacher predictions")

### Exercise 2.3: *Explain* the predictions of a CNN

Use the CNN model you trained in Exercise 1.3 and implement [*Class Activation Maps*](http://cnnlocalization.csail.mit.edu/#:~:text=A%20class%20activation%20map%20for,decision%20made%20by%20the%20CNN.):

> B. Zhou, A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba. Learning Deep Features for Discriminative Localization. CVPR'16 (arXiv:1512.04150, 2015).

Use your CNN implementation to demonstrate how your trained CNN *attends* to specific image features to recognize *specific* classes. Try your implementation out using a pre-trained ResNet-18 model and some images from the [Imagenette](https://pytorch.org/vision/0.20/generated/torchvision.datasets.Imagenette.html#torchvision.datasets.Imagenette) dataset -- I suggest you start with the low resolution version of images at 160px.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import argparse
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import wandb
from models import MLP, CNN# Your code here.