# ResNet-50 Baseline
Use this notebook to fine-tune a ResNet-50 classifier

In [30]:
import os
import time
import copy
from pathlib import Path
from typing import Tuple

import pandas as pd

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from sklearn.metrics import classification_report, confusion_matrix

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [31]:
# Configuration
data_root = Path("/Users/enricotazzer/Desktop/multi-task-learning-for-classification-and-segmentation-of-skin-lesions/dataset/classification")
train_dir = data_root / "train"
val_dir = data_root / "val"
test_dir = data_root / "test"

train_img_dir = train_dir / "input"
val_img_dir = val_dir / "input"
test_img_dir = test_dir / "input"

required_paths = [train_img_dir, val_img_dir]
for path in required_paths:
    if not path.exists():
        raise FileNotFoundError(f"Missing required directory: {path}")

train_csv = train_dir / "ground_truth" / "ISIC2018_Task3_Training_GroundTruth.csv"
val_csv = val_dir / "ground_truth" / "ISIC2018_Task3_Validation_GroundTruth.csv"
test_csv = test_dir / "ground_truth" / "ISIC2018_Task3_Test_GroundTruth.csv"

for csv_path in [train_csv, val_csv]:
    if not csv_path.exists():
        raise FileNotFoundError(f"Missing ground-truth CSV: {csv_path}")

df_train = pd.read_csv(train_csv)
label_columns = [col for col in df_train.columns if col != "image"]

def prepare_split(df: pd.DataFrame, split_name: str) -> pd.DataFrame:
    missing = set(label_columns) - set(df.columns)
    if missing:
        raise ValueError(f"Columns {missing} are missing from the {split_name} annotations")
    ordered = df[['image'] + label_columns].copy()
    return ordered

df_val = prepare_split(pd.read_csv(val_csv), "validation")
df_test = prepare_split(pd.read_csv(test_csv), "test") if test_csv.exists() else None

class_names = label_columns
num_classes = len(class_names)

example_row = df_train.iloc[0]
example_vector = example_row[label_columns].to_numpy(dtype=np.float32)
print(f"Detected {num_classes} classes: {class_names}")
print(f"Example: {example_row['image']} -> {example_vector}")

batch_size = 32
num_epochs = 25
learning_rate = 1e-3
weight_decay = 1e-4
label_smoothing = 0.1
num_workers = max(2, os.cpu_count() // 2)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Training on {device}")

Detected 7 classes: ['MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC']
Example: ISIC_0024306 -> [0. 1. 0. 0. 0. 0. 0.]
Training on mps


In [32]:
# Data pipeline
weights = ResNet50_Weights.DEFAULT


def resolve_normalization_params(weights_enum):
    """Extract mean/std from torchvision weights across API variants."""
    transforms_obj = weights_enum.transforms()
    direct_mean = getattr(transforms_obj, "mean", None)
    direct_std = getattr(transforms_obj, "std", None)
    if direct_mean is not None and direct_std is not None:
        return tuple(float(m) for m in direct_mean), tuple(float(s) for s in direct_std)

    for transform in getattr(transforms_obj, "transforms", []):
        nested_mean = getattr(transform, "mean", None)
        nested_std = getattr(transform, "std", None)
        if nested_mean is not None and nested_std is not None:
            return tuple(float(m) for m in nested_mean), tuple(float(s) for s in nested_std)

    meta = getattr(weights_enum, "meta", None)
    if isinstance(meta, dict):
        meta_mean = meta.get("mean")
        meta_std = meta.get("std")
        if meta_mean is not None and meta_std is not None:
            return tuple(float(m) for m in meta_mean), tuple(float(s) for s in meta_std)

    print("Warning: falling back to ImageNet mean/std defaults.")
    return (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)


mean, std = resolve_normalization_params(weights)

train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

eval_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

class LesionDataset(Dataset):
    def __init__(
        self,
        image_dir: Path,
        annotations: pd.DataFrame,
        transform=None,
        image_ext: str = ".jpg",
    ):
        self.image_dir = Path(image_dir)
        if not self.image_dir.exists():
            raise FileNotFoundError(f"Image directory does not exist: {self.image_dir}")
        self.transform = transform
        self.image_ext = image_ext
        self.label_columns = [col for col in annotations.columns if col != "image"]
        self.image_ids = annotations["image"].tolist()
        label_array = annotations[self.label_columns].to_numpy(dtype=np.float32)
        self.targets = torch.as_tensor(label_array, dtype=torch.float32)
        self.samples = []
        missing_files = []
        for image_id in self.image_ids:
            path = self._resolve_image_path(image_id)
            if path.exists():
                self.samples.append(path)
            else:
                missing_files.append(path.name)
        if missing_files:
            raise FileNotFoundError(
                f"{len(missing_files)} images listed in annotations were not found in {self.image_dir}. "
                f"First few missing files: {missing_files[:5]}"
            )
        if len(self.samples) != len(self.targets):
            raise RuntimeError("Mismatch between images and targets after validation.")

    def _resolve_image_path(self, image_id: str) -> Path:
        image_name = image_id if Path(image_id).suffix else f"{image_id}{self.image_ext}"
        return self.image_dir / image_name

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_path = self.samples[index]
        target = self.targets[index]
        with Image.open(image_path) as img:
            image = img.convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, target

train_dataset = LesionDataset(train_img_dir, df_train, transform=train_transforms)
val_dataset = LesionDataset(val_img_dir, df_val, transform=eval_transforms)
test_dataset = (
    LesionDataset(test_img_dir, df_test, transform=eval_transforms)
    if df_test is not None and test_img_dir.exists()
    else None
)

pin_memory = device.type != "cpu"
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
test_loader = (
    DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)
    if test_dataset is not None
    else None
)

print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)}")
if test_dataset is not None:
    print(f"Test samples: {len(test_dataset)}")

Train samples: 10015 | Val samples: 193
Test samples: 1512


In [34]:
# Model setup
model = models.resnet50(weights=weights)
for param in model.parameters():
    param.requires_grad = False
for param in model.layer4.parameters():
    param.requires_grad = True
model.fc = nn.Sequential(
    nn.Dropout(p=0.2),
    nn.Linear(model.fc.in_features, num_classes)
)
model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=label_smoothing)
optimizer = torch.optim.AdamW((p for p in model.parameters() if p.requires_grad), lr=learning_rate, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=learning_rate * 0.1)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/enricotazzer/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


URLError: <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1018)>

In [None]:
# Training utilities
def run_epoch(model, dataloader, criterion, optimizer=None):
    is_train = optimizer is not None
    model.train(is_train)
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0
    start = time.time()

    for inputs, targets in dataloader:
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        target_indices = torch.argmax(targets, dim=1)

        if is_train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(is_train):
            outputs = model(inputs)
            loss = criterion(outputs, target_indices)
            if is_train:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
                optimizer.step()

        _, preds = torch.max(outputs, 1)
        batch_size = inputs.size(0)
        running_loss += loss.item() * batch_size
        running_corrects += torch.sum(preds == target_indices).item()
        total_samples += batch_size

    epoch_loss = running_loss / max(total_samples, 1)
    epoch_acc = running_corrects / max(total_samples, 1)
    elapsed = time.time() - start
    return epoch_loss, epoch_acc, elapsed


@torch.no_grad()
def evaluate(model, dataloader, criterion):
    model.eval()
    running_loss = 0.0
    running_corrects = 0
    total_samples = 0

    all_preds = []
    all_targets = []

    for inputs, targets in dataloader:
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        target_indices = torch.argmax(targets, dim=1)

        outputs = model(inputs)
        loss = criterion(outputs, target_indices)
        _, preds = torch.max(outputs, 1)

        batch_size = inputs.size(0)
        running_loss += loss.item() * batch_size
        running_corrects += torch.sum(preds == target_indices).item()
        total_samples += batch_size

        all_preds.append(preds.detach().cpu())
        all_targets.append(target_indices.detach().cpu())

    epoch_loss = running_loss / max(total_samples, 1)
    epoch_acc = running_corrects / max(total_samples, 1)

    if all_preds:
        all_preds = torch.cat(all_preds).numpy()
        all_targets = torch.cat(all_targets).numpy()
    else:
        all_preds = np.array([])
        all_targets = np.array([])

    return epoch_loss, epoch_acc, all_preds, all_targets

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler=None, num_epochs=10, checkpoint_dir="checkpoints"):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    best_ckpt_path = checkpoint_dir / "resnet50_best.pt"

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        print("-" * 20)

        train_loss, train_acc, train_time = run_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_acc, val_preds, val_targets = evaluate(model, val_loader, criterion)
        if scheduler is not None:
            scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"train loss: {train_loss:.4f} | train acc: {train_acc:.4f} | time: {train_time:.1f}s")
        print(f"val   loss: {val_loss:.4f} | val   acc: {val_acc:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save({
                "model_state_dict": best_model_wts,
                "val_acc": best_acc,
                "epoch": epoch + 1,
                "class_names": class_names,
            }, best_ckpt_path)
            print(f"\nâœ… Saved new best checkpoint to {best_ckpt_path}\n")

    print(f"Best val acc: {best_acc:.4f}")
    model.load_state_dict(best_model_wts)
    return model, history


In [None]:
# Train the model
if __name__ == "__main__":
    if len(train_dataset) == 0 or len(val_dataset) == 0:
        raise RuntimeError("Training/validation datasets are empty. Check the data directory structure.")
    
    %time trained_model, history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=num_epochs,
        checkpoint_dir="artifacts",
    )

In [None]:
# Learning curves
if 'history' in locals():
    import matplotlib.pyplot as plt
    epochs = range(1, len(history['train_loss']) + 1)
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))
    axes[0].plot(epochs, history['train_loss'], label='Train')
    axes[0].plot(epochs, history['val_loss'], label='Val')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Cross-Entropy')
    axes[0].legend()

    axes[1].plot(epochs, history['train_acc'], label='Train')
    axes[1].plot(epochs, history['val_acc'], label='Val')
    axes[1].set_title('Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    plt.tight_layout()
else:
    print("Run the training cell first to generate history.")

In [None]:
# Validation metrics
if 'trained_model' in locals():
    val_loss, val_acc, val_preds, val_targets = evaluate(trained_model, val_loader, criterion)
    print(f"Validation loss: {val_loss:.4f}")
    print(f"Validation acc : {val_acc:.4f}")
    if val_targets.size > 0:
        print(classification_report(val_targets, val_preds, target_names=class_names, digits=4))
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
        except ImportError:
            print("Install seaborn and matplotlib to visualize the confusion matrix.")
        else:
            cm = confusion_matrix(val_targets, val_preds)
            fig, ax = plt.subplots(figsize=(6, 6))
            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names, ax=ax)
            ax.set_xlabel('Predicted')
            ax.set_ylabel('True')
            ax.set_title('Validation Confusion Matrix')
            plt.tight_layout()
    else:
        print("No validation predictions captured. This can happen if the dataset is empty.")
else:
    print("Train the model before running this cell.")

In [None]:
# Test set evaluation
if test_loader is not None and 'trained_model' in locals():
    test_loss, test_acc, test_preds, test_targets = evaluate(trained_model, test_loader, criterion)
    print(f"Test loss: {test_loss:.4f}")
    print(f"Test acc : {test_acc:.4f}")
else:
    print("No test set detected or model has not been trained yet.")