# Papilledema Project — Hybrid AI + Classical Baseline

1. Setup + data
2. Deep-learning detector (EfficientNet-B4) + Hard Negative Mining (HNM)
3. Ordinal staging (CORAL) + hard-example mining
4. End-to-end hybrid evaluation (7-class)
5. Explainability (Grad-CAM)
6. Classical ML baseline (handcrafted features + XGBoost)
7. Export plots + summary tables for reports/presentations

> **Final headline results (your test set):**
> - **Deep hybrid:** Detection AUC **0.9893**, Sens **1.0000**, Spec **0.9592**; Hybrid 7-class Acc **0.9310**
> - **Best classical baseline:** Detection AUC **0.9887**, Sens **1.0000**, Spec **0.9592**; Hybrid 7-class Acc **0.8276**




## 1) Setup (Colab)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies (run once per Colab session)
!pip -q install timm albumentations opencv-python-headless pytorch-grad-cam xgboost scikit-learn tqdm

In [None]:
import os, json, random, math
import numpy as np
import pandas as pd
from tqdm import tqdm

import cv2
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.metrics import (classification_report, confusion_matrix, roc_auc_score,
                             roc_curve, precision_recall_curve, average_precision_score)


In [None]:
# Reproducibility
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_everything(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 2) Data

In [None]:
# TODO: Update these paths to match your Drive layout
DATA_ROOT = "/content/drive/MyDrive/papilloedema"   # folder with N/, PSEUDO/, P/ etc
SPLITS_DIR = "/content/drive/MyDrive/papilloedema/splits"  # where your train/val/test CSVs are

train_csv = os.path.join(SPLITS_DIR, "train.csv")
val_csv   = os.path.join(SPLITS_DIR, "val.csv")
test_csv  = os.path.join(SPLITS_DIR, "test.csv")

train_df = pd.read_csv(train_csv)
val_df   = pd.read_csv(val_csv)
test_df  = pd.read_csv(test_csv)

train_df.head()

In [None]:
# Expected columns (typical):
# - path: file path
# - detect: 0/1 (non-pap vs pap)
# - stage: 1..5 for pap (can be NaN/0 for non-pap)

def show_distributions(df, name):
    print(f"\n{name} Detection Distribution")
    print(df['detect'].value_counts(normalize=True).sort_index())
    if 'stage' in df.columns:
        pap = df[df['detect']==1]
        if len(pap)>0:
            print(f"\n{name} Stage Distribution (pap only)")
            print(pap['stage'].value_counts().sort_index())

show_distributions(train_df, "Train")
show_distributions(val_df, "Val")
show_distributions(test_df, "Test")

## 3) Preprocessing (fundus crop + transforms)

In [None]:
def crop_to_fundus(rgb):
    """Simple circular fundus crop by masking dark borders.
    If you already have a working version in your older notebook, paste it here."""
    # --- minimal safe fallback: return input ---
    return rgb

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

def build_tfms(img_size=448, train=True):
    if train:
        return A.Compose([
            A.LongestMaxSize(img_size),
            A.PadIfNeeded(img_size, img_size, border_mode=cv2.BORDER_CONSTANT),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.3),
            A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.05, rotate_limit=10, p=0.5),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.LongestMaxSize(img_size),
            A.PadIfNeeded(img_size, img_size, border_mode=cv2.BORDER_CONSTANT),
            A.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ToTensorV2(),
        ])


In [None]:
class FundusDataset(Dataset):
    def __init__(self, df, train=True, task='detect', img_size=448):
        self.df = df.reset_index(drop=True)
        self.train = train
        self.task = task
        self.tfms = build_tfms(img_size=img_size, train=train)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        path = row['path']
        bgr = cv2.imread(path)
        if bgr is None:
            raise FileNotFoundError(path)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        rgb = crop_to_fundus(rgb)

        x = self.tfms(image=rgb)['image']

        if self.task == 'detect':
            y = torch.tensor(row['detect'], dtype=torch.float32)
            return x, y, path
        else:
            # stage: 1..5
            y = int(row['stage'])
            return x, y, path


## 4) Deep Learning — Detection (EfficientNet-B4)

In [None]:
class DetectionModel(nn.Module):
    def __init__(self, backbone="tf_efficientnet_b4_ns", pretrained=True):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0, global_pool="avg")
        n = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(n, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 1),
        )
    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat).squeeze(1)  # logits


In [None]:
@torch.no_grad()
def eval_detection(model, loader):
    model.eval()
    ys, ps = [], []
    losses=[]
    for x,y,_ in loader:
        x=x.to(device)
        y=y.to(device)
        logits = model(x)
        prob = torch.sigmoid(logits)
        ys.append(y.detach().cpu().numpy())
        ps.append(prob.detach().cpu().numpy())
    y_true = np.concatenate(ys)
    y_prob = np.concatenate(ps)
    auc = roc_auc_score(y_true, y_prob) if len(np.unique(y_true))>1 else float("nan")
    return y_true, y_prob, auc

def binary_metrics_from_threshold(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
    recall = tp / (tp+fn+1e-12)
    spec   = tn / (tn+fp+1e-12)
    return (tn, fp, fn, tp), recall, spec


In [None]:
def train_detection(train_df, val_df, epochs=10, batch_size=8, lr=3e-4, img_size=448, pos_weight=1.0):
    train_ds = FundusDataset(train_df, train=True, task='detect', img_size=img_size)
    val_ds   = FundusDataset(val_df, train=False, task='detect', img_size=img_size)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = DetectionModel().to(device)
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    best_auc=-1
    best_path="/content/drive/MyDrive/papilloedema/models/det_best.pth"
    os.makedirs(os.path.dirname(best_path), exist_ok=True)

    for ep in range(1, epochs+1):
        model.train()
        running=0
        for x,y,_ in train_loader:
            x=x.to(device); y=y.to(device)
            opt.zero_grad(set_to_none=True)
            logits=model(x)
            loss=loss_fn(logits, y)
            loss.backward()
            opt.step()
            running += loss.item()*x.size(0)
        train_loss = running/len(train_loader.dataset)

        y_true, y_prob, auc = eval_detection(model, val_loader)
        cm, rec, spec = binary_metrics_from_threshold(y_true, y_prob, thr=0.5)

        print(f"Epoch {ep}/{epochs} | TrainLoss {train_loss:.4f} | ValAUC {auc:.4f} | ValRecall {rec:.4f} | ValSpec {spec:.4f} | cm {cm}")

        if auc > best_auc:
            best_auc=auc
            torch.save(model.state_dict(), best_path)
            print("✅ Saved improved detection model")

    print("Best AUC:", best_auc, "Saved:", best_path)
    return best_path


## 5) Deep Learning — Staging (CORAL ordinal)

In [None]:
def make_coral_targets(y, num_classes=5):
    # y in {1..5} -> thresholds y>=2..5 -> 4 binary targets
    y = y.view(-1,1)
    thresholds = torch.arange(2, num_classes+1, device=y.device).view(1,-1)
    return (y >= thresholds).float()

class StagingModel(nn.Module):
    def __init__(self, backbone="tf_efficientnet_b4_ns", pretrained=True, num_classes=5):
        super().__init__()
        self.num_classes=num_classes
        self.backbone = timm.create_model(backbone, pretrained=pretrained, num_classes=0, global_pool="avg")
        n = self.backbone.num_features
        self.head = nn.Sequential(
            nn.Linear(n, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes-1),  # 4 logits
        )
    def forward(self, x):
        feat = self.backbone(x)
        return self.head(feat)

def decode_coral(logits):
    # logits: [B,4]
    probs = torch.sigmoid(logits)
    return (probs > 0.5).sum(dim=1) + 1


In [None]:
@torch.no_grad()
def eval_staging(model, loader):
    model.eval()
    y_true=[]
    y_pred=[]
    for x,y,_ in loader:
        x=x.to(device)
        logits=model(x)
        pred=decode_coral(logits).detach().cpu().numpy()
        y_pred.append(pred)
        y_true.append(np.array(y))
    y_true=np.concatenate(y_true)
    y_pred=np.concatenate(y_pred)
    mae=float(np.mean(np.abs(y_true-y_pred)))
    return y_true, y_pred, mae


In [None]:
def train_staging(train_df, val_df, epochs=25, batch_size=8, lr=3e-4, img_size=448):
    # only positives
    tr = train_df[train_df.detect==1].copy()
    va = val_df[val_df.detect==1].copy()

    train_ds = FundusDataset(tr, train=True, task='stage', img_size=img_size)
    val_ds   = FundusDataset(va, train=False, task='stage', img_size=img_size)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = StagingModel().to(device)
    loss_fn = nn.BCEWithLogitsLoss()
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=2)

    best_mae=1e9
    best_path="/content/drive/MyDrive/papilloedema/models/stage_best.pth"
    os.makedirs(os.path.dirname(best_path), exist_ok=True)

    for ep in range(1, epochs+1):
        model.train()
        run=0
        for x,y,_ in train_loader:
            x=x.to(device)
            y=torch.tensor(y, device=device)
            t = make_coral_targets(y, num_classes=5)
            opt.zero_grad(set_to_none=True)
            logits=model(x)
            loss=loss_fn(logits, t)
            loss.backward()
            opt.step()
            run += loss.item()*x.size(0)
        train_loss = run/len(train_loader.dataset)

        yv_true, yv_pred, mae = eval_staging(model, val_loader)
        sched.step(mae)

        print(f"Epoch {ep}/{epochs} | TrainLoss {train_loss:.4f} | ValMAE {mae:.4f}")

        if mae < best_mae:
            best_mae=mae
            torch.save(model.state_dict(), best_path)
            print("✅ Best staging model saved")

    print("Best Val MAE:", best_mae, "Saved:", best_path)
    return best_path


## 6) Hybrid evaluation (7-class)

In [None]:
@torch.no_grad()
def load_models(det_path, stage_path):
    det = DetectionModel().to(device)
    det.load_state_dict(torch.load(det_path, map_location=device))
    det.eval()

    stg = StagingModel().to(device)
    stg.load_state_dict(torch.load(stage_path, map_location=device))
    stg.eval()
    return det, stg

@torch.no_grad()
def preprocess_one(path, img_size=448):
    bgr=cv2.imread(path)
    rgb=cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb=crop_to_fundus(rgb)
    tfm=build_tfms(img_size=img_size, train=False)
    x=tfm(image=rgb)['image'].unsqueeze(0)
    return x, rgb

@torch.no_grad()
def hybrid_predict(det_model, stage_model, path, det_thr=0.5, img_size=448):
    x,_ = preprocess_one(path, img_size=img_size)
    x=x.to(device)

    p = torch.sigmoid(det_model(x)).item()
    det_pred = int(p >= det_thr)
    if det_pred==0:
        return 0, p, None
    logits = stage_model(x)
    stage_pred = int(decode_coral(logits).item())
    return stage_pred, p, stage_pred


In [None]:
@torch.no_grad()
def eval_hybrid(det_model, stage_model, df, det_thr=0.5, img_size=448):
    y_true=[]
    y_pred=[]
    for _,row in tqdm(df.iterrows(), total=len(df)):
        true7 = 0 if int(row.detect)==0 else int(row.stage)
        pred7, p, _ = hybrid_predict(det_model, stage_model, row.path, det_thr=det_thr, img_size=img_size)
        y_true.append(true7)
        y_pred.append(pred7)

    y_true=np.array(y_true); y_pred=np.array(y_pred)

    labels=[0,1,2,3,4,5]
    names=["Non-pap","G1","G2","G3","G4","G5"]
    print(classification_report(y_true, y_pred, labels=labels, target_names=names, digits=4, zero_division=0))
    print("Confusion:\n", confusion_matrix(y_true, y_pred, labels=labels))
    return y_true, y_pred


## 7) Plots (ROC / PR curves + confusion heatmaps)

In [None]:
def save_roc_pr_curves(y_true, y_prob, out_dir, prefix="det"):
    os.makedirs(out_dir, exist_ok=True)

    fpr, tpr, _ = roc_curve(y_true, y_prob)
    auc = roc_auc_score(y_true, y_prob)
    plt.figure()
    plt.plot(fpr, tpr)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"ROC ({prefix}) AUC={auc:.4f}")
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, f"{prefix}_roc.png"), dpi=200, bbox_inches="tight")
    plt.close()

    prec, rec, _ = precision_recall_curve(y_true, y_prob)
    ap = average_precision_score(y_true, y_prob)
    plt.figure()
    plt.plot(rec, prec)
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.title(f"PR ({prefix}) AP={ap:.4f}")
    plt.grid(True)
    plt.savefig(os.path.join(out_dir, f"{prefix}_pr.png"), dpi=200, bbox_inches="tight")
    plt.close()

def save_confusion_heatmap(cm, labels, out_path, title="Confusion Matrix"):
    plt.figure(figsize=(6,5))
    plt.imshow(cm)
    plt.title(title)
    plt.xticks(range(len(labels)), labels, rotation=45, ha="right")
    plt.yticks(range(len(labels)), labels)
    for i in range(len(labels)):
        for j in range(len(labels)):
            plt.text(j, i, int(cm[i,j]), ha="center", va="center")
    plt.colorbar()
    plt.tight_layout()
    plt.savefig(out_path, dpi=200, bbox_inches="tight")
    plt.close()


## 8) Grad-CAM (optional explainability)

In [None]:
# NOTE: Grad-CAM requires gradients; do NOT wrap the CAM call with @torch.no_grad().
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

def find_last_conv_module(model):
    last=None
    for m in model.modules():
        if isinstance(m, torch.nn.Conv2d):
            last=m
    return last

class BinaryPositiveTarget:
    def __call__(self, model_output):
        # model_output: logits [B] or [B,1]
        if model_output.ndim==2:
            model_output=model_output[:,0]
        return model_output.sum()

def gradcam_detection(det_model, path, img_size=448):
    det_model.eval()
    x, rgb = preprocess_one(path, img_size=img_size)
    rgb_float = (rgb/255.0).astype(np.float32)
    x = x.to(device)
    layer = find_last_conv_module(det_model)
    cam = GradCAM(model=det_model, target_layers=[layer])
    det_model.zero_grad(set_to_none=True)
    cam_map = cam(input_tensor=x, targets=[BinaryPositiveTarget()])[0]
    overlay = show_cam_on_image(rgb_float, cam_map, use_rgb=True)
    return overlay


## 9) Classical ML baseline (Approach A)

In [None]:
# This section assumes you already have a working feature extractor in your notebook.
# Paste your final feature extraction functions here (disc patch, LBP, Frangi vesselness, entropy, etc.)
# Then train:
#
# - Detection: XGBoostClassifier with scale_pos_weight
# - Staging: CORAL-style threshold models or ordinal approach
#
# Finally: evaluate 7-class hybrid the same way as deep hybrid.


## 10) Export summary for report/presentation

In [None]:
# Create a quick summary table (edit values if you update models)
summary = pd.DataFrame([
    {"System":"Deep hybrid (final)", "Det AUC":0.9893, "Det Sens":1.0, "Det Spec":0.9592, "Hybrid Acc":0.9310},
    {"System":"Classical baseline (best)", "Det AUC":0.9887, "Det Sens":1.0, "Det Spec":0.9592, "Hybrid Acc":0.8276},
])
summary

In [None]:
out_dir = "/content/drive/MyDrive/papilloedema/exports"
os.makedirs(out_dir, exist_ok=True)
summary.to_csv(os.path.join(out_dir, "model_summary.csv"), index=False)
print("Saved:", os.path.join(out_dir, "model_summary.csv"))