# Week 3 — ResNet18 Finetune-All & Light Augmentations (MedMNIST-SSL)

This Colab-friendly notebook is a **starter template** for Week 3 of the `medmnist-ssl` project.

You will:
- Load a **binary** MedMNIST2D dataset (`pneumoniamnist` or `breastmnist`).
- Train **ResNet-18** in two regimes:
  - `head-only` (frozen backbone, train final classifier layer only)
  - `finetune-all` (all layers trainable)
- Add **light, modality-aware data augmentations** on the train split.
- Collect **per-epoch train/val metrics** and simple **learning curves**.

You can later copy the metrics / plots into your repo under
`results/week3/<dataset_key>/<your_name>/`.

In [None]:
%%bash
# If you are running on Colab, uncomment the following line.
# pip install -q medmnist torch torchvision scikit-learn tqdm

echo "If you're on Colab: uncomment the pip install line above and run this cell once."

In [None]:
import os
import json
from typing import Dict, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as T
import torchvision.models as models

import medmnist
from medmnist import INFO

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# -----------------------------
# Basic config (edit here)
# -----------------------------
DATASET_KEY = 'pneumoniamnist'  # or 'breastmnist'
RUNS_ROOT = './week3_runs'      # where to save metrics and plots

BATCH_SIZE = 128
EPOCHS_HEAD = 5     # for head-only run
EPOCHS_ALL = 8      # for finetune-all runs
LR = 1e-3
USE_IMAGENET_PRETRAINED = True

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.makedirs(RUNS_ROOT, exist_ok=True)
print('Using device:', device)
print('Saving outputs to:', os.path.abspath(RUNS_ROOT))

In [None]:
# -----------------------------
# Dataset + transforms
# -----------------------------
info = INFO[DATASET_KEY]
DataClass = getattr(medmnist, info['python_class'])
n_classes = len(info['label'])
print('Dataset:', DATASET_KEY, '| num_classes =', n_classes)

# Base normalization (simple 0.5/0.5 for now)
MEAN = [0.5, 0.5, 0.5]
STD = [0.5, 0.5, 0.5]

# Basic transform (no strong augmentation)
base_transform = T.Compose([
    T.ToPILImage(),
    T.Resize(224),
    T.Grayscale(num_output_channels=3),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD),
])

# AugA: light geometric augmentation (flip + small rotation)
train_transform_augA = T.Compose([
    T.ToPILImage(),
    T.Resize(224),
    T.Grayscale(num_output_channels=3),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=10),
    T.ToTensor(),
    T.Normalize(mean=MEAN, std=STD),
])

def make_dataloaders(train_aug: str = 'basic'):
    """Create train/val/test DataLoaders with a chosen train-time augmentation.

    train_aug: 'basic' or 'augA'
    """
    if train_aug == 'basic':
        train_transform = base_transform
    elif train_aug == 'augA':
        train_transform = train_transform_augA
    else:
        raise ValueError(f'Unknown train_aug: {train_aug}')

    train_ds = DataClass(split='train', transform=train_transform, download=True)
    val_ds   = DataClass(split='val',   transform=base_transform, download=True)
    test_ds  = DataClass(split='test',  transform=base_transform, download=True)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loader


In [None]:
# -----------------------------
# ResNet-18 model helper
# -----------------------------
def create_resnet18(num_classes: int = 2,
                    finetune_mode: str = 'head',
                    use_imagenet_pretrained: bool = True) -> nn.Module:
    """Create a ResNet-18 model.

    finetune_mode: 'head' (freeze backbone) or 'all' (unfreeze all params)
    """
    try:
        # Newer torchvision API (PyTorch >= 1.12)
        weights = models.ResNet18_Weights.IMAGENET1K_V1 if use_imagenet_pretrained else None
        model = models.resnet18(weights=weights)
    except AttributeError:
        # Fallback to older API
        model = models.resnet18(pretrained=use_imagenet_pretrained)

    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)

    if finetune_mode == 'head':
        for name, param in model.named_parameters():
            if not name.startswith('fc.'):
                param.requires_grad = False
    elif finetune_mode == 'all':
        for param in model.parameters():
            param.requires_grad = True
    else:
        raise ValueError(f'Unknown finetune_mode: {finetune_mode}')

    return model.to(device)


In [None]:
# -----------------------------
# Training & evaluation helpers
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, targets in loader:
        images = images.to(device)
        # MedMNIST labels come as shape [B, 1]
        targets = targets.view(-1).long().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += (preds == targets).sum().item()
        total += images.size(0)

    return running_loss / total, correct / total


def evaluate(model, loader, criterion, compute_auroc: bool = True):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_probs = []
    all_targets = []

    with torch.inference_mode():
        for images, targets in loader:
            images = images.to(device)
            targets = targets.view(-1).long().to(device)

            outputs = model(images)
            loss = criterion(outputs, targets)

            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += (preds == targets).sum().item()
            total += images.size(0)

            if compute_auroc:
                probs = torch.softmax(outputs, dim=1)[:, 1]  # positive class prob
                all_probs.append(probs.cpu().numpy())
                all_targets.append(targets.cpu().numpy())

    avg_loss = running_loss / total
    acc = correct / total

    if compute_auroc:
        all_probs = np.concatenate(all_probs)
        all_targets = np.concatenate(all_targets)
        try:
            auroc = roc_auc_score(all_targets, all_probs)
        except ValueError:
            # If only one class present in y_true
            auroc = float('nan')
    else:
        auroc = float('nan')

    return avg_loss, acc, auroc


def run_experiment(run_name: str,
                  finetune_mode: str,
                  train_aug: str,
                  num_epochs: int,
                  lr: float = 1e-3) -> Dict:
    print(f"\n=== Run: {run_name} | finetune={finetune_mode} | aug={train_aug} ===")

    train_loader, val_loader, test_loader = make_dataloaders(train_aug=train_aug)
    model = create_resnet18(num_classes=n_classes,
                            finetune_mode=finetune_mode,
                            use_imagenet_pretrained=USE_IMAGENET_PRETRAINED)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)

    history = {
        'epoch': [],
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
    }

    best_val_acc = 0.0
    best_state = None

    for epoch in range(1, num_epochs + 1):
        print(f"Epoch {epoch}/{num_epochs}")
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_acc, _ = evaluate(model, val_loader, criterion, compute_auroc=False)

        history['epoch'].append(epoch)
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)

        print(f"  train_loss={train_loss:.4f} | train_acc={train_acc:.4f} | "
              f"val_loss={val_loss:.4f} | val_acc={val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_state = model.state_dict()

    # Evaluate best model on test set
    if best_state is not None:
        model.load_state_dict(best_state)

    test_loss, test_acc, test_auroc = evaluate(model, test_loader, criterion, compute_auroc=True)
    print(f"Test: loss={test_loss:.4f} | acc={test_acc:.4f} | auroc={test_auroc:.4f}")

    # Save metrics to JSON
    out_metrics = {
        'run_name': run_name,
        'dataset_key': DATASET_KEY,
        'finetune_mode': finetune_mode,
        'train_aug': train_aug,
        'num_epochs': num_epochs,
        'test_loss': test_loss,
        'test_acc': test_acc,
        'test_auroc': test_auroc,
        'history': history,
    }
    metrics_path = os.path.join(RUNS_ROOT, f'metrics_{run_name}.json')
    with open(metrics_path, 'w') as f:
        json.dump(out_metrics, f, indent=2)
    print('Saved metrics to:', metrics_path)

    return out_metrics


In [None]:
# -----------------------------
# Run Week 3 experiments
# -----------------------------
all_runs: List[Dict] = []

# 1) Head-only, basic transforms (Week 2 baseline re-run, shorter epochs if needed)
run1 = run_experiment(
    run_name=f'{DATASET_KEY}_resnet18_head_basic',
    finetune_mode='head',
    train_aug='basic',
    num_epochs=EPOCHS_HEAD,
    lr=LR,
)
all_runs.append(run1)

# 2) Finetune-all, basic transforms
run2 = run_experiment(
    run_name=f'{DATASET_KEY}_resnet18_all_basic',
    finetune_mode='all',
    train_aug='basic',
    num_epochs=EPOCHS_ALL,
    lr=LR,
)
all_runs.append(run2)

# 3) Finetune-all, AugA (flip + rotation)
run3 = run_experiment(
    run_name=f'{DATASET_KEY}_resnet18_all_augA',
    finetune_mode='all',
    train_aug='augA',
    num_epochs=EPOCHS_ALL,
    lr=LR,
)
all_runs.append(run3)

print('\nSummary:')
for r in all_runs:
    print(r['run_name'], '| acc={:.4f}'.format(r['test_acc']), '| auroc={:.4f}'.format(r['test_auroc']))

In [None]:
# -----------------------------
# Plot learning curves
# -----------------------------
def plot_learning_curves(run_dict: Dict, save_name: str = None):
    h = run_dict['history']
    epochs = h['epoch']

    plt.figure(figsize=(6, 4))
    plt.plot(epochs, h['train_loss'], label='train_loss')
    plt.plot(epochs, h['val_loss'], label='val_loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(run_dict['run_name'] + ' — loss')
    plt.legend()
    plt.tight_layout()
    if save_name is not None:
        path = os.path.join(RUNS_ROOT, save_name)
        plt.savefig(path, dpi=150)
        print('Saved:', path)
    plt.show()

    plt.figure(figsize=(6, 4))
    plt.plot(epochs, h['train_acc'], label='train_acc')
    plt.plot(epochs, h['val_acc'], label='val_acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title(run_dict['run_name'] + ' — accuracy')
    plt.legend()
    plt.tight_layout()
    if save_name is not None:
        base, ext = os.path.splitext(save_name)
        path = os.path.join(RUNS_ROOT, base + '_acc' + ext)
        plt.savefig(path, dpi=150)
        print('Saved:', path)
    plt.show()


for r in all_runs:
    short_name = r['run_name'].replace(DATASET_KEY + '_', '')
    plot_learning_curves(r, save_name=f'curves_{short_name}.png')

## Next Steps

- Move the saved `metrics_*.json` and `curves_*.png` files into your repo under:
  `results/week3/<dataset_key>/<your_name>/`.
- Use these metrics and curves to fill in **Week 3 tables and plots** in your `README_week3.md`.
- Add **dataset-specific observations**:
  - For **PneumoniaMNIST**: talk about chest X-ray patterns, false positives/negatives, and which aug helps.
  - For **BreastMNIST**: talk about ultrasound noise, subtle lesions, and over/underfitting patterns.

This notebook is only a **starter** – you can adjust epochs, learning rates, and augmentation strength
as long as you clearly document what you changed in your README.