# PCam Small CNN on Colab (GPU)
End-to-end notebook for Colab with GPU: clone repo, install deps, download PCam, run training/Optuna.


**Before running:**
- In Colab: `Runtime > Change runtime type > GPU`.
- Set `REPO_URL` below (HTTPS clone).

In [None]:
import os, pathlib, subprocess, sys

# Set to your repo URL (HTTPS)
REPO_URL = "https://github.com/ellenaspiess/PCam.git"
REPO_DIR = pathlib.Path("/content/PCam")

if not REPO_DIR.exists():
    assert REPO_URL and REPO_URL != "<YOUR_REPO_URL>", "Please set REPO_URL"
    subprocess.run(["git", "clone", REPO_URL, str(REPO_DIR)], check=True)
else:
    # If rerun, pull latest changes
    subprocess.run(["git", "-C", str(REPO_DIR), "pull"], check=True)

os.chdir(REPO_DIR)
sys.path.insert(0, str(REPO_DIR))
print("Working dir:", os.getcwd())

In [None]:
# Install minimal deps (Colab already has torch/torchvision, but we ensure versions)
!pip install -q torch torchvision torchaudio scikit-learn optuna tqdm

In [None]:
import torch
from pathlib import Path
from src.datasets.dataloaders import get_pcam_dataloaders
from src.training.train_small_cnn import train

# Prefer GPU on Colab
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATA_ROOT = Path("data/raw")
print("Using device:", DEVICE)

# Quick download/verify (small subset) to warm cache
loaders = get_pcam_dataloaders(
    data_root=DATA_ROOT,
    batch_size=16,
    num_workers=2,
    center_crop_size=64,
    limit_per_split=128,
)
for k, v in loaders.items():
    print(k, len(v.dataset))


## Baseline / Final run
Adjust `limit_per_split` to `None` for full dataset (longer on GPU).

In [None]:
train(
    data_root=DATA_ROOT,
    num_epochs=8,
    batch_size=32,
    lr=2.24e-4,
    device=DEVICE,
    limit_per_split=None,  # set to None for full dataset; use small int for quick tests
)


## Optional: small Optuna search (GPU)
Keep trials modest to fit Colab session time.

In [None]:
import optuna
from src.models.small_cnn import SmallCNN
from src.training.train_small_cnn import evaluate

def objective(trial):
    lr = trial.suggest_float("lr", 1e-4, 8e-4, log=True)
    weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 48])

    loaders = get_pcam_dataloaders(
        data_root=DATA_ROOT,
        batch_size=batch_size,
        num_workers=2,
        center_crop_size=64,
        limit_per_split=2048,  # reduce for speed
    )

    model = SmallCNN(dropout_p=0.1).to(DEVICE)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    EPOCHS = 3
    for epoch in range(1, EPOCHS + 1):
        model.train()
        for images, labels in loaders["train"]:
            images = images.to(DEVICE)
            labels = labels.float().to(DEVICE)
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
        _, val_auroc, _ = evaluate(model, loaders["val"], criterion, DEVICE)
        trial.report(val_auroc, step=epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
    return val_auroc

DO_OPTUNA = False  # set True to run
if DO_OPTUNA:
    study = optuna.create_study(
        direction="maximize",
        pruner=optuna.pruners.MedianPruner(n_startup_trials=1),
    )
    study.optimize(objective, n_trials=5)
    print("Best AUROC:", study.best_value)
    print("Best params:", study.best_params)
