# ResNet18 Transfer Learning on PCam

In this notebook I explore transfer learning with ResNet18 on the PatchCamelyon (PCam) dataset.

Goals:
- Use the same PCam dataloaders and preprocessing as for the SmallCNN.
- Compare different transfer learning modes:
  - **frozen**: only the final classification head is trainable
  - **partial**: last ResNet block + head are trainable
- Evaluate models using **AUROC** and **AUPRC** on the validation set.
- Prepare results for later comparison with the SmallCNN notebook.


## 1) Setup: imports, project root, device, seed

In [41]:
import os
import sys
import random
from pathlib import Path

import numpy as np
import torch

# -------------------------------------------------------------------
# Make sure we can import from the project src/ folder
# -------------------------------------------------------------------
ROOT = Path().resolve()
if not (ROOT / "src").exists():
    # If the notebook is inside notebooks/, go one level up
    ROOT = ROOT.parent
    os.chdir(ROOT)

if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

print("Project root:", ROOT)


# -------------------------------------------------------------------
# Reproducibility helpers
# -------------------------------------------------------------------
def set_seed(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def get_device() -> torch.device:
    """Select best available device: CUDA, MPS (Apple), or CPU."""
    if torch.cuda.is_available():
        return torch.device("cuda")
    if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


set_seed(42)
DEVICE = get_device()
print("Using device:", DEVICE)


Project root: /Users/ellenaspiess/Documents/PCam_Project/PCam
Using device: cpu


## 2) Load Data

In [42]:
from src.datasets.dataloaders import get_pcam_dataloaders

# You can adjust batch_size and crop size if needed
BATCH_SIZE = 64
CENTER_CROP = 64

loaders = get_pcam_dataloaders(
    data_root="data/raw",
    batch_size=BATCH_SIZE,
    center_crop_size=CENTER_CROP,
    num_workers=0,  # 0 is safer on macOS / some environments
)

print("Number of batches:")
for split in ["train", "val", "test"]:
    print(f"  {split}: {len(loaders[split])}")


Number of batches:
  train: 4096
  val: 512
  test: 512


## 3) Sanity Check

In [43]:
images, labels = next(iter(loaders["train"]))
print("Images shape:", images.shape)   # [B, 3, 64, 64]
print("Labels shape:", labels.shape)
print("First 10 labels:", labels[:10])

unique, counts = torch.unique(labels, return_counts=True)
print("Label distribution in this batch:", dict(zip(unique.tolist(), counts.tolist())))


Images shape: torch.Size([64, 3, 64, 64])
Labels shape: torch.Size([64])
First 10 labels: tensor([0, 1, 1, 0, 0, 1, 1, 0, 1, 0])
Label distribution in this batch: {0: 26, 1: 38}


## 4) Import ResNet Model

In [None]:
from src.models.resnet_pcam_gpu import ResNetPCam, ResNetConfig

def count_trainable_params(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

for mode in ["frozen", "partial", "full"]:
    cfg = ResNetConfig(tl_mode=mode, pretrained=True)
    m = ResNetPCam(cfg)
    print(
        f"Mode='{mode}': trainable params = {count_trainable_params(m):,}"
    )


Mode='frozen': trainable params = 513
Mode='partial': trainable params = 8,394,241
Mode='full': trainable params = 11,177,025


## 5) Tuning Hyperparameters with Optuna

In [5]:
from torch.utils.data import DataLoader, Subset
import torch


def get_tuning_dataloaders(
    batch_size: int,
    max_train_samples: int = 2000,
    max_val_samples: int = 500,
    center_crop_size: int = 64,
):
    """
    Build smaller train/val dataloaders for hyperparameter tuning.

    This uses a subset of the full train/val sets to keep each trial reasonably cheap.
    """
    # We reuse the same underlying datasets that get_pcam_dataloaders uses.
    full_loaders = get_pcam_dataloaders(
        data_root="data/raw",
        batch_size=batch_size,
        center_crop_size=center_crop_size,
        num_workers=0,
    )

    full_train_ds = full_loaders["train"].dataset
    full_val_ds = full_loaders["val"].dataset

    # Create deterministic random subsets for reproducibility
    g = torch.Generator().manual_seed(42)

    train_size = min(len(full_train_ds), max_train_samples)
    val_size = min(len(full_val_ds), max_val_samples)

    train_indices = torch.randperm(len(full_train_ds), generator=g)[:train_size]
    val_indices = torch.randperm(len(full_val_ds), generator=g)[:val_size]

    train_subset = Subset(full_train_ds, train_indices.tolist())
    val_subset = Subset(full_val_ds, val_indices.tolist())

    train_loader = DataLoader(
        train_subset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_subset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        pin_memory=True,
    )

    print(
        f"Tuning loaders: train_samples={len(train_subset)}, "
        f"val_samples={len(val_subset)}, batch_size={batch_size}"
    )

    return train_loader, val_loader


In [6]:
import optuna
from torch import nn, optim
from src.training.utils_training import evaluate_binary_classifier

def build_resnet_and_optimizer(
    tl_mode: str,
    lr: float,
    weight_decay: float,
):
    """
    Helper to construct a ResNet model + optimizer + loss for a given trial.
    """
    cfg = ResNetConfig(tl_mode=tl_mode, pretrained=True)
    model = ResNetPCam(cfg).to(DEVICE)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(
        trainable_params,
        lr=lr,
        weight_decay=weight_decay,
    )
    criterion = nn.BCEWithLogitsLoss()
    return model, optimizer, criterion



# Fixed TL mode for the current study.
# We will set this to "frozen" or "partial" before running a study.
FIXED_TL_MODE = "frozen"  # will be overwritten below as needed


def objective(trial: optuna.Trial) -> float:
    """
    Optuna objective function for ResNet18 on PCam.

    We optimize validation AUROC (higher is better).
    The transfer learning mode (tl_mode) is fixed per study (frozen / partial).
    """
    tl_mode = FIXED_TL_MODE  # <-- NOT sampled, fixed for this study

    # Hyperparameter search space (per mode)
    lr = trial.suggest_float("lr", 1e-4, 8e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32])

    # Smaller dataloaders for tuning
    train_loader, val_loader = get_tuning_dataloaders(
        batch_size=batch_size,
        max_train_samples=2000,
        max_val_samples=500,
        center_crop_size=CENTER_CROP,
    )

    model, optimizer, criterion = build_resnet_and_optimizer(
        tl_mode=tl_mode,
        lr=lr,
        weight_decay=weight_decay,
    )

    num_epochs = 2  # keep tuning epochs small
    best_val_auroc = 0.0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for images, labels in train_loader:
            images = images.to(DEVICE)
            labels = labels.float().to(DEVICE)

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss /= len(train_loader.dataset)

        val_loss, val_auroc, val_auprc = evaluate_binary_classifier(
            model, val_loader, criterion, DEVICE
        )

        trial.report(val_auroc, step=epoch)

        if trial.should_prune():
            raise optuna.TrialPruned()

        best_val_auroc = max(best_val_auroc, val_auroc)

        print(
            f"[trial {trial.number:02d} | mode={tl_mode}] "
            f"epoch={epoch+1} lr={lr:.1e} wd={weight_decay:.1e} bs={batch_size} | "
            f"val_AUROC={val_auroc:.3f}"
        )

    return best_val_auroc

In [7]:
FIXED_TL_MODE = "frozen"
print("Running Optuna study for tl_mode='frozen'")

sampler = optuna.samplers.TPESampler(seed=42)
pruner = optuna.pruners.MedianPruner(n_startup_trials=2, n_warmup_steps=1)

study_frozen = optuna.create_study(
    study_name="resnet_pcam_frozen",
    direction="maximize",
    sampler=sampler,
    pruner=pruner,
)

study_frozen.optimize(objective, n_trials=12)  # increase later if you want

print("Finished trials (frozen):", len(study_frozen.trials))
print("Best AUROC (frozen):", study_frozen.best_trial.value)
print("Best params (frozen):")
for k, v in study_frozen.best_trial.params.items():
    print(f"  {k}: {v}")

best_params_frozen = study_frozen.best_trial.params
best_params_frozen


[I 2025-12-14 09:36:24,934] A new study created in memory with name: resnet_pcam_frozen


Running Optuna study for tl_mode='frozen'
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 00 | mode=frozen] epoch=1 lr=2.2e-04 wd=8.0e-04 bs=16 | val_AUROC=0.583


[I 2025-12-14 09:36:52,522] Trial 0 finished with value: 0.7126940106067241 and parameters: {'lr': 0.0002178930765977573, 'weight_decay': 0.0007969454818643932, 'batch_size': 16}. Best is trial 0 with value: 0.7126940106067241.


[trial 00 | mode=frozen] epoch=2 lr=2.2e-04 wd=8.0e-04 bs=16 | val_AUROC=0.713
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 01 | mode=frozen] epoch=1 lr=1.4e-04 wd=2.1e-05 bs=32 | val_AUROC=0.636


[I 2025-12-14 09:37:19,188] Trial 1 finished with value: 0.6933943138876725 and parameters: {'lr': 0.00013832442448846117, 'weight_decay': 2.0511104188433963e-05, 'batch_size': 32}. Best is trial 0 with value: 0.7126940106067241.


[trial 01 | mode=frozen] epoch=2 lr=1.4e-04 wd=2.1e-05 bs=32 | val_AUROC=0.693
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 02 | mode=frozen] epoch=1 lr=3.5e-04 wd=2.6e-04 bs=32 | val_AUROC=0.667


[I 2025-12-14 09:37:42,888] Trial 2 finished with value: 0.7603431778004833 and parameters: {'lr': 0.0003490285460630004, 'weight_decay': 0.0002607024758370766, 'batch_size': 32}. Best is trial 2 with value: 0.7603431778004833.


[trial 02 | mode=frozen] epoch=2 lr=3.5e-04 wd=2.6e-04 bs=32 | val_AUROC=0.760
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 03 | mode=frozen] epoch=1 lr=5.6e-04 wd=2.7e-05 bs=32 | val_AUROC=0.753


[I 2025-12-14 09:38:05,857] Trial 3 finished with value: 0.8078463809014094 and parameters: {'lr': 0.0005646386642932582, 'weight_decay': 2.6587543983272695e-05, 'batch_size': 32}. Best is trial 3 with value: 0.8078463809014094.


[trial 03 | mode=frozen] epoch=2 lr=5.6e-04 wd=2.7e-05 bs=32 | val_AUROC=0.808
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 04 | mode=frozen] epoch=1 lr=1.9e-04 wd=1.1e-04 bs=16 | val_AUROC=0.683


[I 2025-12-14 09:38:31,086] Trial 4 finished with value: 0.7515042410678084 and parameters: {'lr': 0.00018826002986047262, 'weight_decay': 0.00011207606211860574, 'batch_size': 16}. Best is trial 3 with value: 0.8078463809014094.


[trial 04 | mode=frozen] epoch=2 lr=1.9e-04 wd=1.1e-04 bs=16 | val_AUROC=0.752
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 05 | mode=frozen] epoch=1 lr=3.6e-04 wd=1.9e-05 bs=32 | val_AUROC=0.668


[I 2025-12-14 09:38:53,866] Trial 5 finished with value: 0.7554777080393779 and parameters: {'lr': 0.0003569095943771273, 'weight_decay': 1.9010245319870364e-05, 'batch_size': 32}. Best is trial 3 with value: 0.8078463809014094.


[trial 05 | mode=frozen] epoch=2 lr=3.6e-04 wd=1.9e-05 bs=32 | val_AUROC=0.755
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 06 | mode=frozen] epoch=1 lr=2.6e-04 wd=3.7e-04 bs=32 | val_AUROC=0.672


[I 2025-12-14 09:39:16,898] Trial 6 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 07 | mode=frozen] epoch=1 lr=3.4e-04 wd=1.2e-05 bs=16 | val_AUROC=0.749


[I 2025-12-14 09:39:41,475] Trial 7 finished with value: 0.7963800904977376 and parameters: {'lr': 0.0003427706793941003, 'weight_decay': 1.2385137298860926e-05, 'batch_size': 16}. Best is trial 3 with value: 0.8078463809014094.


[trial 07 | mode=frozen] epoch=2 lr=3.4e-04 wd=1.2e-05 bs=16 | val_AUROC=0.796
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 08 | mode=frozen] epoch=1 lr=1.1e-04 wd=7.9e-04 bs=16 | val_AUROC=0.567


[I 2025-12-14 09:40:06,893] Trial 8 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 09 | mode=frozen] epoch=1 lr=1.9e-04 wd=1.6e-05 bs=16 | val_AUROC=0.630


[I 2025-12-14 09:40:32,452] Trial 9 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 10 | mode=frozen] epoch=1 lr=7.1e-04 wd=5.3e-05 bs=32 | val_AUROC=0.791


[I 2025-12-14 09:40:55,696] Trial 10 finished with value: 0.8155662595890301 and parameters: {'lr': 0.0007137454446124092, 'weight_decay': 5.3389437745561126e-05, 'batch_size': 32}. Best is trial 10 with value: 0.8155662595890301.


[trial 10 | mode=frozen] epoch=2 lr=7.1e-04 wd=5.3e-05 bs=32 | val_AUROC=0.816
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 11 | mode=frozen] epoch=1 lr=7.2e-04 wd=4.9e-05 bs=32 | val_AUROC=0.731


[I 2025-12-14 09:41:18,646] Trial 11 finished with value: 0.7960070711493862 and parameters: {'lr': 0.0007249712432705719, 'weight_decay': 4.9211947176147695e-05, 'batch_size': 32}. Best is trial 10 with value: 0.8155662595890301.


[trial 11 | mode=frozen] epoch=2 lr=7.2e-04 wd=4.9e-05 bs=32 | val_AUROC=0.796
Finished trials (frozen): 12
Best AUROC (frozen): 0.8155662595890301
Best params (frozen):
  lr: 0.0007137454446124092
  weight_decay: 5.3389437745561126e-05
  batch_size: 32


{'lr': 0.0007137454446124092,
 'weight_decay': 5.3389437745561126e-05,
 'batch_size': 32}

In [8]:
FIXED_TL_MODE = "partial"
print("\nRunning Optuna study for tl_mode='partial'")

sampler = optuna.samplers.TPESampler(seed=42)
pruner = optuna.pruners.MedianPruner(n_startup_trials=2, n_warmup_steps=1)

study_partial = optuna.create_study(
    study_name="resnet_pcam_partial",
    direction="maximize",
    sampler=sampler,
    pruner=pruner,
)

study_partial.optimize(objective, n_trials=12)  # same budget as frozen

print("Finished trials (partial):", len(study_partial.trials))
print("Best AUROC (partial):", study_partial.best_trial.value)
print("Best params (partial):")
for k, v in study_partial.best_trial.params.items():
    print(f"  {k}: {v}")

best_params_partial = study_partial.best_trial.params
best_params_partial

[I 2025-12-14 09:58:33,918] A new study created in memory with name: resnet_pcam_partial



Running Optuna study for tl_mode='partial'
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 00 | mode=partial] epoch=1 lr=2.2e-04 wd=8.0e-04 bs=16 | val_AUROC=0.827


[I 2025-12-14 09:59:33,898] Trial 0 finished with value: 0.862145023435346 and parameters: {'lr': 0.0002178930765977573, 'weight_decay': 0.0007969454818643932, 'batch_size': 16}. Best is trial 0 with value: 0.862145023435346.


[trial 00 | mode=partial] epoch=2 lr=2.2e-04 wd=8.0e-04 bs=16 | val_AUROC=0.862
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 01 | mode=partial] epoch=1 lr=1.4e-04 wd=2.1e-05 bs=32 | val_AUROC=0.846


[I 2025-12-14 10:00:18,683] Trial 1 finished with value: 0.8780551095541609 and parameters: {'lr': 0.00013832442448846117, 'weight_decay': 2.0511104188433963e-05, 'batch_size': 32}. Best is trial 1 with value: 0.8780551095541609.


[trial 01 | mode=partial] epoch=2 lr=1.4e-04 wd=2.1e-05 bs=32 | val_AUROC=0.878
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 02 | mode=partial] epoch=1 lr=3.5e-04 wd=2.6e-04 bs=32 | val_AUROC=0.842


[I 2025-12-14 10:01:01,726] Trial 2 finished with value: 0.8887591430285927 and parameters: {'lr': 0.0003490285460630004, 'weight_decay': 0.0002607024758370766, 'batch_size': 32}. Best is trial 2 with value: 0.8887591430285927.


[trial 02 | mode=partial] epoch=2 lr=3.5e-04 wd=2.6e-04 bs=32 | val_AUROC=0.889
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 03 | mode=partial] epoch=1 lr=5.6e-04 wd=2.7e-05 bs=32 | val_AUROC=0.867


[I 2025-12-14 10:01:47,523] Trial 3 finished with value: 0.8963979305535283 and parameters: {'lr': 0.0005646386642932582, 'weight_decay': 2.6587543983272695e-05, 'batch_size': 32}. Best is trial 3 with value: 0.8963979305535283.


[trial 03 | mode=partial] epoch=2 lr=5.6e-04 wd=2.7e-05 bs=32 | val_AUROC=0.896
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 04 | mode=partial] epoch=1 lr=1.9e-04 wd=1.1e-04 bs=16 | val_AUROC=0.872


[I 2025-12-14 10:02:48,473] Trial 4 finished with value: 0.8859533887996885 and parameters: {'lr': 0.00018826002986047262, 'weight_decay': 0.00011207606211860574, 'batch_size': 16}. Best is trial 3 with value: 0.8963979305535283.


[trial 04 | mode=partial] epoch=2 lr=1.9e-04 wd=1.1e-04 bs=16 | val_AUROC=0.886
Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 05 | mode=partial] epoch=1 lr=3.6e-04 wd=1.9e-05 bs=32 | val_AUROC=0.844


[I 2025-12-14 10:03:31,730] Trial 5 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 06 | mode=partial] epoch=1 lr=2.6e-04 wd=3.7e-04 bs=32 | val_AUROC=0.889


[I 2025-12-14 10:04:15,033] Trial 6 finished with value: 0.8893592176324625 and parameters: {'lr': 0.00025815006344207555, 'weight_decay': 0.00037183641805732076, 'batch_size': 32}. Best is trial 3 with value: 0.8963979305535283.


[trial 06 | mode=partial] epoch=2 lr=2.6e-04 wd=3.7e-04 bs=32 | val_AUROC=0.886
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 07 | mode=partial] epoch=1 lr=3.4e-04 wd=1.2e-05 bs=16 | val_AUROC=0.861


[I 2025-12-14 10:05:15,169] Trial 7 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 08 | mode=partial] epoch=1 lr=1.1e-04 wd=7.9e-04 bs=16 | val_AUROC=0.874


[I 2025-12-14 10:06:15,112] Trial 8 finished with value: 0.8896187093530548 and parameters: {'lr': 0.00011448469784568747, 'weight_decay': 0.000790261954970823, 'batch_size': 16}. Best is trial 3 with value: 0.8963979305535283.


[trial 08 | mode=partial] epoch=2 lr=1.1e-04 wd=7.9e-04 bs=16 | val_AUROC=0.890
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 09 | mode=partial] epoch=1 lr=1.9e-04 wd=1.6e-05 bs=16 | val_AUROC=0.873


[I 2025-12-14 10:07:15,328] Trial 9 pruned. 


Tuning loaders: train_samples=2000, val_samples=500, batch_size=32
[trial 10 | mode=partial] epoch=1 lr=7.1e-04 wd=5.3e-05 bs=32 | val_AUROC=0.878


[I 2025-12-14 10:07:59,208] Trial 10 finished with value: 0.8870724468447428 and parameters: {'lr': 0.0007137454446124092, 'weight_decay': 5.3389437745561126e-05, 'batch_size': 32}. Best is trial 3 with value: 0.8963979305535283.


[trial 10 | mode=partial] epoch=2 lr=7.1e-04 wd=5.3e-05 bs=32 | val_AUROC=0.887
Tuning loaders: train_samples=2000, val_samples=500, batch_size=16
[trial 11 | mode=partial] epoch=1 lr=1.0e-04 wd=5.6e-05 bs=16 | val_AUROC=0.839


[I 2025-12-14 10:09:00,537] Trial 11 pruned. 


Finished trials (partial): 12
Best AUROC (partial): 0.8963979305535283
Best params (partial):
  lr: 0.0005646386642932582
  weight_decay: 2.6587543983272695e-05
  batch_size: 32


{'lr': 0.0005646386642932582,
 'weight_decay': 2.6587543983272695e-05,
 'batch_size': 32}

## 5) Training

In [None]:
from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from src.training.utils_training import evaluate_binary_classifier

NUM_WORKERS = min(8, os.cpu_count()-1)
BATCH_FINAL = 64 
EPOCHS_FINAL = 20
use_amp = True
patience = 2
BATCH_EVAL = 256


def train_resnet_one_mode(
    tl_mode: str,
    num_epochs: int = 2,
    lr: float = 1e-4,
    weight_decay: float = 1e-5,
    batch_size: int = 64,
    max_train_batches: int | None = None,
    dropout_p: float = 0.0,
    patience: int = 2,
    use_scheduler: bool = True,
    use_amp: bool = True,
    num_workers: int = 0,
    ckpt_dir: str | Path = "experiments/runs",
    save_every: int = 1,
):
    """
    Train a ResNetPCam model for a given transfer learning mode (GPUâ€‘ready).

    Adds AMP (if available), configurable num_workers, periodic checkpointing, and
    saving of both best-by-val-loss and best-by-AUROC model states.
    """
    ckpt_dir = Path(ckpt_dir)
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    print(f"=== Training ResNet18 with tl_mode='{tl_mode}' (num_workers={num_workers}, AMP={use_amp}) ===")
    # reload dataloaders in case batch_size differs
    loaders_local = get_pcam_dataloaders(
        data_root="data/raw",
        batch_size=batch_size,
        center_crop_size=CENTER_CROP,
        num_workers=num_workers,
    )

    cfg = ResNetConfig(tl_mode=tl_mode, pretrained=True, dropout_p=dropout_p)
    model = ResNetPCam(cfg).to(DEVICE)

    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print("Trainable parameters:", f"{sum(p.numel() for p in trainable_params):,}")

    criterion = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(
        trainable_params,
        lr=lr,
        weight_decay=weight_decay,
    )
    scheduler = (
        ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=1, verbose=True)
        if use_scheduler
        else None
    )

    scaler = torch.cuda.amp.GradScaler() if (use_amp and DEVICE.type == "cuda") else None

    history = []
    best_val_loss = float("inf")
    best_state = None
    best_val_auroc = float("-inf")
    best_state_auroc = None
    bad_epochs = 0

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0

        for batch_idx, (images, labels) in enumerate(loaders_local["train"]):
            images = images.to(DEVICE)
            labels = labels.float().to(DEVICE)

            optimizer.zero_grad()

            if scaler is not None:
                with torch.cuda.amp.autocast():
                    logits = model(images)
                    loss = criterion(logits, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                logits = model(images)
                loss = criterion(logits, labels)
                loss.backward()
                optimizer.step()

            train_loss += loss.item() * images.size(0)

            if max_train_batches is not None and (batch_idx + 1) >= max_train_batches:
                break

        # If we cut off early, approximate the effective dataset size
        if max_train_batches is not None:
            effective_size = max_train_batches * batch_size
        else:
            effective_size = len(loaders_local["train"].dataset)

        train_loss /= effective_size

        val_loss, val_auroc, val_auprc = evaluate_binary_classifier(
            model, loaders_local["val"], criterion, DEVICE
        )

        if scheduler:
            scheduler.step(val_loss)

        history.append(
            dict(
                epoch=epoch,
                train_loss=train_loss,
                val_loss=val_loss,
                val_auroc=val_auroc,
                val_auprc=val_auprc,
                lr=optimizer.param_groups[0]["lr"],
            )
        )

        # best by val loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_state = model.state_dict()
            # save best-by-val-loss
            torch.save(best_state, ckpt_dir / f"resnet_best_by_val_loss_{tl_mode}.pt")
            bad_epochs = 0
        else:
            bad_epochs += 1

        # best by AUROC
        if val_auroc is not None and val_auroc > best_val_auroc:
            best_val_auroc = val_auroc
            best_state_auroc = model.state_dict()
            torch.save(best_state_auroc, ckpt_dir / f"resnet_best_by_auroc_{tl_mode}.pt")

        # periodic checkpoint
        if epoch % save_every == 0:
            ckpt = {
                "epoch": epoch,
                "model": model.state_dict(),
                "opt": optimizer.state_dict(),
                "scaler": scaler.state_dict() if scaler is not None else None,
                "lr": optimizer.param_groups[0]["lr"],
                "val_loss": float(val_loss),
                "val_auroc": float(val_auroc),
                "val_auprc": float(val_auprc),
            }
            torch.save(ckpt, ckpt_dir / f"resnet_ckpt_epoch{epoch}_{tl_mode}.pt")

        print(
            f"[{tl_mode}] Epoch {epoch:02d} | "
            f"train_loss={train_loss:.4f} | "
            f"val_loss={val_loss:.4f} | "
            f"AUROC={val_auroc:.3f} | "
            f"AUPRC={val_auprc:.3f} | "
            f"lr={optimizer.param_groups[0]['lr']:.2e}"
        )

        if bad_epochs > patience:
            print(f"Early stopping at epoch {epoch} (no val improvement for {bad_epochs} epochs)")
            break

    # Load best state (val-loss) into model before returning
    if best_state is not None:

        model.load_state_dict(best_state)    return model, history, best_state


In [None]:
# Final training/eval with best partial params on full data
from torch import nn
from importlib import reload
import src.models.resnet_pcam_gpu as resnet_module

reload(resnet_module)  # picks up the updated ResNetConfig/ResNetPCam
ResNetConfig = resnet_module.ResNetConfig
ResNetPCam = resnet_module.ResNetPCam

# Use best partial params from Optuna
FINAL_TL_MODE = "partial"
FINAL_EPOCHS = 12  # longer run for GPU
FINAL_DROPOUT = 0.0


resnet_partial_final, hist_partial_final, best_state_partial = train_resnet_one_mode(
    tl_mode=FINAL_TL_MODE,
    num_epochs=FINAL_EPOCHS,
    lr=best_params_partial["lr"],
    weight_decay=best_params_partial["weight_decay"],
    batch_size=best_params_partial["batch_size"],
    dropout_p=FINAL_DROPOUT,
    max_train_batches=None,  # full epochs
    use_amp=True,
    num_workers=NUM_WORKERS,
    ckpt_dir=Path("experiments/runs"),
    save_every=1,
)

# Evaluate on test split
loaders_eval = get_pcam_dataloaders(
    data_root="data/raw",
    batch_size=best_params_partial["batch_size"],
    center_crop_size=CENTER_CROP,
    num_workers=NUM_WORKERS,
)
criterion = nn.BCEWithLogitsLoss()
test_loss, test_auroc, test_auprc = evaluate_binary_classifier(
    resnet_partial_final,
    loaders_eval["test"],
    criterion,
    DEVICE,
)

print(

    f"[partial] Test metrics | loss={test_loss:.4f} | AUROC={test_auroc:.3f} | AUPRC={test_auprc:.3f}")

=== Training ResNet18 with tl_mode='partial' ===
Trainable parameters: 8,394,241




[partial] Epoch 01 | train_loss=0.3487 | val_loss=0.3570 | AUROC=0.924 | AUPRC=0.925 | lr=5.65e-04
[partial] Epoch 02 | train_loss=0.3041 | val_loss=0.3730 | AUROC=0.919 | AUPRC=0.923 | lr=5.65e-04
[partial] Epoch 03 | train_loss=0.2866 | val_loss=0.3967 | AUROC=0.913 | AUPRC=0.919 | lr=2.82e-04
[partial] Epoch 04 | train_loss=0.2615 | val_loss=0.4018 | AUROC=0.916 | AUPRC=0.922 | lr=2.82e-04
Early stopping at epoch 4 (no val improvement for 3 epochs)
[partial] Test metrics | loss=0.4819 | AUROC=0.894 | AUPRC=0.898


In [None]:
# Final training/eval with best frozen params on full data
from torch import nn

FINAL_TL_MODE_FROZEN = "frozen"
FINAL_EPOCHS_FROZEN = 12
FINAL_DROPOUT_FROZEN = 0.0

resnet_frozen_final, hist_frozen_final, best_state_frozen = train_resnet_one_mode(
    tl_mode=FINAL_TL_MODE_FROZEN,
    num_epochs=FINAL_EPOCHS_FROZEN,
    lr=best_params_frozen["lr"],
    weight_decay=best_params_frozen["weight_decay"],
    batch_size=best_params_frozen["batch_size"],
    dropout_p=FINAL_DROPOUT_FROZEN,
    max_train_batches=None,
    use_amp=True,
    num_workers=NUM_WORKERS,
    ckpt_dir=Path("experiments/runs"),
    save_every=1,
)

loaders_eval_frozen = get_pcam_dataloaders(
    data_root="data/raw",
    batch_size=best_params_frozen["batch_size"],
    center_crop_size=CENTER_CROP,
    num_workers=NUM_WORKERS,
)
criterion = nn.BCEWithLogitsLoss()
test_loss, test_auroc, test_auprc = evaluate_binary_classifier(
    resnet_frozen_final,
    loaders_eval_frozen["test"],
    criterion,
    DEVICE,
)

print(

    f"[frozen] Test metrics | loss={test_loss:.4f} | AUROC={test_auroc:.3f} | AUPRC={test_auprc:.3f}")

=== Training ResNet18 with tl_mode='frozen' ===
Trainable parameters: 513




[frozen] Epoch 01 | train_loss=0.4969 | val_loss=0.4915 | AUROC=0.847 | AUPRC=0.840 | lr=7.14e-04
[frozen] Epoch 02 | train_loss=0.4919 | val_loss=0.4745 | AUROC=0.858 | AUPRC=0.849 | lr=7.14e-04
[frozen] Epoch 03 | train_loss=0.4904 | val_loss=0.4846 | AUROC=0.852 | AUPRC=0.842 | lr=7.14e-04
[frozen] Epoch 04 | train_loss=0.4887 | val_loss=0.4717 | AUROC=0.860 | AUPRC=0.853 | lr=7.14e-04
[frozen] Epoch 05 | train_loss=0.4890 | val_loss=0.4793 | AUROC=0.856 | AUPRC=0.845 | lr=7.14e-04
[frozen] Epoch 06 | train_loss=0.4900 | val_loss=0.4711 | AUROC=0.860 | AUPRC=0.854 | lr=7.14e-04
[frozen] Epoch 07 | train_loss=0.4899 | val_loss=0.4790 | AUROC=0.858 | AUPRC=0.848 | lr=7.14e-04
[frozen] Epoch 08 | train_loss=0.4898 | val_loss=0.4848 | AUROC=0.851 | AUPRC=0.844 | lr=3.57e-04
[frozen] Test metrics | loss=0.5127 | AUROC=0.833 | AUPRC=0.827
