# Papilledema: Detection + Severity Staging (Organized Notebook)
This notebook is a cleaned, *organized* version of the work from `Untitled41 (2).ipynb`.
It consolidates repeated cells, keeps one source of truth for datasets/models/training, and separates **training**, **evaluation**, and **inference**.

**Backbones:** EfficientNet-B4 (timm)

**Tasks:**
1) Binary detection (papilledema vs not)
2) Ordinal staging (CORAL)

> Update the paths in the Config section to match your Drive/folders.

## 0) Setup & Config

In [None]:
# If running on Colab:
from google.colab import drive
drive.mount('/content/drive')

# ----------------------------
# Config (edit these)
# ----------------------------
DATASET_ROOT = "/content/drive/MyDrive/papilloedema"     # folder with class subfolders OR where your images live
SPLIT_DIR    = "/content/drive/MyDrive/papilloedema_splits"
WEIGHTS_DIR  = "/content/drive/MyDrive"                 # where .pth are saved

IMG_SIZE      = 448
BATCH_SIZE    = 8
NUM_WORKERS   = 2
RANDOM_STATE  = 42

DET_WEIGHTS_PATH   = f"{WEIGHTS_DIR}/best_detection_model.pth"
STAGE_WEIGHTS_PATH = f"{WEIGHTS_DIR}/best_staging_model_coral.pth"


## 1) Imports, Reproducibility

In [None]:
import os, random
from pathlib import Path
import numpy as np
import pandas as pd

import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, recall_score, confusion_matrix, classification_report,
    accuracy_score, cohen_kappa_score, mean_absolute_error
)

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(RANDOM_STATE)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


## 2) Build Metadata DataFrame (path, detect, stage)
If you already have `train.csv/val.csv/test.csv`, skip to section 3.

In [None]:
# NOTE: this logic assumes your folder names encode class/stage.
# If your original notebook already produced df with columns: path, detect, stage,
# you can directly load it instead.

def build_df_from_folder(root: str) -> pd.DataFrame:
    records = []
    root = Path(root)

    for cls in root.iterdir():
        if not cls.is_dir():
            continue

        cls_name = cls.name
        for p in cls.glob("*"):
            if p.suffix.lower() not in [".jpg",".jpeg",".png",".tif",".tiff"]:
                continue

            # --- TODO: adapt this mapping to your folder naming ---
            # Example assumptions:
            # - class folders like: "normal", "pap_stage1", "pap_stage2", ...
            # - detect = 1 if papilledema else 0
            # - stage in {0..4} or {1..5}; normalize later
            detect = 1 if "pap" in cls_name.lower() else 0

            # crude stage parse example; modify to your needs
            stage = 0
            for k in ["1","2","3","4","5"]:
                if f"stage{k}" in cls_name.lower():
                    stage = int(k)
                    break

            records.append({"path": str(p), "detect": detect, "stage": stage})

    df = pd.DataFrame(records)
    if df.empty:
        raise ValueError("No images found. Check DATASET_ROOT and folder structure.")
    return df

df = build_df_from_folder(DATASET_ROOT)
print(df.head())
print(df.groupby(["detect","stage"]).size())


## 3) Train/Val/Test Split (stratified by detect)

In [None]:
os.makedirs(SPLIT_DIR, exist_ok=True)

train_df, temp_df = train_test_split(
    df, test_size=0.30, stratify=df["detect"], random_state=RANDOM_STATE
)
val_df, test_df = train_test_split(
    temp_df, test_size=0.50, stratify=temp_df["detect"], random_state=RANDOM_STATE
)

train_df.to_csv(f"{SPLIT_DIR}/train.csv", index=False)
val_df.to_csv(f"{SPLIT_DIR}/val.csv", index=False)
test_df.to_csv(f"{SPLIT_DIR}/test.csv", index=False)

def check_dist(d, name):
    print(f"
{name} detect dist")
    print(d["detect"].value_counts(normalize=True))
    if (d["detect"]==1).any():
        print(f"{name} stage dist (detect=1 only)")
        print(d[d["detect"]==1]["stage"].value_counts())

check_dist(train_df,"Train")
check_dist(val_df,"Val")
check_dist(test_df,"Test")


## 4) Transforms

In [None]:
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_transform = A.Compose([
    A.LongestMaxSize(max_size=512),
    A.PadIfNeeded(512, 512, border_mode=0),
    A.RandomCrop(IMG_SIZE, IMG_SIZE),
    A.HorizontalFlip(p=0.5),
    A.Rotate(limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.4),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])

val_transform = A.Compose([
    A.LongestMaxSize(max_size=512),
    A.PadIfNeeded(512, 512, border_mode=0),
    A.CenterCrop(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
    ToTensorV2(),
])


## 5) Dataset + Dataloaders (single implementation)
Supports optional fundus cropping and returning the image path (useful for Grad-CAM).

In [None]:
def crop_to_fundus(img_rgb, pad: int = 8):
    """Basic black-border removal used in the original notebook.
    Keep/adjust based on your results.
    """
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    _, th = cv2.threshold(gray, 10, 255, cv2.THRESH_BINARY)
    cnts, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not cnts:
        return img_rgb
    c = max(cnts, key=cv2.contourArea)
    x,y,w,h = cv2.boundingRect(c)
    x = max(0, x-pad); y = max(0, y-pad)
    w = min(img_rgb.shape[1]-x, w+2*pad)
    h = min(img_rgb.shape[0]-y, h+2*pad)
    return img_rgb[y:y+h, x:x+w]

class PapilledemaDataset(Dataset):
    def __init__(self, df: pd.DataFrame, transform=None, use_fundus_crop=False, return_path=False):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.use_fundus_crop = use_fundus_crop
        self.return_path = return_path

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row["path"]

        img = cv2.imread(path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.use_fundus_crop:
            img = crop_to_fundus(img)

        if self.transform:
            img = self.transform(image=img)["image"]

        detect = torch.tensor(row["detect"], dtype=torch.float32)
        stage  = torch.tensor(row["stage"], dtype=torch.long)

        if self.return_path:
            return img, detect, stage, path
        return img, detect, stage

train_loader = DataLoader(
    PapilledemaDataset(train_df, train_transform, use_fundus_crop=False),
    batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS
)
val_loader = DataLoader(
    PapilledemaDataset(val_df, val_transform, use_fundus_crop=False),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)
test_loader = DataLoader(
    PapilledemaDataset(test_df, val_transform, use_fundus_crop=False),
    batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS
)

x, d, s = next(iter(train_loader))
print("Batch:", x.shape, d[:5], s[:5])


## 6) Models
### 6.1 Detection (binary)
### 6.2 Staging (CORAL ordinal)

The original notebook uses EfficientNet-B4 via `timm` and a small MLP head.

In [None]:
class DetectionModel(nn.Module):
    def __init__(self, backbone="efficientnet_b4", pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        self.head = nn.Sequential(
            nn.Linear(self.backbone.num_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1)
        )

    def forward(self, x):
        feat = self.backbone(x)
        logits = self.head(feat)
        return logits.squeeze(1)

# ---------- CORAL utilities ----------
def labels_to_coral(stage_labels: torch.Tensor, num_classes: int):
    """Convert class labels (0..K-1) to CORAL targets of shape [B, K-1]."""
    # For stage y, coral target is [1,1,...,1,0,0,...] length K-1
    B = stage_labels.shape[0]
    coral = torch.zeros((B, num_classes-1), device=stage_labels.device)
    for k in range(num_classes-1):
        coral[:, k] = (stage_labels > k).float()
    return coral

def coral_logits_to_label(logits: torch.Tensor):
    """Convert CORAL logits [B,K-1] to predicted labels 0..K-1."""
    probs = torch.sigmoid(logits)
    return (probs > 0.5).sum(dim=1)

class StagingModelCORAL(nn.Module):
    def __init__(self, backbone="efficientnet_b4", pretrained=True, num_classes=4):
        super().__init__()
        self.num_classes = num_classes
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0)
        self.head = nn.Sequential(
            nn.Linear(self.backbone.num_features, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes-1)  # CORAL uses K-1 outputs
        )

    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat)


## 7) Training Utilities

In [None]:
@torch.no_grad()
def validate_detection(model, loader, criterion, threshold=0.5):
    model.eval()
    losses, probs_all, y_all = [], [], []
    for images, detect, _ in loader:
        images = images.to(device)
        detect = detect.to(device)
        logits = model(images)
        loss = criterion(logits, detect)
        losses.append(loss.item())
        probs = torch.sigmoid(logits).detach().cpu().numpy()
        probs_all.extend(probs.tolist())
        y_all.extend(detect.detach().cpu().numpy().tolist())

    probs_all = np.array(probs_all)
    y_all = np.array(y_all).astype(int)
    auc = roc_auc_score(y_all, probs_all)
    recall = recall_score(y_all, probs_all >= threshold)
    return float(np.mean(losses)), float(auc), float(recall)

def train_one_epoch_detection(model, loader, optimizer, criterion):
    model.train()
    losses=[]
    for images, detect, _ in loader:
        images = images.to(device)
        detect = detect.to(device)

        optimizer.zero_grad(set_to_none=True)
        logits = model(images)
        loss = criterion(logits, detect)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return float(np.mean(losses))

@torch.no_grad()
def validate_stage(model, loader, criterion, num_classes: int):
    model.eval()
    losses=[]
    y_true=[]
    y_pred=[]
    for images, detect, stage in loader:
        # Evaluate staging only on detect==1, like your original logic
        mask = (detect == 1)
        if mask.sum() == 0:
            continue
        images = images[mask].to(device)
        stage = stage[mask].to(device)

        logits = model(images)
        target = labels_to_coral(stage, num_classes=num_classes)
        loss = criterion(logits, target)
        losses.append(loss.item())

        pred = coral_logits_to_label(logits)
        y_true.extend(stage.detach().cpu().numpy().tolist())
        y_pred.extend(pred.detach().cpu().numpy().tolist())

    if len(y_true)==0:
        return np.nan, np.nan, np.nan, np.nan

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    mae = mean_absolute_error(y_true, y_pred)
    qwk = cohen_kappa_score(y_true, y_pred, weights="quadratic")
    acc = accuracy_score(y_true, y_pred)
    return float(np.mean(losses)), float(mae), float(qwk), float(acc)


## 8) Train Detection

In [None]:
det_model = DetectionModel(pretrained=True).to(device)

num_pos = train_df["detect"].sum()
num_neg = len(train_df) - num_pos
pos_weight = torch.tensor([num_neg / max(num_pos,1)], device=device)

det_criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
det_optimizer = optim.AdamW(det_model.parameters(), lr=3e-4, weight_decay=1e-4)
det_scheduler = optim.lr_scheduler.ReduceLROnPlateau(det_optimizer, mode="max", factor=0.5, patience=2)

EPOCHS = 20
best_auc = -1
for epoch in range(EPOCHS):
    tr = train_one_epoch_detection(det_model, train_loader, det_optimizer, det_criterion)
    vl, auc, rec = validate_detection(det_model, val_loader, det_criterion)
    det_scheduler.step(auc)

    print(f"Epoch {epoch+1:02d}/{EPOCHS} | TrainLoss {tr:.4f} | ValLoss {vl:.4f} | ValAUC {auc:.4f} | ValRecall {rec:.4f}")

    if auc > best_auc:
        best_auc = auc
        torch.save(det_model.state_dict(), DET_WEIGHTS_PATH)
        print("✅ Saved:", DET_WEIGHTS_PATH)


## 9) Train Staging (CORAL)
You can train staging on only papilledema images, but this notebook keeps a simple approach: use the same loader and mask detect==1 inside validation.
If you have a separate stage-only dataframe, build a dedicated loader for it.

In [None]:
NUM_STAGE_CLASSES = 4  # <-- set to your true number of stages (e.g., 4 or 5). Must match your labels.

stage_model = StagingModelCORAL(pretrained=True, num_classes=NUM_STAGE_CLASSES).to(device)
stage_criterion = nn.BCEWithLogitsLoss()
stage_optimizer = optim.AdamW(stage_model.parameters(), lr=3e-4, weight_decay=1e-4)
stage_scheduler = optim.lr_scheduler.ReduceLROnPlateau(stage_optimizer, mode="max", factor=0.5, patience=2)

EPOCHS_STAGE = 15
best_qwk = -1
for epoch in range(EPOCHS_STAGE):
    # Train only on detect==1 samples
    stage_model.train()
    losses=[]
    for images, detect, stage in train_loader:
        mask = (detect == 1)
        if mask.sum() == 0:
            continue
        images = images[mask].to(device)
        stage = stage[mask].to(device)

        stage_optimizer.zero_grad(set_to_none=True)
        logits = stage_model(images)
        target = labels_to_coral(stage, num_classes=NUM_STAGE_CLASSES)
        loss = stage_criterion(logits, target)
        loss.backward()
        stage_optimizer.step()
        losses.append(loss.item())

    tr_loss = float(np.mean(losses)) if losses else np.nan
    vl_loss, mae, qwk, acc = validate_stage(stage_model, val_loader, stage_criterion, num_classes=NUM_STAGE_CLASSES)
    stage_scheduler.step(qwk)

    print(f"Epoch {epoch+1:02d}/{EPOCHS_STAGE} | TrainLoss {tr_loss:.4f} | ValLoss {vl_loss:.4f} | MAE {mae:.3f} | QWK {qwk:.3f} | Acc {acc:.3f}")

    if qwk > best_qwk:
        best_qwk = qwk
        torch.save(stage_model.state_dict(), STAGE_WEIGHTS_PATH)
        print("✅ Saved:", STAGE_WEIGHTS_PATH)


## 10) Final Evaluation on Test Set (Detection + Staging)

In [None]:
# Load best weights
det_model = DetectionModel(pretrained=False).to(device)
det_model.load_state_dict(torch.load(DET_WEIGHTS_PATH, map_location=device))
det_model.eval()

stage_model = StagingModelCORAL(pretrained=False, num_classes=NUM_STAGE_CLASSES).to(device)
stage_model.load_state_dict(torch.load(STAGE_WEIGHTS_PATH, map_location=device))
stage_model.eval()

# ---- Detection report ----
@torch.no_grad()
def get_detection_probs(model, loader):
    model.eval()
    probs_all, y_all = [], []
    for images, detect, _ in loader:
        images = images.to(device)
        logits = model(images)
        probs = torch.sigmoid(logits).cpu().numpy()
        probs_all.extend(probs.tolist())
        y_all.extend(detect.numpy().tolist())
    return np.array(y_all).astype(int), np.array(probs_all)

y_true, probs = get_detection_probs(det_model, test_loader)
y_pred = (probs >= 0.5).astype(int)
print("Detection AUC:", roc_auc_score(y_true, probs))
print(classification_report(y_true, y_pred, digits=4))

# ---- Staging report (only detect==1) ----
@torch.no_grad()
def stage_on_test(det_model, stage_model, loader):
    det_model.eval(); stage_model.eval()
    y_true_s=[]; y_pred_s=[]
    for images, detect, stage in loader:
        images = images.to(device)
        # If you want to gate staging by predicted detection, replace detect with predictions here.
        mask = (detect == 1)
        if mask.sum()==0:
            continue
        logits = stage_model(images[mask])
        pred = coral_logits_to_label(logits).cpu().numpy()
        y_true_s.extend(stage[mask].numpy().tolist())
        y_pred_s.extend(pred.tolist())
    return np.array(y_true_s), np.array(y_pred_s)

y_s, p_s = stage_on_test(det_model, stage_model, test_loader)
if len(y_s):
    print("Staging MAE:", mean_absolute_error(y_s, p_s))
    print("Staging QWK:", cohen_kappa_score(y_s, p_s, weights="quadratic"))
    print("Staging Acc:", accuracy_score(y_s, p_s))
    print(confusion_matrix(y_s, p_s))
else:
    print("No detect==1 samples found in test split (check labels).")


## 11) Inference Helper (single image)
This mirrors the end of your original notebook: optional fundus crop, transforms, then detection → staging.

In [None]:
import matplotlib.pyplot as plt

@torch.no_grad()
def predict_one_image(image_path: str, det_threshold=0.5, use_crop=True):
    # read
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    if use_crop:
        img = crop_to_fundus(img)

    # keep a copy for display
    vis = img.copy()

    # model input
    x = val_transform(image=img)["image"].unsqueeze(0).to(device)

    # detect
    det_logit = det_model(x)
    det_prob = torch.sigmoid(det_logit).item()
    det_pred = int(det_prob >= det_threshold)

    # stage
    stage_pred = None
    if det_pred == 1:
        stage_logits = stage_model(x)
        stage_pred = int(coral_logits_to_label(stage_logits).item())

    return vis, det_prob, det_pred, stage_pred

# Example:
# vis, prob, det, stage = predict_one_image("/content/drive/MyDrive/some_image.jpg")
# plt.imshow(vis); plt.axis("off")
# print("det_prob:", prob, "det:", det, "stage:", stage)
