# PCam Small CNN: Step-by-step walkthrough
- Goal: establish a baseline, test short hypotheses, then prepare hyperparameter tuning and a final run.
- We start small (CPU, limited split), observe what works, and scale up gradually.

## 1) Setup & Assumptions
- CPU run for debugging; limited dataset for fast iteration.
- Metric focus: val AUROC (primary), AUPRC (secondary). Loss used for stability checks.
- Initial hypothesis: a small CNN will generalize with a moderate LR and light weight decay; batch size affects stability on CPU.

In [None]:
import torch
from pathlib import Path
from typing import Dict
import optuna

from src.datasets.dataloaders import get_pcam_dataloaders
from src.models.small_cnn import SmallCNN
from src.training.train_small_cnn import evaluate

# Device setting: force CPU for debugging; switch to "cuda" for GPU later
DEVICE = torch.device("cpu")
DATA_ROOT = Path("data/raw")
CENTER_CROP = 64

# Base settings for quick runs (adjusted further below)
LIMIT_DEBUG = 512     # None for full split; keep small for CPU iterations
NUM_WORKERS = 0       # CPU-friendly
EPOCHS_DEBUG = 2
BATCH_DEBUG = 32
LR_DEBUG = 1e-3
WD_DEBUG = 1e-4

print("Using device:", DEVICE)

Using device: cpu


Why these defaults?
- `LIMIT_DEBUG=512`: minimizes I/O and compute time, enough for coarse trends.
- `LR_DEBUG=1e-3`: a robust starting point for Adam with small CNNs; often works without divergence.
- `WD_DEBUG=1e-4`: light regularization to prevent overfitting on small splits.
- `BATCH_DEBUG=32`: balance between stability and CPU speed.

In [10]:
def make_loaders(batch_size: int, limit: int | None) -> Dict[str, torch.utils.data.DataLoader]:
    return get_pcam_dataloaders(
        data_root=DATA_ROOT,
        batch_size=batch_size,
        num_workers=NUM_WORKERS,
        center_crop_size=CENTER_CROP,
        limit_per_split=limit,
    )


def train_one_run(
    batch_size: int,
    lr: float,
    weight_decay: float,
    epochs: int,
    limit: int | None,
    dropout_p: float = 0.1,
    patience: int = 2,
    use_scheduler: bool = True,
    verbose: bool = True,
):
    loaders = make_loaders(batch_size, limit)
    model = SmallCNN(dropout_p=dropout_p).to(DEVICE)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = (
        torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=1, verbose=verbose
        )
        if use_scheduler
        else None
    )

    best_val_loss = float("inf")
    best_state = None
    bad_epochs = 0
    history = []
    for epoch in range(1, epochs + 1):
        model.train()
        train_loss = 0.0
        for images, labels in loaders["train"]:
            images = images.to(DEVICE)
            labels = labels.float().to(DEVICE)

            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)

        train_loss /= len(loaders["train"].dataset)
        val_loss, val_auroc, val_auprc = evaluate(model, loaders["val"], criterion, DEVICE)

        if scheduler:
            scheduler.step(val_loss)

        history.append(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_auroc": val_auroc,
                "val_auprc": val_auprc,
                "lr": optimizer.param_groups[0]["lr"],
            }
        )

        improved = val_loss < best_val_loss
        if improved:
            best_val_loss = val_loss
            best_state = model.state_dict()
            bad_epochs = 0
        else:
            bad_epochs += 1

        if verbose:
            print(
                f"Epoch {epoch:02d} | train_loss={train_loss:.4f} | "
                f"val_loss={val_loss:.4f} | AUROC={val_auroc:.3f} | AUPRC={val_auprc:.3f} | lr={optimizer.param_groups[0]['lr']:.2e}"
            )

        if bad_epochs > patience:
            if verbose:
                print(f"Early stopping at epoch {epoch} (no val improvement for {bad_epochs} epochs)")
            break

    return history, best_state


## 2) Baseline on a small split
Hypothesis: with moderate LR/WD we expect reasonable AUROC > 0.8 after a few epochs on 512 samples. If unstable, adjust LR or batch size.

In [11]:
DO_BASELINE = True
if DO_BASELINE:
    hist_baseline, _ = train_one_run(
        batch_size=BATCH_DEBUG,
        lr=LR_DEBUG,
        weight_decay=WD_DEBUG,
        epochs=EPOCHS_DEBUG,
        limit=LIMIT_DEBUG,
    )
    hist_baseline

Epoch 01 | train_loss=0.6154 | val_loss=0.6913 | AUROC=0.756 | AUPRC=0.771
Epoch 02 | train_loss=0.5331 | val_loss=0.5240 | AUROC=0.843 | AUPRC=0.813


Interpreting the baseline (manually after the run):
- If AUROC quickly exceeds 0.8, the model learns stably → we can fine-tune.
- If loss fluctuates or AUROC stays below 0.7 → reduce LR (e.g. 5e-4) or increase batch size (if using GPU); on CPU, smaller batches often give more stability.

## 3) Targeted tuning (Optuna, small search space)
Idea: small search on CPU with pruning to probe LR / weight decay / batch size.
- Search space: `lr` 1e-4..8e-4 (log), `weight_decay` 1e-5..1e-3 (log), `batch_size` {16, 32} for CPU.
- Keep trials small (5–10) and use `LIMIT_DEBUG`.
- Disabled by default; enable Optuna when you want to run the search.

In [None]:
DO_OPTUNA = False

def objective(trial):
    lr = trial.suggest_float("lr", 1e-4, 8e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32])

    loaders = make_loaders(batch_size, LIMIT_DEBUG)
    model = SmallCNN().to(DEVICE)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    for epoch in range(1, EPOCHS_DEBUG + 1):
        model.train()
        for images, labels in loaders["train"]:
            images = images.to(DEVICE)
            labels = labels.float().to(DEVICE)
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        _, val_auroc, _ = evaluate(model, loaders["val"], criterion, DEVICE)
        trial.report(val_auroc, step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()

    return val_auroc

if DO_OPTUNA:
    study = optuna.create_study(
        direction="maximize",
        study_name="small_cnn_story_cpu",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=1, n_warmup_steps=0),
    )
    study.optimize(objective, n_trials=5)
    print("Best AUROC:", study.best_value)
    print("Best params:", study.best_params)
    trials_df = study.trials_dataframe()
    trials_df[["number", "value", "state", "params_lr", "params_weight_decay", "params_batch_size"]]

Interpreting the search results:
- Choose the best parameters as a starting point for the larger run.
- If the search shows little variation, widen the LR range or include `batch_size` 16/32/48 (if using GPU).

**New anti-overfitting measures**:
- Dropout 0.1 before the classifier in SmallCNN.
- ReduceLROnPlateau on val loss (factor 0.5, patience=1) for LR adaptation.
- Early stopping with patience=2 (stop if val does not improve).


## 4) More complete run (more data, more epochs)
Plan: test the best settings on more data and for more epochs.
- On CPU: keep `LIMIT_FINAL` moderate (e.g. 2k–4k) and 6–10 epochs.
- On GPU: `LIMIT_FINAL=None` (full split) and 10–15 epochs, optional LR scheduler (`ReduceLROnPlateau`).

In [None]:
# Example settings from a good trial
LR_FINAL = 2.24e-4
WD_FINAL = 4.28e-5
BATCH_FINAL = 16
EPOCHS_FINAL = 8
LIMIT_FINAL = 2048   # None for full dataset; keep moderate on CPU

DO_FINAL = False

if DO_FINAL:
    hist_final, _ = train_one_run(
        batch_size=BATCH_FINAL,
        lr=LR_FINAL,
        weight_decay=WD_FINAL,
        epochs=EPOCHS_FINAL,
        limit=LIMIT_FINAL,
    )
    hist_final

In [None]:
# Final step: run final training, log results, save model
DO_FINAL = True

if DO_FINAL:
    hist_final, best_state = train_one_run(
        batch_size=BATCH_FINAL,
        lr=LR_FINAL,
        weight_decay=WD_FINAL,
        epochs=EPOCHS_FINAL,
        limit=LIMIT_FINAL,
        verbose=True,
    )
    # Find best epoch (by AUROC)
    best = max(hist_final, key=lambda x: x["val_auroc"])
    print(
        f"Best epoch: {best['epoch']} | "
        f"train_loss={best['train_loss']:.4f} | "
        f"val_loss={best['val_loss']:.4f} | "
        f"AUROC={best['val_auroc']:.3f} | "
        f"AUPRC={best['val_auprc']:.3f}"
    )

    # Save model (best val-loss state)
    if best_state is not None:
        save_path = Path("experiments/runs/small_cnn_final.pt")
        save_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(best_state, save_path)
        print("Saved model to", save_path)
    else:
        print("Warn: best_state is None, nothing saved")


Epoch 01 | train_loss=0.5387 | val_loss=0.4891 | AUROC=0.845 | AUPRC=0.814
Epoch 02 | train_loss=0.5094 | val_loss=0.4794 | AUROC=0.860 | AUPRC=0.851
Epoch 03 | train_loss=0.4949 | val_loss=0.4503 | AUROC=0.882 | AUPRC=0.877
Epoch 04 | train_loss=0.4912 | val_loss=0.4538 | AUROC=0.875 | AUPRC=0.873
Epoch 05 | train_loss=0.4863 | val_loss=0.4733 | AUROC=0.858 | AUPRC=0.874
Epoch 06 | train_loss=0.4747 | val_loss=0.5090 | AUROC=0.870 | AUPRC=0.870
Epoch 07 | train_loss=0.4689 | val_loss=0.4646 | AUROC=0.861 | AUPRC=0.861
Epoch 08 | train_loss=0.4739 | val_loss=0.4250 | AUROC=0.893 | AUPRC=0.897
Best epoch: 8 | train_loss=0.4739 | val_loss=0.4250 | AUROC=0.893 | AUPRC=0.897
Saved model to experiments/runs/small_cnn_final.pt
