### PyTorch AlexNet Exercises

Welcome to the PyTorch AlexNet exercise template notebook.

There are several questions in this notebook and it's your goal to answer them by writing Python and PyTorch code.






In [2]:
import torch

# Usa MPS per Mac (Apple Silicon) o CPU come fallback
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("✓ Using MPS (Metal Performance Shaders) - Apple GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("✓ Using CUDA - NVIDIA GPU")
else:
    device = torch.device('cpu')
    print("✓ Using CPU")

✓ Using MPS (Metal Performance Shaders) - Apple GPU


## Estrai dataset da zip

In [None]:
import os
import zipfile
import shutil
from pathlib import Path

# Percorso del file zip
zip_path = '../data/tiny-imagenet-200.zip'
extract_path = '../data/tiny-imagenet-200'

print("Extracting Tiny ImageNet dataset...")
# Estrai il dataset
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall('../data')

print("Dataset extracted successfully!")

# Tiny ImageNet ha una struttura particolare per il validation set
# Dobbiamo riorganizzarlo per renderlo compatibile con ImageFolder

val_dir = os.path.join(extract_path, 'val')
val_images_dir = os.path.join(val_dir, 'images')
val_annotations_file = os.path.join(val_dir, 'val_annotations.txt')

# Leggi le annotazioni del validation set
print("Reorganizing validation set...")
val_img_dict = {}
with open(val_annotations_file, 'r') as f:
    for line in f:
        parts = line.strip().split('\t')
        img_name = parts[0]
        class_id = parts[1]
        val_img_dict[img_name] = class_id

# Crea le directory per ogni classe nel validation set
for class_id in set(val_img_dict.values()):
    class_dir = os.path.join(val_dir, class_id)
    os.makedirs(class_dir, exist_ok=True)

# Sposta le immagini nelle rispettive cartelle di classe
for img_name, class_id in val_img_dict.items():
    src = os.path.join(val_images_dir, img_name)
    dst = os.path.join(val_dir, class_id, img_name)
    if os.path.exists(src):
        shutil.move(src, dst)

# Rimuovi la cartella images vuota
if os.path.exists(val_images_dir):
    os.rmdir(val_images_dir)

# Rimuovi il file di annotazioni
if os.path.exists(val_annotations_file):
    os.remove(val_annotations_file)

print("Validation set reorganized successfully!")

# Verifica la struttura
train_path = os.path.join(extract_path, 'train')
val_path = os.path.join(extract_path, 'val')

num_train_classes = len([d for d in os.listdir(train_path) if os.path.isdir(os.path.join(train_path, d))])
num_val_classes = len([d for d in os.listdir(val_path) if os.path.isdir(os.path.join(val_path, d))])

print(f"\nDataset ready!")
print(f"Train classes: {num_train_classes}")
print(f"Validation classes: {num_val_classes}")
print(f"\nTrain path: {train_path}")
print(f"Validation path: {val_path}")
print("\nUpdate your code paths to:")
print(f"  train_dataset = datasets.ImageFolder('{train_path}', transform=transform)")
print(f"  val_dataset = datasets.ImageFolder('{val_path}', transform=transform)")

Extracting Tiny ImageNet dataset...
Dataset extracted successfully!
Reorganizing validation set...
Validation set reorganized successfully!

Dataset ready!
Train classes: 200
Validation classes: 200

Train path: tiny-imagenet-200/train
Validation path: tiny-imagenet-200/val

Update your code paths to:
  train_dataset = datasets.ImageFolder('tiny-imagenet-200/train', transform=transform)
  val_dataset = datasets.ImageFolder('tiny-imagenet-200/val', transform=transform)


## AlexNet

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import time
import wandb
from tabulate import tabulate

# Login to W&B
wandb.login(key="a43845e720f3ca1353b72c93baa054db6a19fbcf")

# Define the AlexNet architecture
class AlexNet(nn.Module):
    def __init__(self, num_classes=200):
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(
            # Conv1: 224x224x3 -> 55x55x96
            nn.Conv2d(3, 96, kernel_size=11, stride=4, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # Conv2: 55x55x96 -> 27x27x256
            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            # Conv3: 27x27x256 -> 13x13x384
            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Conv4: 13x13x384 -> 13x13x384
            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Conv5: 13x13x384 -> 13x13x256
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
        )

        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))

        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# Training function
def train_epoch(net, train_loader, criterion, optimizer, device):
    net.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if i % 100 == 99:
            print(f'  Batch [{i+1}/{len(train_loader)}], Loss: {running_loss/(i+1):.3f}, Acc: {100.*correct/total:.2f}%')

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# Validation function
def validate(net, val_loader, criterion, device):
    net.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc

# Hyperparameters
learning_rates = [0.1, 0.001, 0.0001]
batch_sizes = [16, 32, 64]
weight_decays = [5e-4, 1e-3, 1e-4]  # Different weight decay values to test
num_epochs = 10  # Adjust based on your needs

# Define transforms for the input data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the Tiny ImageNet dataset
train_dataset = datasets.ImageFolder('../data/tiny-imagenet-200/train', transform=transform)
val_dataset = datasets.ImageFolder('../data/tiny-imagenet-200/val', transform=transform)

# Store results for final comparison
results = []

# Loop over hyperparameters
experiment_id = 0
for lr in learning_rates:
    for batch_size in batch_sizes:
        for weight_decay in weight_decays:
            experiment_id += 1
            print(f"\n{'='*80}")
            print(f"Experiment {experiment_id}: LR={lr}, Batch Size={batch_size}, Weight Decay={weight_decay}")
            print(f"{'='*80}\n")

            # Initialize W&B run
            run = wandb.init(
                project="alexnet-tiny-imagenet",
                name=f"lr{lr}_bs{batch_size}_wd{weight_decay}",
                config={
                    "learning_rate": lr,
                    "batch_size": batch_size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "architecture": "AlexNet",
                    "dataset": "Tiny ImageNet"
                },
                reinit=True
            )

            # Data loaders
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

            # Initialize the network
            net = AlexNet(num_classes=200).to(device)

            # Loss and optimizer
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

            # Learning rate scheduler (optional but recommended)
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

            # Track best validation accuracy
            best_val_acc = 0.0
            start_time = time.time()

            # Train the network
            for epoch in range(num_epochs):
                print(f"\nEpoch [{epoch+1}/{num_epochs}]")

                # Training
                train_loss, train_acc = train_epoch(net, train_loader, criterion, optimizer, device)

                # Validation
                val_loss, val_acc = validate(net, val_loader, criterion, device)

                # Update learning rate
                scheduler.step()

                # Log metrics to W&B
                wandb.log({
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "learning_rate": optimizer.param_groups[0]['lr']
                })

                print(f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%")
                print(f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.2f}%")

                # Save best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(net.state_dict(), f'best_alexnet_lr{lr}_bs{batch_size}_wd{weight_decay}.pth')

            training_time = time.time() - start_time

            # Final evaluation
            final_val_loss, final_val_acc = validate(net, val_loader, criterion, device)

            # Store results
            results.append({
                'Learning Rate': lr,
                'Batch Size': batch_size,
                'Weight Decay': weight_decay,
                'Best Val Acc': f"{best_val_acc:.2f}%",
                'Final Val Acc': f"{final_val_acc:.2f}%",
                'Training Time (min)': f"{training_time/60:.2f}"
            })

            # Log final metrics to W&B
            wandb.log({
                "best_val_acc": best_val_acc,
                "final_val_acc": final_val_acc,
                "training_time_minutes": training_time/60
            })

            print(f"\nExperiment completed in {training_time/60:.2f} minutes")
            print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

            # Finish W&B run
            wandb.finish()

# Print final comparison table
print("\n" + "="*100)
print("FINAL RESULTS COMPARISON")
print("="*100 + "\n")

# Create table
headers = results[0].keys()
rows = [list(r.values()) for r in results]
print(tabulate(rows, headers=headers, tablefmt="grid"))

# Find best configuration
best_result = max(results, key=lambda x: float(x['Best Val Acc'].rstrip('%')))
print("\n" + "="*100)
print("BEST CONFIGURATION:")
print("="*100)
print(f"Learning Rate: {best_result['Learning Rate']}")
print(f"Batch Size: {best_result['Batch Size']}")
print(f"Weight Decay: {best_result['Weight Decay']}")
print(f"Best Validation Accuracy: {best_result['Best Val Acc']}")
print(f"Training Time: {best_result['Training Time (min)']} minutes")
print("="*100)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/nicolotermine/.netrc
[34m[1mwandb[0m: Currently logged in as: [33ms338680[0m ([33ms338680-politecnico-di-torino[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



Experiment 1: LR=0.1, Batch Size=16, Weight Decay=0.0005






Epoch [1/10]
  Batch [100/6250], Loss: 5.307, Acc: 0.56%
  Batch [200/6250], Loss: 5.309, Acc: 0.41%
  Batch [300/6250], Loss: 5.308, Acc: 0.44%
  Batch [400/6250], Loss: 5.311, Acc: 0.41%
  Batch [500/6250], Loss: 5.311, Acc: 0.47%
  Batch [600/6250], Loss: 5.311, Acc: 0.53%
  Batch [700/6250], Loss: 5.311, Acc: 0.51%
  Batch [800/6250], Loss: 5.312, Acc: 0.51%
  Batch [900/6250], Loss: 5.311, Acc: 0.50%
  Batch [1000/6250], Loss: 5.312, Acc: 0.49%
  Batch [1100/6250], Loss: 5.312, Acc: 0.54%
  Batch [1200/6250], Loss: 5.312, Acc: 0.54%
  Batch [1300/6250], Loss: 5.312, Acc: 0.50%
  Batch [1400/6250], Loss: 5.312, Acc: 0.52%
  Batch [1500/6250], Loss: 5.312, Acc: 0.52%
  Batch [1600/6250], Loss: 5.312, Acc: 0.50%
  Batch [1700/6250], Loss: 5.313, Acc: 0.49%
  Batch [1800/6250], Loss: 5.313, Acc: 0.49%
  Batch [1900/6250], Loss: 5.313, Acc: 0.49%
  Batch [2000/6250], Loss: 5.313, Acc: 0.48%
  Batch [2100/6250], Loss: 5.313, Acc: 0.47%
  Batch [2200/6250], Loss: 5.313, Acc: 0.47%
  Bat

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/spawn.py", line 132, in _main
    self = reduction.pickle.load(from_parent)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nicolotermine/Library/Caches/pypoetry/virtualenvs/lab-04-RLYAEv7G-py3.12/lib/python3.12/site-packages/torch/__init__.py", line 2680, in <module>
    from torch import _meta_registrations
  File "/Users/nicolotermine/Library/Caches/pypoetry/virtualenvs/lab-04-RLYAEv7G-py3.12/lib/python3.12/site-packages/torch/_meta_registrations.py", line 12, in <module>
    from torch._decomp import (
  File "/Users/nicolotermine/Library/Caches/pypoetry/virtualenvs/lab-04-RLYAEv7G-py3.12/lib/python3.12

KeyboardInterrupt: 

socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x13e00df10>> (for post_run_cell), with arguments args (<ExecutionResult object at 1109694c0, execution_count=3 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 110f28a40, raw_cell="import torch
import torch.nn as nn
import torch.op.." transformed_cell="import torch
import torch.nn as nn
import torch.op.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/nicolotermine/zMellow/GitHub-Poli/Polito/polito-aml/Lab_04/AML_lab04.ipynb#W3sZmlsZQ%3D%3D> result=None>,),kwargs {}:


socket.send() raised exception.
socket.send() raised exception.


BrokenPipeError: [Errno 32] Broken pipe

socket.send() raised exception.
socket.send() raised exception.


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

Tre artifact

## 1. **ResNet Training** (primo artifact)
Codice equivalente a quello di AlexNet ma con l'architettura ResNet18, che include:
- Implementazione completa di ResNet con blocchi residuali
- Stessa configurazione di hyperparameter tuning
- Integrazione W&B per tracking
- Salvataggio dei risultati in `resnet18_results.pkl`

## 2. **Confronto Tabellare** (secondo artifact)
Script che carica i risultati di entrambe le architetture e produce:
- **Tabelle complete** di tutti gli esperimenti
- **Statistiche aggregate** (media, max, min, std)
- **Top 5 configurazioni** per ogni architettura
- **Confronto head-to-head** delle migliori configurazioni
- **Analisi per iperparametro** (effetto di LR, BS, WD)
- **Key findings** con conclusioni automatiche

## 3. **Visualizzazioni Grafiche** (terzo artifact)
Script che genera 9 grafici in 2 immagini PNG:

**architecture_comparison.png** contiene:
1. Box plot distribuzione accuracy
2. Scatter plot accuracy vs training time
3. Bar chart effetto learning rate
4. Bar chart effetto batch size
5. Bar chart effetto weight decay
6. Heatmap differenza di performance

**detailed_analysis.png** contiene:
7. Top 10 configurazioni a confronto
8. Istogramma distribuzione tempi di training
9. Error bar per robustezza delle configurazioni

### Come usare:
1. Esegui prima il training di AlexNet (già fornito)
2. Esegui il training di ResNet18 (primo nuovo artifact)
3. Esegui l'analisi tabellare (secondo artifact)
4. Esegui la visualizzazione grafica (terzo artifact)

Tutti i risultati verranno salvati e visualizzati sia su W&B che localmente!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import time
import wandb
from tabulate import tabulate

# Login to W&B
wandb.login(key="a43845e720f3ca1353b72c93baa054db6a19fbcf")

# Define the ResNet architecture with residual blocks
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=200):
        super(ResNet, self).__init__()
        self.in_channels = 64

        # Initial convolution layer
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Residual layers
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

def ResNet18(num_classes=200):
    return ResNet(ResidualBlock, [2, 2, 2, 2], num_classes)

def ResNet34(num_classes=200):
    return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes)

# Training function
def train_epoch(net, train_loader, criterion, optimizer, device):
    net.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        if i % 100 == 99:
            print(f'  Batch [{i+1}/{len(train_loader)}], Loss: {running_loss/(i+1):.3f}, Acc: {100.*correct/total:.2f}%')

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc

# Validation function
def validate(net, val_loader, criterion, device):
    net.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    return val_loss, val_acc

# Hyperparameters
learning_rates = [0.1, 0.001, 0.0001]
batch_sizes = [16, 32, 64]
weight_decays = [5e-4, 1e-3, 1e-4]
num_epochs = 10

# Define transforms for the input data
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load the Tiny ImageNet dataset
train_dataset = datasets.ImageFolder('../data/tiny-imagenet-200/train', transform=transform)
val_dataset = datasets.ImageFolder('../data/tiny-imagenet-200/val', transform=transform)

# Store results for final comparison
results = []

# Loop over hyperparameters
experiment_id = 0
for lr in learning_rates:
    for batch_size in batch_sizes:
        for weight_decay in weight_decays:
            experiment_id += 1
            print(f"\n{'='*80}")
            print(f"Experiment {experiment_id}: LR={lr}, Batch Size={batch_size}, Weight Decay={weight_decay}")
            print(f"{'='*80}\n")

            # Initialize W&B run
            run = wandb.init(
                project="resnet-tiny-imagenet",
                name=f"resnet18_lr{lr}_bs{batch_size}_wd{weight_decay}",
                config={
                    "learning_rate": lr,
                    "batch_size": batch_size,
                    "weight_decay": weight_decay,
                    "epochs": num_epochs,
                    "architecture": "ResNet18",
                    "dataset": "Tiny ImageNet"
                },
                reinit=True
            )

            # Data loaders
            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
            val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

            # Initialize the network (using ResNet18)
            net = ResNet18(num_classes=200).to(device)

            # Loss and optimizer
            criterion = nn.CrossEntropyLoss()
            optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

            # Learning rate scheduler
            scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

            # Track best validation accuracy
            best_val_acc = 0.0
            start_time = time.time()

            # Train the network
            for epoch in range(num_epochs):
                print(f"\nEpoch [{epoch+1}/{num_epochs}]")

                # Training
                train_loss, train_acc = train_epoch(net, train_loader, criterion, optimizer, device)

                # Validation
                val_loss, val_acc = validate(net, val_loader, criterion, device)

                # Update learning rate
                scheduler.step()

                # Log metrics to W&B
                wandb.log({
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "learning_rate": optimizer.param_groups[0]['lr']
                })

                print(f"Train Loss: {train_loss:.3f}, Train Acc: {train_acc:.2f}%")
                print(f"Val Loss: {val_loss:.3f}, Val Acc: {val_acc:.2f}%")

                # Save best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save(net.state_dict(), f'best_resnet18_lr{lr}_bs{batch_size}_wd{weight_decay}.pth')

            training_time = time.time() - start_time

            # Final evaluation
            final_val_loss, final_val_acc = validate(net, val_loader, criterion, device)

            # Store results
            results.append({
                'Architecture': 'ResNet18',
                'Learning Rate': lr,
                'Batch Size': batch_size,
                'Weight Decay': weight_decay,
                'Best Val Acc': f"{best_val_acc:.2f}%",
                'Final Val Acc': f"{final_val_acc:.2f}%",
                'Training Time (min)': f"{training_time/60:.2f}"
            })

            # Log final metrics to W&B
            wandb.log({
                "best_val_acc": best_val_acc,
                "final_val_acc": final_val_acc,
                "training_time_minutes": training_time/60
            })

            print(f"\nExperiment completed in {training_time/60:.2f} minutes")
            print(f"Best Validation Accuracy: {best_val_acc:.2f}%")

            # Finish W&B run
            wandb.finish()

# Print final comparison table
print("\n" + "="*100)
print("FINAL RESULTS COMPARISON - RESNET18")
print("="*100 + "\n")

# Create table
headers = results[0].keys()
rows = [list(r.values()) for r in results]
print(tabulate(rows, headers=headers, tablefmt="grid"))

# Find best configuration
best_result = max(results, key=lambda x: float(x['Best Val Acc'].rstrip('%')))
print("\n" + "="*100)
print("BEST CONFIGURATION:")
print("="*100)
print(f"Architecture: {best_result['Architecture']}")
print(f"Learning Rate: {best_result['Learning Rate']}")
print(f"Batch Size: {best_result['Batch Size']}")
print(f"Weight Decay: {best_result['Weight Decay']}")
print(f"Best Validation Accuracy: {best_result['Best Val Acc']}")
print(f"Training Time: {best_result['Training Time (min)']} minutes")
print("="*100)

# Save results to file for later comparison
import pickle
with open('resnet18_results.pkl', 'wb') as f:
    pickle.dump(results, f)
print("\nResults saved to 'resnet18_results.pkl'")