# Assignment Module 2: Pet Classification

The goal of this assignment is to implement a neural network that classifies images of 37 breeds of cats and dogs from the [Oxford-IIIT-Pet dataset](https://www.robots.ox.ac.uk/~vgg/data/pets/). The assignment is divided into two parts: first, you will be asked to implement from scratch your own neural network for image classification; then, you will fine-tune a pretrained network provided by PyTorch.


## Dataset

The following cells contain the code to download and access the dataset you will be using in this assignment. Note that, although this dataset features each and every image from [Oxford-IIIT-Pet](https://www.robots.ox.ac.uk/~vgg/data/pets/), it uses a different train-val-test split than the original authors.

In [None]:
# Dataset is already cloned to /scratch.hpc/matteo.preda/ipcv-assignment-2
# No need to clone here
print("Dataset location: /scratch.hpc/matteo.preda/ipcv-assignment-2")

In [None]:
# !git clone https://github.com/CVLAB-Unibo/ipcv-assignment-2.git

In [None]:
import os
import math
from pathlib import Path
from PIL import Image
from torch import Tensor
import torch
from torch.utils.data import Dataset
from typing import List, Tuple
import matplotlib.pyplot as plt
import wandb
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import torch.nn.functional as F
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    classification_report,
)
from tqdm.auto import tqdm


BATCH_SIZE = 128

In [None]:
# Login into wandb
API = '287da622db71b3c4376edca2ab62b34407fb7070'
wandb.login(API)

In [None]:
LABELS_TO_NAME = {
    0: "Abyssinian",
    1: "american_bulldog",
    2: "american_pit_bull_terrier",
    3: "basset_hound",
    4: "beagle",
    5: "Bengal",
    6: "Birman",
    7: "Bombay",
    8: "boxer",
    9: "British_Shorthair",
    10: "chihuahua",
    11: "Egyptian_Mau",
    12: "english_cocker_spaniel",
    13: "english_setter",
    14: "german_shorthaired",
    15: "great_pyrenees",
    16: "havanese",
    17: "japanese_chin",
    18: "keeshond",
    19: "leonberger",
    20: "Maine_Coon",
    21: "miniature_pinscher",
    22: "newfoundland",
    23: "Persian",
    24: "pomeranian",
    25: "pug",
    26: "Ragdoll",
    27: "Russian_Blue",
    28: "saint_bernard",
    29: "samoyed",
    30: "scottish_terrier",
    31: "shiba_inu",
    32: "Siamese",
    33: "Sphynx",
    34: "staffordshire_bull_terrier",
    35: "wheaten_terrier",
    36: "yorkshire_terrier",
}

NUM_CLASSES = 37


class OxfordPetDataset(Dataset):
    def __init__(self, split: str, transform=None) -> None:
        super().__init__()

        self.root = Path("/scratch.hpc/matteo.preda/ipcv-assignment-2") / "dataset"
        # self.root = Path("ipcv-assignment-2") / "dataset"
        self.split = split
        self.names, self.labels = self._get_names_and_labels()
        self.transform = transform

    def __len__(self) -> int:
        return len(self.labels)

    def __getitem__(self, idx) -> Tuple[Tensor, int]:
        img_path = self.root / "images" / f"{self.names[idx]}.jpg"
        img = Image.open(img_path).convert("RGB")
        label = self.labels[idx]

        if self.transform:
            img = self.transform(img)

        return img, label

    def get_num_classes(self) -> int:
        return max(self.labels) + 1

    def _get_names_and_labels(self) -> Tuple[List[str], List[int]]:
        names = []
        labels = []

        with open(self.root / "annotations" / f"{self.split}.txt") as f:
            for line in f:
                name, label = line.replace("\n", "").split(" ")
                names.append(name),
                labels.append(int(label) - 1)

        return names, labels

## Data Inspection

In [None]:
# Optimized transforms for GPU training
train_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# No augmentation for val/test, but still resize and normalize
val_test_transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

train_dataset = OxfordPetDataset("train", transform=train_transform)
val_dataset = OxfordPetDataset("val", transform=val_test_transform)
test_dataset = OxfordPetDataset("test", transform=val_test_transform)

# Optimized DataLoaders for GPU training
# num_workers: parallelize data loading; pin_memory: faster GPU transfer
train_dl = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)
val_dl = DataLoader(
    val_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=2,
    persistent_workers=True
)
test_dl = DataLoader(
    test_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [None]:
def denorm_image(tensor, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
    """
    Denormalize a tensor image (C, H, W) or (B, C, H, W) that was normalized
    using torchvision's Normalize(mean, std).
    """
    if tensor.ndim == 3:
        mean = torch.tensor(mean, device=tensor.device).view(-1, 1, 1)
        std = torch.tensor(std, device=tensor.device).view(-1, 1, 1)
    elif tensor.ndim == 4:
        mean = torch.tensor(mean, device=tensor.device).view(1, -1, 1, 1)
        std = torch.tensor(std, device=tensor.device).view(1, -1, 1, 1)
    else:
        raise ValueError("Expected tensor of shape (C,H,W) or (B,C,H,W)")

    return tensor * std + mean


print(f"The train dataset contains {len(train_dataset)}")
print(f"The test dataset contains {len(val_dataset)}")
print(f"The train dataset contains {len(test_dataset)}")
indexes = torch.randperm(len(train_dataset))[:9]
plt.figure(figsize=(15, 15))
for ind, i in enumerate(indexes):
    img, label = train_dataset[i]
    plt.subplot(3, 3, ind + 1)
    plt.imshow(denorm_image(img).permute(1, 2, 0))
    plt.title(f"Label= {LABELS_TO_NAME[label]}")
    plt.axis("off")

## Part 1: design your own network

Your goal is to implement a convolutional neural network for image classification and train it from scratch on `OxfordPetDataset`. You should consider yourselves satisfied once you obtain a classification accuracy on the test split of ~60%. You are free to achieve this however you want, except for a few rules you must follow:

- Compile this notebook by displaying the results obtained by the best model you found throughout your experimentation; then show how, by removing some of its components, its performance drops. In other words, do an _ablation study_ to prove that your design choices have a positive impact on the final result.

- Do not instantiate an off-the-self PyTorch network. Instead, construct your network as a composition of existing PyTorch layers. In more concrete terms, you can use e.g. `torch.nn.Linear`, but you cannot use e.g. `torchvision.models.alexnet`.

- Show your results and ablations with plots, tables, images, etc. â€” the clearer, the better.

Don't be too concerned with your model performance: the ~60% is just to give you an idea of when to stop. Keep in mind that a thoroughly justified model with lower accuracy will be rewarded more points than a poorly experimentally validated model with higher accuracy.

## Block definition for the net

In [None]:
class Swish(nn.Module):
    """Swish activation function (SiLU)"""
    def forward(self, x):
        return x * torch.sigmoid(x)


class SqueezeExcitation(nn.Module):
    """Squeeze-and-Excitation block for channel attention"""
    def __init__(self, in_channels: int, reduced_dim: int):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, reduced_dim, 1),
            Swish(),
            nn.Conv2d(reduced_dim, in_channels, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.se(x)


class DropConnect(nn.Module):
    """Stochastic depth (drop connect) for regularization"""
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob
    
    def forward(self, x):
        if not self.training or self.drop_prob == 0:
            return x
        
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
        random_tensor.floor_()
        return x.div(keep_prob) * random_tensor


class MBConvBlock(nn.Module):
    """Mobile Inverted Bottleneck Convolution Block with SE"""
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int,
        expand_ratio: int,
        se_ratio: float = 0.25,
        drop_connect_rate: float = 0.0
    ):
        super().__init__()
        self.stride = stride
        self.use_residual = (stride == 1 and in_channels == out_channels)
        
        # Expansion phase
        expanded_channels = in_channels * expand_ratio
        self.expand = expand_ratio != 1
        
        if self.expand:
            self.expansion_conv = nn.Sequential(
                nn.Conv2d(in_channels, expanded_channels, 1, bias=False),
                nn.BatchNorm2d(expanded_channels),
                Swish()
            )
        
        # Depthwise separable convolution
        self.depthwise_conv = nn.Sequential(
            nn.Conv2d(
                expanded_channels,
                expanded_channels,
                kernel_size,
                stride=stride,
                padding=kernel_size // 2,
                groups=expanded_channels,  # Depthwise
                bias=False
            ),
            nn.BatchNorm2d(expanded_channels),
            Swish()
        )
        
        # Squeeze and Excitation
        se_channels = max(1, int(in_channels * se_ratio))
        self.se = SqueezeExcitation(expanded_channels, se_channels)
        
        # Output projection
        self.project = nn.Sequential(
            nn.Conv2d(expanded_channels, out_channels, 1, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        
        self.drop_connect = DropConnect(drop_connect_rate)
    
    def forward(self, x):
        identity = x
        
        if self.expand:
            x = self.expansion_conv(x)
        
        x = self.depthwise_conv(x)
        x = self.se(x)
        x = self.project(x)
        
        if self.use_residual:
            x = self.drop_connect(x)
            x = x + identity
        
        return x

## Net definition and variation

In [None]:
class EfficientNetCustom(nn.Module):
    """Custom EfficientNet-inspired model built from scratch"""
    
    def __init__(
        self,
        num_classes: int = 37,
        width_coefficient: float = 1.0,
        depth_coefficient: float = 1.0,
        dropout_rate: float = 0.2,
        drop_connect_rate: float = 0.2,
        image_size: int = 224
    ):
        super().__init__()
        
        # EfficientNet block configuration
        # Format: [expand_ratio, out_channels, num_blocks, stride, kernel_size]
        base_config = [
            [1, 16, 1, 1, 3],   # Stage 1
            [6, 24, 2, 2, 3],   # Stage 2
            [6, 40, 2, 2, 5],   # Stage 3
            [6, 80, 3, 2, 3],   # Stage 4
            [6, 112, 3, 1, 5],  # Stage 5
            [6, 192, 4, 2, 5],  # Stage 6
            [6, 320, 1, 1, 3]   # Stage 7
        ]
        
        # Scaling functions
        def scale_width(channels):
            return int(math.ceil(channels * width_coefficient / 8) * 8)
        
        def scale_depth(num_blocks):
            return int(math.ceil(num_blocks * depth_coefficient))
        
        # Stem: initial convolution
        out_channels = scale_width(32)
        self.stem = nn.Sequential(
            nn.Conv2d(3, out_channels, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            Swish()
        )
        
        # Build MBConv stages
        blocks = []
        total_blocks = sum([scale_depth(config[2]) for config in base_config])
        block_idx = 0
        
        in_channels = out_channels
        for expand_ratio, channels, num_blocks, stride, kernel_size in base_config:
            out_channels = scale_width(channels)
            num_blocks = scale_depth(num_blocks)
            
            for i in range(num_blocks):
                # Linearly increase drop connect rate
                drop_rate = drop_connect_rate * block_idx / total_blocks
                
                blocks.append(
                    MBConvBlock(
                        in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride if i == 0 else 1,
                        expand_ratio=expand_ratio,
                        se_ratio=0.25,
                        drop_connect_rate=drop_rate
                    )
                )
                in_channels = out_channels
                block_idx += 1
        
        self.blocks = nn.Sequential(*blocks)
        
        # Head: final classification layers
        head_channels = scale_width(1280)
        self.head = nn.Sequential(
            nn.Conv2d(in_channels, head_channels, 1, bias=False),
            nn.BatchNorm2d(head_channels),
            Swish(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(dropout_rate),
            nn.Linear(head_channels, num_classes)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize network weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x):
        x = self.stem(x)
        x = self.blocks(x)
        x = self.head(x)
        return x

## Model factory functions

In [None]:
def create_efficientnet_b0(num_classes: int = 37):
    """EfficientNet-B0: baseline model"""
    return EfficientNetCustom(
        num_classes=num_classes,
        width_coefficient=1.0,
        depth_coefficient=1.0,
        dropout_rate=0.2,
        drop_connect_rate=0.2
    )


def create_efficientnet_b1(num_classes: int = 37):
    """EfficientNet-B1: slightly larger"""
    return EfficientNetCustom(
        num_classes=num_classes,
        width_coefficient=1.0,
        depth_coefficient=1.1,
        dropout_rate=0.2,
        drop_connect_rate=0.2
    )


def create_efficientnet_b2(num_classes: int = 37):
    """EfficientNet-B2: medium size"""
    return EfficientNetCustom(
        num_classes=num_classes,
        width_coefficient=1.1,
        depth_coefficient=1.2,
        dropout_rate=0.4,
        drop_connect_rate=0.2
    )


def count_parameters(model):
    """Count trainable parameters in model"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Model tests

In [None]:
print("Testing EfficientNet Custom Implementation\n")
print("=" * 60)

# Test model creation
model_b0 = create_efficientnet_b0(NUM_CLASSES)
model_b1 = create_efficientnet_b1(NUM_CLASSES)
model_b2 = create_efficientnet_b2(NUM_CLASSES)

# Print parameter counts
models_info = {
    'EfficientNet-B0': model_b0,
    'EfficientNet-B1': model_b1,
    'EfficientNet-B2': model_b2
}

for name, model in models_info.items():
    params = count_parameters(model)
    print(f"{name:30s}: {params:,} parameters")

print("\n" + "=" * 60)

# Test forward pass
test_input = torch.randn(2, 3, 224, 224)
test_output = model_b0(test_input)
print(f"\nForward pass test:")
print(f"  Input shape:  {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Output range: [{test_output.min().item():.3f}, {test_output.max().item():.3f}]")

## Train and Test function

In [None]:
def train_efficientnet(
    model: torch.nn.Module,
    dl_train,
    dl_val,
    criterion,
    optimizer,
    scheduler,
    epochs: int,
    name: str,
    device: str = "cpu",
    project: str = "IPCV2-EfficientNet",
    run_name: str = None,
    save_dir: str = "./checkpoints",
    use_amp: bool = True,
    grad_clip: float = 1.0
):
    """Enhanced training with mixed precision and gradient clipping"""
    
    wandb.init(
        project=project,
        name=run_name,
        entity="mpreda01-universit-di-bologna",
        config={
            "epochs": epochs,
            "optimizer": optimizer.__class__.__name__,
            "scheduler": scheduler.__class__.__name__ if scheduler else "None",
            "criterion": criterion.__class__.__name__,
            "device": device,
            "model": model.__class__.__name__,
            "use_amp": use_amp,
            "grad_clip": grad_clip
        }
    )
    
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(os.path.join(save_dir, name), exist_ok=True)
    
    model.to(device)
    scaler = torch.cuda.amp.GradScaler() if use_amp and device == "cuda" else None
    
    best_val_acc = 0.0
    global_step = 0
    
    for e in tqdm(range(epochs), desc="Training"):
        # Training phase
        model.train()
        running_loss = 0.0
        
        for images, labels in dl_train:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            
            # Mixed precision forward pass
            if use_amp and scaler:
                with torch.cuda.amp.autocast("cuda"):
                    output_logits = model(images)
                    loss = criterion(output_logits, labels)
                
                scaler.scale(loss).backward()
                
                # Gradient clipping
                if grad_clip > 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                
                scaler.step(optimizer)
                scaler.update()
            else:
                output_logits = model(images)
                loss = criterion(output_logits, labels)
                loss.backward()
                
                if grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                
                optimizer.step()
            
            running_loss += loss.item()
            global_step += 1
            
            if global_step % 10 == 0:
                wandb.log({"train_loss_step": loss.item()}, step=global_step)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in dl_val:
                images, labels = images.to(device), labels.to(device)
                
                if use_amp and device == "cuda":
                    with torch.cuda.amp.autocast():
                        output_logits = model(images)
                        loss = criterion(output_logits, labels)
                else:
                    output_logits = model(images)
                    loss = criterion(output_logits, labels)
                
                val_loss += loss.item()
                pred_labels = output_logits.argmax(1)
                correct += (pred_labels == labels).sum().item()
                total += labels.size(0)
        
        avg_val_loss = val_loss / len(dl_val)
        val_accuracy = correct / total
        
        # Learning rate scheduling
        if scheduler:
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
            wandb.log({"learning_rate": current_lr}, step=global_step)
        
        # Logging
        wandb.log({
            "val_loss_epoch": avg_val_loss,
            "val_accuracy": val_accuracy,
            "epoch": e + 1
        }, step=global_step)
        
        # # Save checkpoint
        # checkpoint_path = os.path.join(save_dir, name, f"epoch_{e+1}.pt")
        # torch.save({
        #     "epoch": e + 1,
        #     "model_state_dict": model.state_dict(),
        #     "optimizer_state_dict": optimizer.state_dict(),
        #     "scheduler_state_dict": scheduler.state_dict() if scheduler else None,
        #     "val_loss": avg_val_loss,
        #     "val_accuracy": val_accuracy,
        # }, checkpoint_path)
        
        # # Save best model
        # if val_accuracy > best_val_acc:
        #     best_val_acc = val_accuracy
        #     best_path = os.path.join(save_dir, name, "best_model.pt")
        #     torch.save(model.state_dict(), best_path)
        #     wandb.run.summary["best_val_accuracy"] = best_val_acc
        #     print(f"\nSaved best model (epoch {e+1}, acc={best_val_acc:.4f})")
    
    wandb.finish()

def test(
    model: torch.nn.Module,
    dl_test,
    criterion,
    checkpoint_path: str,
    device: str = "cpu",
    project: str = "my_project",
    run_name: str = "test_run",
):
    """
    Evaluate a trained model on the test set and compute performance metrics.

    Args:
        model: Trained torch.nn.Module
        dl_test: DataLoader for the test set
        criterion: Loss function
        checkpoint_path: Path to a saved model checkpoint (.pt)
        device: Device to run inference on ('cpu' or 'cuda')
        project: W&B project name
        run_name: W&B run name for logging results
    """

    # Initialize W&B
    wandb.init(
        project=project,
        name=run_name,
        entity="mpreda01-universit-di-bologna",
        job_type="test",
        config={
            "device": device,
            "model": model.__class__.__name__,
            "checkpoint": checkpoint_path,
        },
    )

    # Load model checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    if "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)  # if saved with model.state_dict() only

    model.to(device)
    model.eval()

    # Accumulators
    test_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for images, labels in dl_test:
            images, labels = images.to(device), labels.to(device)
            # (B, 3, 224, 224)
            # (B, N)
            # (B,)
            output_logits = model(images)
            loss = criterion(output_logits, labels)
            test_loss += loss.item()

            preds = output_logits.argmax(dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Average loss
    avg_test_loss = test_loss / len(dl_test)

    # Metrics
    acc = accuracy_score(all_labels, all_preds)
    prec = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
    rec = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
    f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)

    # Print report
    print("\n=== Test Results ===")
    print(f"Loss: {avg_test_loss:.4f}")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {prec:.4f}")
    print(f"Recall: {rec:.4f}")
    print(f"F1-score: {f1:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds))

    # Log metrics to W&B
    wandb.log(
        {
            "test_loss": avg_test_loss,
            "test_accuracy": acc,
            "test_precision": prec,
            "test_recall": rec,
            "test_f1": f1,
            "confusion_matrix": wandb.plot.confusion_matrix(
                probs=None,
                y_true=all_labels,
                preds=all_preds,
                title="Confusion Matrix",
            ),
        }
    )

    wandb.finish()

    return {
        "test_loss": avg_test_loss,
        "accuracy": acc,
        "precision": prec,
        "recall": rec,
        "f1": f1,
        "confusion_matrix": cm,
    }

In [None]:
EPOCHS = 150
LR = 1e-3
WEIGHT_DECAY = 1e-3
CRITERION = nn.CrossEntropyLoss()

for model_name, model in models_info.items():
    print(f"Training {model_name}...")  

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

    train_efficientnet(
        model=model,
        dl_train=train_dl,
        dl_val=val_dl,
        criterion=CRITERION,
        optimizer=optimizer,
        scheduler=scheduler,
        epochs=EPOCHS,
        name=model_name,
        project="IPCV2-EfficientNet",
        run_name=f"{model_name}_run",
        use_amp=True,
        grad_clip=1.0
    )

In [None]:
for model_name, model in models_info.items():
    print(f"\nTesting {model_name}...")
    test_metrics = test(
            model,
            test_dl,
            CRITERION,
            checkpoint_path=f"./checkpoints/{model_name}/best_model.pt",
            project="IPCV2-EfficientNet",
            run_name=f"{model_name}_test"
        )
        
    print(f"\nTest Results for {model_name}:")
    print(f"  Accuracy:  {test_metrics['accuracy']:.4f}")
    print(f"  Precision: {test_metrics['precision']:.4f}")
    print(f"  Recall:    {test_metrics['recall']:.4f}")
    print(f"  F1 Score:  {test_metrics['f1']:.4f}")

## Definition of all the experiments to perform the ablation study