# 1. CFP Images MobileNetV3L

In [1]:
import math
import os
import random
import re
from typing import Dict, List
import zipfile

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

CSV_PATH = "./filtered_glaucoma.csv"
IMG_ROOT = "./glaucoma_data/CFPs"  # <-- CFP images folder
CHECK_DIR = "./checkpoints"
CFP_DIR = "./glaucoma_data/ROI images"  # <-- ROI images folder
ROI_DIR = "./glaucoma_data/ROI images"  # ROI images folder
JSON_DIR = "./glaucoma_data/json"  # LabelMe JSON files matching image names

os.makedirs(CHECK_DIR, exist_ok=True)


SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 80
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_POINTS = 59
PATIENCE = 10
MIN_DELTA = 0.01


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [2]:
with zipfile.ZipFile("./glaucoma_data.zip", "r") as zip_ref:
    zip_ref.extractall("./")

In [3]:
# ---- tiny CSV reader (no pandas) ------------------------------------------------
def read_csv(fp: str) -> List[Dict[str, str]]:
    with open(fp, "r", encoding="utf-8") as f:
        lines = [l.rstrip("\n") for l in f if l.strip()]
    header = [h.strip() for h in lines[0].split(",")]
    rows = []
    for line in lines[1:]:
        parts = [p.strip() for p in line.split(",")]
        parts = (parts + [""] * len(header))[: len(header)]  # ensure equal length
        rows.append({h: parts[i] for i, h in enumerate(header)})
    return rows


# ---- detect columns --------------------------------------------------------------
IMAGE_COLS_CANDIDATES = ["image", "image_name", "img", "image_path", "filename", "file"]


def detect_columns(rows: List[Dict[str, str]]):
    if not rows:
        raise ValueError("CSV has no rows.")
    cols = list(rows[0].keys())

    # find image column
    image_col = None
    for c in IMAGE_COLS_CANDIDATES:
        if c in cols:
            image_col = c
            break
    if image_col is None:
        for c in cols:
            if any(
                rows[i][c].lower().endswith((".jpg", ".jpeg", ".png"))
                for i in range(min(10, len(rows)))
            ):
                image_col = c
                break
    if image_col is None:
        raise ValueError("Could not find image filename column.")

    # pick 59 VF columns: prefer v1..v59
    vf_cols = [f"v{i}" for i in range(1, NUM_POINTS + 1)]
    if all(c in cols for c in vf_cols):
        return image_col, vf_cols

    # fallback: numeric columns
    candidates = []
    for c in cols:
        if c == image_col:
            continue
        ok = True
        for r in rows[: min(20, len(rows))]:
            v = r[c].strip()
            if v == "":
                ok = False
                break
            try:
                float(v)
            except:
                ok = False
                break
        if ok:
            candidates.append(c)

    if len(candidates) < NUM_POINTS:
        raise ValueError("Not enough numeric VF columns detected.")

    # sort by trailing number if exists
    def keyfun(name):
        m = re.search(r"(\d+)$", name)
        return (name, int(m.group(1)) if m else 9999)

    candidates_sorted = sorted(candidates, key=keyfun)[:NUM_POINTS]
    return image_col, candidates_sorted

In [4]:
class CFPDataset(Dataset):
    def __init__(
        self, rows: List[Dict[str, str]], image_col: str, vf_cols: List[str], train: bool
    ):
        self.rows = rows
        self.image_col = image_col
        self.vf_cols = vf_cols
        self.train = train

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]

        self.tf = transforms.Compose(
            [transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), *aug, normalize]
        )

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        name = r[self.image_col]

        path = name
        if not os.path.isabs(path):
            if os.path.basename(path) == path:
                path = os.path.join(IMG_ROOT, path)

        img = Image.open(path).convert("RGB")
        x = self.tf(img)
        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)
        return x, y

In [5]:
class MobileNetV3LVF(nn.Module):
    def __init__(self, out_dim=59, pretrained=True):
        super().__init__()
        self.backbone = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        in_f = self.backbone.classifier[-1].in_features
        self.backbone.classifier[-1] = nn.Identity()

        self.regressor = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25), nn.Linear(512, out_dim)
        )

    def forward(self, x):
        f = self.backbone(x)
        return self.regressor(f)

In [6]:
@torch.no_grad()
def mae(pred, true):
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def ms_mae(pred, true):
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0

    for x, y in loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        if train:
            opt.zero_grad()

        pred = model(x)
        loss = crit(pred, y)

        if train:
            loss.backward()
            opt.step()

        bs = x.size(0)
        n += bs

        loss_sum += loss.item() * bs
        pmae_sum += mae(pred, y) * bs
        msmae_sum += ms_mae(pred, y) * bs

    return {"loss": loss_sum / n, "pointwise_mae": pmae_sum / n, "ms_mae": msmae_sum / n}

In [7]:
import os
import random

import numpy as np
import torch

# --- reproducibility (same as before)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class EarlyStopper:
    """
    - Saves whenever val_metric strictly improves over 'best' (tolerance=1e-12).
    - Uses 'min_delta' only to decide whether to reset patience (ref metric).
    """

    def __init__(self, patience=10, min_delta=0.01, ckpt_path=None):
        self.patience = int(patience)
        self.min_delta = float(min_delta)
        self.ckpt_path = ckpt_path
        self.best_save = float("inf")  # for checkpoint saving (any improvement)
        self.best_ref = float("inf")  # for patience (needs >= min_delta improvement)
        self.bad_epochs = 0
        if self.ckpt_path:
            os.makedirs(os.path.dirname(self.ckpt_path), exist_ok=True)

    def update(self, val_metric, model, epoch_meta=None):
        saved = False
        # --- Save on ANY strict improvement
        if val_metric < self.best_save - 1e-12:
            self.best_save = val_metric
            if self.ckpt_path:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "val_pointwise_mae": self.best_save,
                        **(epoch_meta or {}),
                    },
                    self.ckpt_path,
                )
            saved = True

        # --- Early-stopping patience uses min_delta
        if val_metric < self.best_ref - self.min_delta:
            self.best_ref = val_metric
            self.bad_epochs = 0
        else:
            self.bad_epochs += 1

        should_stop = self.bad_epochs > self.patience
        return should_stop, saved


def main():
    rows = read_csv(CSV_PATH)
    image_col, vf_cols = detect_columns(rows)
    print(f"[OK] image column: {image_col}")
    print(f"[OK] using {len(vf_cols)} VF columns: {vf_cols[:5]} ... {vf_cols[-5:]}")

    # split
    N = len(rows)
    n_train = int(0.8 * N)
    random.shuffle(rows)
    train_rows = rows[:n_train]
    val_rows = rows[n_train:]

    train_ds = CFPDataset(train_rows, image_col, vf_cols, train=True)
    val_ds = CFPDataset(val_rows, image_col, vf_cols, train=False)

    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True
    )

    model = MobileNetV3LVF(out_dim=NUM_POINTS, pretrained=True).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    ckpt_path = os.path.join(CHECK_DIR, "best_mobilenetv3l_original_cfp.pth")
    stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA, ckpt_path=ckpt_path)

    for epoch in range(1, EPOCHS + 1):
        tr = run_epoch(model, train_dl, opt)
        va = run_epoch(model, val_dl, None)

        print(
            f"\nEpoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f}  train_pMAE={tr['pointwise_mae']:.3f}  train_msMAE={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f}  val_pMAE={va['pointwise_mae']:.3f}  val_msMAE={va['ms_mae']:.3f}"
        )

        # scheduler on validation MAE
        sched.step(va["pointwise_mae"])

        # early stopping + save best
        should_stop, saved = stopper.update(
            va["pointwise_mae"], model, epoch_meta={"epoch": epoch}
        )
        if saved:
            print(f"\n ✅ Saved BEST → {ckpt_path} (pMAE={stopper.best_save:.3f})")
        if should_stop:
            print(f"\nEarly stopping at epoch {epoch} (best val pMAE={stopper.best_save:.3f})")
            break

    # load best before returning
    state = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(state["model"])
    return model, train_dl, val_dl, image_col, vf_cols, ckpt_path


if __name__ == "__main__":
    model, train_dl, val_dl, image_col, vf_cols, CKPT = main()

[OK] image column: Corresponding CFP
[OK] using 59 VF columns: ['AGE', 'CCT', 'IOP_y', 'Interval Years', 'MD'] ... ['VF50', 'VF51', 'VF52', 'VF53', 'VF54']
Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth


100%|██████████| 21.1M/21.1M [00:00<00:00, 180MB/s]



Epoch 01 | train_loss=5404.9284  train_pMAE=28.422  train_msMAE=28.259 || val_loss=5333.1244  val_pMAE=27.152  val_msMAE=26.899

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_cfp.pth (pMAE=27.152)

Epoch 02 | train_loss=3867.5985  train_pMAE=20.059  train_msMAE=13.692 || val_loss=2189.8170  val_pMAE=26.853  val_msMAE=15.710

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_cfp.pth (pMAE=26.853)

Epoch 03 | train_loss=1062.7709  train_pMAE=15.617  train_msMAE=4.944 || val_loss=397.3284  val_pMAE=10.038  val_msMAE=7.261

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_cfp.pth (pMAE=10.038)

Epoch 04 | train_loss=301.3700  train_pMAE=13.165  train_msMAE=3.995 || val_loss=191.8438  val_pMAE=7.447  val_msMAE=4.968

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_cfp.pth (pMAE=7.447)

Epoch 05 | train_loss=265.4436  train_pMAE=12.205  train_msMAE=3.822 || val_loss=129.9627  val_pMAE=6.882  val_msMAE=4.337

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_

In [8]:
# --- reload BEST checkpoint and evaluate ---
assert "CKPT" in globals(), (
    "CKPT not found. Make sure you ran the training cell that returns CKPT."
)
assert "val_dl" in globals(), (
    "val_dl not found. Make sure you ran the training cell that defines val_dl."
)

# rebuild the exact architecture
best_model = MobileNetV3LVF(out_dim=NUM_POINTS, pretrained=False).to(DEVICE)

state = torch.load(CKPT, map_location=DEVICE)
best_model.load_state_dict(state["model"])
best_model.eval()

# collect predictions on the validation set
all_true, all_pred = [], []
with torch.no_grad():
    for x, y in val_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        p = best_model(x)
        all_true.append(y.cpu())
        all_pred.append(p.cpu())

y_true = torch.cat(all_true, dim=0)
y_pred = torch.cat(all_pred, dim=0)

print("✅ Using BEST checkpoint:", CKPT)
print("✅ Collected predictions:", y_true.shape, y_pred.shape)

✅ Using BEST checkpoint: ./checkpoints/best_mobilenetv3l_original_cfp.pth
✅ Collected predictions: torch.Size([127, 59]) torch.Size([127, 59])


In [9]:
import torch


def rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


def mae_val(a, b):
    return float(torch.mean(torch.abs(a - b)))


def r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


pw_true, pw_pred = y_true.reshape(-1), y_pred.reshape(-1)
print("\n== POINTWISE ==")
print(
    f"RMSE: {rmse(pw_true, pw_pred):.4f} | MAE: {mae_val(pw_true, pw_pred):.4f} | R²: {r2(pw_true, pw_pred):.4f}"
)

t_mean, p_mean = y_true.mean(dim=1), y_pred.mean(dim=1)
print("== POINTWISE-MEAN ==")
print(
    f"RMSE: {rmse(t_mean, p_mean):.4f} | MAE: {mae_val(t_mean, p_mean):.4f} | R²: {r2(t_mean, p_mean):.4f}"
)

print("== MS (same as pointwise-mean) ==")
print(
    f"RMSE: {rmse(t_mean, p_mean):.4f} | MAE: {mae_val(t_mean, p_mean):.4f} | R²: {r2(t_mean, p_mean):.4f}\n"
)


== POINTWISE ==
RMSE: 9.9729 | MAE: 5.7592 | R²: 0.9792
== POINTWISE-MEAN ==
RMSE: 4.3776 | MAE: 3.2183 | R²: 0.2939
== MS (same as pointwise-mean) ==
RMSE: 4.3776 | MAE: 3.2183 | R²: 0.2939



# 2. ROI Images MobileNetV3L

In [10]:
import math
import os
import random
import re
from typing import Dict, List

import numpy as np
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

os.makedirs(CHECK_DIR, exist_ok=True)

SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 80
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_POINTS = 59

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [11]:
# ===================== LABELME JSON → OD/OC MASKS =====================
OD_LABELS = {"od", "disc", "optic_disc", "optic-disc", "optic disc"}
OC_LABELS = {"oc", "cup", "optic_cup", "optic-cup", "optic cup"}


def _poly_area(pts):
    x = [p[0] for p in pts]
    y = [p[1] for p in pts]
    return 0.5 * abs(
        sum(x[i] * y[(i + 1) % len(pts)] - x[(i + 1) % len(pts)] * y[i] for i in range(len(pts)))
    )


def _read_labelme(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    od_polys, oc_polys = [], []
    for sh in data.get("shapes", []):
        label = str(sh.get("label", "")).strip().lower()
        pts = [(float(x), float(y)) for x, y in sh.get("points", [])]
        if len(pts) < 3:
            continue
        if label in OD_LABELS:
            od_polys.append(pts)
        elif label in OC_LABELS:
            oc_polys.append(pts)
    if len(od_polys) > 1:
        od_polys = [max(od_polys, key=_poly_area)]
    if len(oc_polys) > 1:
        oc_polys = [max(oc_polys, key=_poly_area)]
    return od_polys, oc_polys


def _guess_json_path(img_name: str):
    base = os.path.splitext(os.path.basename(img_name))[0]
    for ext in (".json", ".JSON", ".Json"):
        cand = os.path.join(JSON_DIR, base + ext)
        if os.path.exists(cand):
            return cand
    return ""


def build_masks_from_labelme(img_pil: Image.Image, img_name: str, out_size: int):
    W, H = img_pil.size
    od_mask = Image.new("L", (W, H), 0)
    oc_mask = Image.new("L", (W, H), 0)

    jpath = _guess_json_path(img_name)
    if jpath:
        try:
            od_polys, oc_polys = _read_labelme(jpath)
            d_od = ImageDraw.Draw(od_mask)
            d_oc = ImageDraw.Draw(oc_mask)
            for poly in od_polys:
                d_od.polygon(poly, outline=1, fill=1)
            for poly in oc_polys:
                d_oc.polygon(poly, outline=1, fill=1)
        except Exception as e:
            print(f"[WARN] parsing {jpath}: {e}")

    od_mask = od_mask.resize((out_size, out_size), resample=Image.NEAREST)
    oc_mask = oc_mask.resize((out_size, out_size), resample=Image.NEAREST)
    return od_mask, oc_mask

In [12]:
class CFPDataset(Dataset):
    def __init__(
        self, rows: List[Dict[str, str]], image_col: str, vf_cols: List[str], train: bool
    ):
        self.rows = rows
        self.image_col = image_col
        self.vf_cols = vf_cols
        self.train = train

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]

        self.tf = transforms.Compose(
            [transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), *aug, normalize]
        )

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        name = r[self.image_col]

        path = name
        if not os.path.isabs(path):
            if os.path.basename(path) == path:
                path = os.path.join(CFP_DIR, path)

        img = Image.open(path).convert("RGB")
        x = self.tf(img)
        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)
        return x, y

In [13]:
class MobileNetV3LVF(nn.Module):
    def __init__(self, out_dim=59, pretrained=True):
        super().__init__()
        self.backbone = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        in_f = self.backbone.classifier[-1].in_features
        self.backbone.classifier[-1] = nn.Identity()

        self.regressor = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25), nn.Linear(512, out_dim)
        )

    def forward(self, x):
        f = self.backbone(x)
        return self.regressor(f)

In [14]:
@torch.no_grad()
def mae(pred, true):
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def ms_mae(pred, true):
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0

    for x, y in loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        if train:
            opt.zero_grad()

        pred = model(x)
        loss = crit(pred, y)

        if train:
            loss.backward()
            opt.step()

        bs = x.size(0)
        n += bs

        loss_sum += loss.item() * bs
        pmae_sum += mae(pred, y) * bs
        msmae_sum += ms_mae(pred, y) * bs

    return {"loss": loss_sum / n, "pointwise_mae": pmae_sum / n, "ms_mae": msmae_sum / n}

In [15]:
import os
import random

import numpy as np
import torch

# --- reproducibility (same as before)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


class EarlyStopper:
    """
    - Saves whenever val_metric strictly improves over 'best' (tolerance=1e-12).
    - Uses 'min_delta' only to decide whether to reset patience (ref metric).
    """

    def __init__(self, patience=10, min_delta=0.01, ckpt_path=None):
        self.patience = int(patience)
        self.min_delta = float(min_delta)
        self.ckpt_path = ckpt_path
        self.best_save = float("inf")  # for checkpoint saving (any improvement)
        self.best_ref = float("inf")  # for patience (needs >= min_delta improvement)
        self.bad_epochs = 0
        if self.ckpt_path:
            os.makedirs(os.path.dirname(self.ckpt_path), exist_ok=True)

    def update(self, val_metric, model, epoch_meta=None):
        saved = False
        # --- Save on ANY strict improvement
        if val_metric < self.best_save - 1e-12:
            self.best_save = val_metric
            if self.ckpt_path:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "val_pointwise_mae": self.best_save,
                        **(epoch_meta or {}),
                    },
                    self.ckpt_path,
                )
            saved = True

        # --- Early-stopping patience uses min_delta
        if val_metric < self.best_ref - self.min_delta:
            self.best_ref = val_metric
            self.bad_epochs = 0
        else:
            self.bad_epochs += 1

        should_stop = self.bad_epochs > self.patience
        return should_stop, saved


def main():
    rows = read_csv(CSV_PATH)
    image_col, vf_cols = detect_columns(rows)
    print(f"[OK] image column: {image_col}")
    print(f"[OK] using {len(vf_cols)} VF columns: {vf_cols[:5]} ... {vf_cols[-5:]}")

    # split
    N = len(rows)
    n_train = int(0.8 * N)
    random.shuffle(rows)
    train_rows = rows[:n_train]
    val_rows = rows[n_train:]

    train_ds = CFPDataset(train_rows, image_col, vf_cols, train=True)
    val_ds = CFPDataset(val_rows, image_col, vf_cols, train=False)

    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True
    )

    model = MobileNetV3LVF(out_dim=NUM_POINTS, pretrained=True).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    ckpt_path = os.path.join(CHECK_DIR, "best_mobilenetv3l_original_roi.pth")
    stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA, ckpt_path=ckpt_path)

    for epoch in range(1, EPOCHS + 1):
        tr = run_epoch(model, train_dl, opt)
        va = run_epoch(model, val_dl, None)

        print(
            f"\nEpoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f}  train_pMAE={tr['pointwise_mae']:.3f}  train_msMAE={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f}  val_pMAE={va['pointwise_mae']:.3f}  val_msMAE={va['ms_mae']:.3f}"
        )

        # scheduler on validation MAE
        sched.step(va["pointwise_mae"])

        # early stopping + save best
        should_stop, saved = stopper.update(
            va["pointwise_mae"], model, epoch_meta={"epoch": epoch}
        )
        if saved:
            print(f"\n ✅ Saved BEST → {ckpt_path} (pMAE={stopper.best_save:.3f})")
        if should_stop:
            print(f"\nEarly stopping at epoch {epoch} (best val pMAE={stopper.best_save:.3f})")
            break

    # load best before returning
    state = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(state["model"])
    return model, train_dl, val_dl, image_col, vf_cols, ckpt_path


if __name__ == "__main__":
    model, train_dl, val_dl, image_col, vf_cols, CKPT = main()

[OK] image column: Corresponding CFP
[OK] using 59 VF columns: ['AGE', 'CCT', 'IOP_y', 'Interval Years', 'MD'] ... ['VF50', 'VF51', 'VF52', 'VF53', 'VF54']

Epoch 01 | train_loss=5414.5734  train_pMAE=28.475  train_msMAE=28.315 || val_loss=5417.1320  val_pMAE=27.656  val_msMAE=27.453

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_roi.pth (pMAE=27.656)

Epoch 02 | train_loss=3974.9825  train_pMAE=20.565  train_msMAE=14.576 || val_loss=2725.2064  val_pMAE=18.350  val_msMAE=6.962

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_roi.pth (pMAE=18.350)

Epoch 03 | train_loss=1162.0608  train_pMAE=15.938  train_msMAE=5.374 || val_loss=357.6270  val_pMAE=9.239  val_msMAE=6.112

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_original_roi.pth (pMAE=9.239)

Epoch 04 | train_loss=311.6845  train_pMAE=13.164  train_msMAE=4.081 || val_loss=484.8784  val_pMAE=9.342  val_msMAE=7.954

Epoch 05 | train_loss=265.9235  train_pMAE=12.184  train_msMAE=3.851 || val_loss=209.0681  val_pMA

In [16]:
# reload best ROI checkpoint
best_roi = MobileNetV3LVF(out_dim=NUM_POINTS, pretrained=False).to(DEVICE)
state = torch.load(CKPT, map_location=DEVICE)
best_roi.load_state_dict(state["model"])
best_roi.eval()

# collect predictions
all_true = []
all_pred = []

with torch.no_grad():
    for x, y in val_dl:
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        p = best_roi(x)

        all_true.append(y.cpu())
        all_pred.append(p.cpu())

y_true = torch.cat(all_true, dim=0)
y_pred = torch.cat(all_pred, dim=0)

print("✅ ROI predictions collected:", y_true.shape, y_pred.shape)

✅ ROI predictions collected: torch.Size([127, 59]) torch.Size([127, 59])


In [17]:
import torch


# helpers
def rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


def mae_val(a, b):
    return float(torch.mean(torch.abs(a - b)))


def r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


# POINTWISE
pw_true = y_true.reshape(-1)
pw_pred = y_pred.reshape(-1)

print("\n== ROI: POINTWISE ==")
print(f"RMSE: {rmse(pw_true, pw_pred):.4f}")
print(f"MAE : {mae_val(pw_true, pw_pred):.4f}")
print(f"R²  : {r2(pw_true, pw_pred):.4f}")

# POINTWISE-MEAN / MS
t_mean = y_true.mean(dim=1)
p_mean = y_pred.mean(dim=1)

print("\n== ROI: POINTWISE-MEAN / MS ==")
print(f"RMSE: {rmse(t_mean, p_mean):.4f}")
print(f"MAE : {mae_val(t_mean, p_mean):.4f}")
print(f"R²  : {r2(t_mean, p_mean):.4f}\n")


== ROI: POINTWISE ==
RMSE: 12.1223
MAE : 6.1318
R²  : 0.9692

== ROI: POINTWISE-MEAN / MS ==
RMSE: 5.5650
MAE : 3.7295
R²  : -0.1412



# 3. ROI + OD/OD Segmentation MobileNetV3L

In [18]:
import json
import os
import random
import re
from typing import Dict, List, Tuple

import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

# ---- paths ----

os.makedirs(CHECK_DIR, exist_ok=True)

# ---- training ----
SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 80
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_POINTS = 59

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [19]:
# ===================== CSV UTILS =====================
def read_csv(fp: str) -> List[Dict[str, str]]:
    with open(fp, "r", encoding="utf-8") as f:
        lines = [l.rstrip("\n") for l in f if l.strip()]
    header = [h.strip() for h in lines[0].split(",")]
    rows = []
    for line in lines[1:]:
        parts = [p.strip() for p in line.split(",")]
        parts = (parts + [""] * len(header))[: len(header)]
        rows.append({h: parts[i] for i, h in enumerate(header)})
    return rows


IMAGE_COLS_CANDIDATES = ["image", "image_name", "img", "image_path", "filename", "file"]


def detect_columns(rows: List[Dict[str, str]]) -> Tuple[str, List[str]]:
    if not rows:
        raise ValueError("CSV has no rows.")
    cols = list(rows[0].keys())

    # image col
    image_col = None
    for c in IMAGE_COLS_CANDIDATES:
        if c in cols:
            image_col = c
            break
    if image_col is None:
        for c in cols:
            if any(
                rows[i][c].lower().endswith((".jpg", ".jpeg", ".png"))
                for i in range(min(10, len(rows)))
            ):
                image_col = c
                break
    if image_col is None:
        raise ValueError("Could not find image filename column.")

    # VF cols prefer v1..v59 else numeric fallback
    vf = [f"v{i}" for i in range(1, NUM_POINTS + 1)]
    if all(c in cols for c in vf):
        return image_col, vf

    cand = []
    for c in cols:
        if c == image_col:
            continue
        ok = True
        for r in rows[: min(20, len(rows))]:
            v = r[c].strip()
            if v == "":
                ok = False
                break
            try:
                float(v)
            except:
                ok = False
                break
        if ok:
            cand.append(c)
    if len(cand) < NUM_POINTS:
        raise ValueError(f"Need {NUM_POINTS} VF cols, found {len(cand)}.")

    def keyfun(name):
        m = re.search(r"(\d+)$", name)
        return (name, int(m.group(1)) if m else 9999)

    cand = sorted(cand, key=keyfun)[:NUM_POINTS]
    return image_col, cand

In [20]:
# ===================== LABELME JSON → OD/OC MASKS =====================
OD_LABELS = {"od", "disc", "optic_disc", "optic-disc", "optic disc"}
OC_LABELS = {"oc", "cup", "optic_cup", "optic-cup", "optic cup"}


def _poly_area(pts):
    x = [p[0] for p in pts]
    y = [p[1] for p in pts]
    return 0.5 * abs(
        sum(x[i] * y[(i + 1) % len(pts)] - x[(i + 1) % len(pts)] * y[i] for i in range(len(pts)))
    )


def _read_labelme(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    od_polys, oc_polys = [], []
    for sh in data.get("shapes", []):
        label = str(sh.get("label", "")).strip().lower()
        pts = [(float(x), float(y)) for x, y in sh.get("points", [])]
        if len(pts) < 3:
            continue
        if label in OD_LABELS:
            od_polys.append(pts)
        elif label in OC_LABELS:
            oc_polys.append(pts)
    if len(od_polys) > 1:
        od_polys = [max(od_polys, key=_poly_area)]
    if len(oc_polys) > 1:
        oc_polys = [max(oc_polys, key=_poly_area)]
    return od_polys, oc_polys


def _guess_json_path(img_name: str):
    base = os.path.splitext(os.path.basename(img_name))[0]
    for ext in (".json", ".JSON", ".Json"):
        cand = os.path.join(JSON_DIR, base + ext)
        if os.path.exists(cand):
            return cand
    return ""


def build_masks_from_labelme(img_pil: Image.Image, img_name: str, out_size: int):
    W, H = img_pil.size
    od_mask = Image.new("L", (W, H), 0)
    oc_mask = Image.new("L", (W, H), 0)

    jpath = _guess_json_path(img_name)
    if jpath:
        try:
            od_polys, oc_polys = _read_labelme(jpath)
            d_od = ImageDraw.Draw(od_mask)
            d_oc = ImageDraw.Draw(oc_mask)
            for poly in od_polys:
                d_od.polygon(poly, outline=1, fill=1)
            for poly in oc_polys:
                d_oc.polygon(poly, outline=1, fill=1)
        except Exception as e:
            print(f"[WARN] parsing {jpath}: {e}")

    od_mask = od_mask.resize((out_size, out_size), resample=Image.NEAREST)
    oc_mask = oc_mask.resize((out_size, out_size), resample=Image.NEAREST)
    return od_mask, oc_mask

In [21]:
# ===================== DATASET (5-channel RGB+OD+OC) =====================
class ROI_OD_OC_Dataset(Dataset):
    def __init__(self, rows, image_col, vf_cols, train=True, img_root=ROI_DIR, img_size=IMG_SIZE):
        self.rows, self.image_col, self.vf_cols = rows, image_col, vf_cols
        self.train, self.img_root, self.img_size = train, img_root, img_size

        norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]
        self.rgb_tf = transforms.Compose(
            [transforms.Resize((img_size, img_size)), transforms.ToTensor(), *aug, norm]
        )
        self.mask_tf = transforms.ToTensor()  # L → (1,H,W) float {0,1}

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, i):
        r = self.rows[i]
        fn = r[self.image_col]
        path = fn if os.path.isabs(fn) else os.path.join(self.img_root, fn)

        img = Image.open(path).convert("RGB")
        od_img, oc_img = build_masks_from_labelme(img, fn, self.img_size)

        x_rgb = self.rgb_tf(img)  # (3,H,W)
        x_od = self.mask_tf(od_img)  # (1,H,W)
        x_oc = self.mask_tf(oc_img)  # (1,H,W)
        x = torch.cat([x_rgb, x_od, x_oc], dim=0)  # (5,H,W)

        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)
        return x, y

In [22]:
# ===================== MODEL (ResNet-50 with 5-ch input) =====================
class MobileNetV3L_5ch_VF(nn.Module):
    def __init__(self, out_dim=59, pretrained=True):
        super().__init__()
        base = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        # adapt Conv2dNormActivation: 3→5 channels (init extra channels with mean RGB weights)
        old = base.features[0][0]
        new = nn.Conv2d(
            5,
            old.out_channels,
            kernel_size=old.kernel_size,
            stride=old.stride,
            padding=old.padding,
            bias=(old.bias is not None),
        )
        with torch.no_grad():
            new.weight[:, :3, :, :] = old.weight
            mean_w = old.weight.mean(dim=1, keepdim=True)
            new.weight[:, 3:5, :, :] = mean_w.repeat(1, 2, 1, 1)
            if old.bias is not None:
                new.bias.copy_(old.bias)
        base.features[0][0] = new

        in_f = base.classifier[-1].in_features
        base.classifier[-1] = nn.Identity()
        self.backbone = base
        self.regressor = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25), nn.Linear(512, out_dim)
        )

    def forward(self, x5):
        f = self.backbone(x5)
        return self.regressor(f)

In [23]:
# ===================== METRICS + EPOCH LOOP =====================
@torch.no_grad()
def mae(pred, true):
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def ms_mae(pred, true):
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0
    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        if train:
            opt.zero_grad()
        pred = model(x)
        loss = crit(pred, y)
        if train:
            loss.backward()
            opt.step()

        bs = x.size(0)
        n += bs
        loss_sum += loss.item() * bs
        pmae_sum += mae(pred, y) * bs
        msmae_sum += ms_mae(pred, y) * bs

    return {"loss": loss_sum / n, "pointwise_mae": pmae_sum / n, "ms_mae": msmae_sum / n}

In [24]:
import os
import random

import numpy as np
import torch
from torch.utils.data import DataLoader

# ---- (optional) reproducibility on small data
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# ---- Early Stopping helper
class EarlyStopper:
    def __init__(self, patience=10, min_delta=0.01, ckpt_path=None):
        self.patience = int(patience)
        self.min_delta = float(min_delta)
        self.ckpt_path = ckpt_path
        self.best = float("inf")
        self.bad_epochs = 0

    def step(self, val_metric, model, epoch_meta=None):
        # returns True if we should stop
        if val_metric < self.best - self.min_delta:
            self.best = val_metric
            self.bad_epochs = 0
            if self.ckpt_path:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "val_pointwise_mae": self.best,
                        **(epoch_meta or {}),
                    },
                    self.ckpt_path,
                )
            return False
        else:
            self.bad_epochs += 1
            return self.bad_epochs > self.patience


# ===================== TRAIN =====================
def train_resnet50_roi_odoc(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01):
    rows = read_csv(CSV_PATH)
    image_col, vf_cols = detect_columns(rows)
    print(f"[OK] image column: {image_col}")
    print(f"[OK] {len(vf_cols)} VF cols: {vf_cols[:5]} ... {vf_cols[-5:]}")

    random.shuffle(rows)
    N = len(rows)
    n_train = int(0.8 * N)
    train_rows = rows[:n_train]
    val_rows = rows[n_train:]

    train_ds = ROI_OD_OC_Dataset(train_rows, image_col, vf_cols, train=True)
    val_ds = ROI_OD_OC_Dataset(val_rows, image_col, vf_cols, train=False)

    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True
    )

    model = MobileNetV3L_5ch_VF(out_dim=NUM_POINTS, pretrained=True).to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)

    # ---- LR scheduler (plateau)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    ckpt = os.path.join(CHECK_DIR, "best_mobilenetv3l_ROI_ODOC.pth")
    stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA, ckpt_path=ckpt)

    for epoch in range(1, EPOCHS + 1):
        tr = run_epoch(model, train_dl, opt)
        va = run_epoch(model, val_dl, None)

        # step scheduler on validation pMAE
        sched.step(va["pointwise_mae"])

        print(
            f"Epoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f}  train_pMAE={tr['pointwise_mae']:.3f}  train_MS={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f}  val_pMAE={va['pointwise_mae']:.3f}  val_MS={va['ms_mae']:.3f}"
        )

        # save best (and detect improvement for pretty print)
        prev_best = stopper.best
        should_stop = stopper.step(va["pointwise_mae"], model, epoch_meta={"epoch": epoch})
        if stopper.best < prev_best - MIN_DELTA:
            print(f"\n ✅ Saved BEST → {ckpt} (pMAE={stopper.best:.3f})\n")

        if should_stop:
            print(f"Early stopping at epoch {epoch} (best val pMAE={stopper.best:.3f})")
            break

    # load best weights before returning
    state = torch.load(ckpt, map_location=DEVICE)
    model.load_state_dict(state["model"])

    return model, train_dl, val_dl, image_col, vf_cols, ckpt


# run training and expose globals
model_odoc, train_dl_odoc, val_dl_odoc, image_col_odoc, vf_cols_odoc, CKPT_ODOC = (
    train_resnet50_roi_odoc(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01)
)

[OK] image column: Corresponding CFP
[OK] 59 VF cols: ['AGE', 'CCT', 'IOP_y', 'Interval Years', 'MD'] ... ['VF50', 'VF51', 'VF52', 'VF53', 'VF54']
Epoch 01 | train_loss=5405.0561  train_pMAE=28.408  train_MS=28.255 || val_loss=5188.2580  val_pMAE=26.195  val_MS=25.855

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC.pth (pMAE=26.195)

Epoch 02 | train_loss=3792.9892  train_pMAE=19.962  train_MS=13.381 || val_loss=2230.6751  val_pMAE=20.818  val_MS=8.704

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC.pth (pMAE=20.818)

Epoch 03 | train_loss=1001.0242  train_pMAE=15.370  train_MS=5.395 || val_loss=148.0195  val_pMAE=7.393  val_MS=4.556

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC.pth (pMAE=7.393)

Epoch 04 | train_loss=298.7601  train_pMAE=13.078  train_MS=3.963 || val_loss=154.1055  val_pMAE=6.761  val_MS=4.421

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC.pth (pMAE=6.761)

Epoch 05 | train_loss=265.8528  train_pMAE=12.276  train_MS=3.731 || 

In [25]:
# ===================== EVALUATE BEST + PAPER METRICS =====================
# reload best
best_odoc = MobileNetV3L_5ch_VF(out_dim=NUM_POINTS, pretrained=False).to(DEVICE)
state = torch.load(CKPT_ODOC, map_location=DEVICE)
best_odoc.load_state_dict(state["model"])
best_odoc.eval()

# predictions
all_true, all_pred = [], []
with torch.no_grad():
    for x, y in val_dl_odoc:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        p = best_odoc(x)
        all_true.append(y.cpu())
        all_pred.append(p.cpu())

y_true = torch.cat(all_true, dim=0)
y_pred = torch.cat(all_pred, dim=0)
print("✅ Collected predictions:", y_true.shape, y_pred.shape)


# metrics (safe names to avoid clobbering mae/ms_mae)
def rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


def mae_value(a, b):
    return float(torch.mean(torch.abs(a - b)))


def r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


pw_true, pw_pred = y_true.reshape(-1), y_pred.reshape(-1)
print("\n== ROI+OD/OC: POINTWISE ==")
print(f"RMSE: {rmse(pw_true, pw_pred):.4f}")
print(f"MAE : {mae_value(pw_true, pw_pred):.4f}")
print(f"R²  : {r2(pw_true, pw_pred):.4f}")

t_mean, p_mean = y_true.mean(dim=1), y_pred.mean(dim=1)
print("\n== ROI+OD/OC: POINTWISE-MEAN / MS ==")
print(f"RMSE: {rmse(t_mean, p_mean):.4f}")
print(f"MAE : {mae_value(t_mean, p_mean):.4f}")
print(f"R²  : {r2(t_mean, p_mean):.4f}\n")

✅ Collected predictions: torch.Size([127, 59]) torch.Size([127, 59])

== ROI+OD/OC: POINTWISE ==
RMSE: 12.1275
MAE : 6.2384
R²  : 0.9692

== ROI+OD/OC: POINTWISE-MEAN / MS ==
RMSE: 5.5862
MAE : 3.9871
R²  : -0.1499



# 4. ROI + Clinical Features MobileNetV3L

In [26]:
import json
import os
import random
import re
from typing import Dict, List, Tuple

import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

os.makedirs(CHECK_DIR, exist_ok=True)

# ---- training ----
SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 80
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_POINTS = 59  # VF1..VF59

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# clinical columns present in your CSV (plus computed ones)
# from your columns: AGE, GENDER, IOP_y, MD exist; we’ll add computed CDR & PSD
CLIN_NUM_COLS = ["AGE", "IOP_y", "CDR"]  # numeric
CLIN_CAT_COLS = ["GENDER"]  # categorical (mapped to 0/1)
IMAGE_COLS_CANDIDATES = [
    "Corresponding CFP",
    "image",
    "image_name",
    "img",
    "image_path",
    "filename",
    "file",
]

In [27]:
def read_csv(fp: str) -> List[Dict[str, str]]:
    with open(fp, "r", encoding="utf-8") as f:
        lines = [l.rstrip("\n") for l in f if l.strip()]
    header = [h.strip() for h in lines[0].split(",")]
    rows = []
    for line in lines[1:]:
        parts = [p.strip() for p in line.split(",")]
        parts = (parts + [""] * len(header))[: len(header)]
        rows.append({h: parts[i] for i, h in enumerate(header)})
    return rows


def detect_columns(rows: List[Dict[str, str]]) -> Tuple[str, List[str]]:
    if not rows:
        raise ValueError("CSV has no rows.")
    cols = list(rows[0].keys())

    # image column: prefer "Corresponding CFP" if present
    image_col = None
    for c in IMAGE_COLS_CANDIDATES:
        if c in cols:
            image_col = c
            break
    if image_col is None:
        for c in cols:
            if any(
                rows[i][c].lower().endswith((".jpg", ".jpeg", ".png"))
                for i in range(min(10, len(rows)))
            ):
                image_col = c
                break
    if image_col is None:
        raise ValueError("Could not find image filename column.")

    # prefer explicit VF1..VF59 (ignore VF0, VF60)
    vf_cols_pref = [f"VF{i}" for i in range(1, 60)]
    if all(c in cols for c in vf_cols_pref):
        return image_col, vf_cols_pref

    # fallback: numeric detection (exclude clinical & image)
    excluded = set([image_col] + CLIN_NUM_COLS + CLIN_CAT_COLS + ["VF0", "VF60"])
    candidates = []
    for c in cols:
        if c in excluded:
            continue
        ok = True
        for r in rows[: min(20, len(rows))]:
            v = r[c].strip()
            if v == "":
                ok = False
                break
            try:
                float(v)
            except:
                ok = False
                break
        if ok:
            candidates.append(c)

    if len(candidates) < NUM_POINTS:
        raise ValueError(
            f"Not enough numeric VF columns; found {len(candidates)}, need {NUM_POINTS}."
        )

    def keyfun(name):
        m = re.search(r"(\d+)$", name)
        return (name, int(m.group(1)) if m else 9999)

    candidates_sorted = sorted(candidates, key=keyfun)[:NUM_POINTS]
    return image_col, candidates_sorted

In [28]:
# ---- CDR from LabelMe polygons (vertical cup/disc ratio) ----
OD_LABELS = {"od", "disc", "optic_disc", "optic-disc", "optic disc"}
OC_LABELS = {"oc", "cup", "optic_cup", "optic-cup", "optic cup"}


def _read_labelme_polys(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    od_polys, oc_polys = [], []
    for sh in data.get("shapes", []):
        lab = str(sh.get("label", "")).strip().lower()
        pts = sh.get("points", [])
        if len(pts) < 3:
            continue
        if lab in OD_LABELS:
            od_polys.append(pts)
        if lab in OC_LABELS:
            oc_polys.append(pts)

    # keep polygon with max vertical height if multiple
    def vheight(poly):
        ys = [p[1] for p in poly]
        return (max(ys) - min(ys)) if ys else 0.0

    if len(od_polys) > 1:
        od_polys = [max(od_polys, key=vheight)]
    if len(oc_polys) > 1:
        oc_polys = [max(oc_polys, key=vheight)]
    return od_polys, oc_polys


def _guess_json(img_name: str):
    base = os.path.splitext(os.path.basename(img_name))[0]
    for ext in (".json", ".JSON", ".Json"):
        p = os.path.join(JSON_DIR, base + ext)
        if os.path.exists(p):
            return p
    return ""


def compute_cdr_from_json(img_name: str):
    """
    CDR = vertical height of OC / vertical height of OD.
    Returns None if JSON missing or polygons absent.
    """
    jpath = _guess_json(img_name)
    if not jpath:
        return None
    try:
        od_polys, oc_polys = _read_labelme_polys(jpath)
        if not od_polys or not oc_polys:
            return None

        def vheight(poly):
            ys = [float(y) for _, y in poly]
            return max(ys) - min(ys) if ys else 0.0

        h_od = vheight(od_polys[0])
        h_oc = vheight(oc_polys[0])
        if h_od <= 0:
            return None
        return float(h_oc / h_od)
    except Exception as e:
        print(f"[WARN] CDR parse failed for {img_name}: {e}")
        return None

In [29]:
def augment_rows_with_cdr(rows, image_col, vf_cols):
    augmented = []
    miss_cdr = miss_psd = 0
    for r in rows:
        r2 = dict(r)
        # compute CDR from JSON polygons
        cdr = compute_cdr_from_json(r2[image_col])
        if cdr is None:
            miss_cdr += 1
        r2["CDR"] = cdr

        augmented.append(r2)
    print(f"✅ Augmented rows: CDR missing={miss_cdr}, PSD missing={miss_psd}")
    return augmented

In [30]:
# ----------------- AUGMENT ROWS WITH CDR  -----------------
def augment_rows_with_cdr_psd(rows, image_col, vf_cols):
    augmented = []
    miss_cdr = miss_psd = 0
    for r in rows:
        r2 = dict(r)
        cdr = compute_cdr_from_json(r2[image_col])
        if cdr is None:
            miss_cdr += 1
        r2["CDR"] = cdr

        augmented.append(r2)
    return augmented

In [31]:
def to_float(x):
    x = str(x).strip()
    if x == "":
        return None
    try:
        return float(x)
    except:
        return None


def fit_clinical_stats(rows, clin_num_cols):
    stats = {}
    for c in clin_num_cols:
        vals = [to_float(r.get(c, "")) for r in rows]
        vals = [v for v in vals if v is not None]
        mean = np.mean(vals) if vals else 0.0
        std = np.std(vals) if vals else 1.0
        if std == 0:
            std = 1.0
        stats[c] = (float(mean), float(std))
    return stats


def encode_gender(x):
    s = str(x).strip().lower()
    if s in ("m", "male", "man"):
        return 1.0
    if s in ("f", "female", "woman"):
        return 0.0
    return 0.5  # unknown/other


def build_clinical_vector(r, stats):
    vec = []
    for c in CLIN_NUM_COLS:
        v = to_float(r.get(c, ""))
        mean, std = stats[c]
        v = mean if v is None else v
        v = (v - mean) / std
        vec.append(v)
    for c in CLIN_CAT_COLS:
        if c == "GENDER":
            vec.append(encode_gender(r.get(c, "")))
        else:
            vec.append(0.0)
    return torch.tensor(vec, dtype=torch.float32)

In [32]:
class ROIClinicalDataset(Dataset):
    def __init__(
        self, rows, image_col, vf_cols, clin_stats, train=True, img_root=ROI_DIR, img_size=IMG_SIZE
    ):
        self.rows = rows
        self.image_col = image_col
        self.vf_cols = vf_cols
        self.clin_stats = clin_stats
        self.train = train
        self.img_root = img_root

        normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]
        self.tf = transforms.Compose(
            [transforms.Resize((IMG_SIZE, IMG_SIZE)), transforms.ToTensor(), *aug, normalize]
        )

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        r = self.rows[idx]
        name = r[self.image_col]
        path = name if os.path.isabs(name) else os.path.join(self.img_root, name)
        img = Image.open(path).convert("RGB")
        x_img = self.tf(img)

        x_clin = build_clinical_vector(r, self.clin_stats)  # (clin_dim,)
        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)  # (59,)

        return x_img, x_clin, y

In [33]:
# ----------------- DATASET: 5-CH ROI + CLINICAL -----------------
class ROI_ODOC_Clinical_Dataset(Dataset):
    def __init__(
        self, rows, image_col, vf_cols, clin_stats, train=True, img_root=ROI_DIR, img_size=IMG_SIZE
    ):
        self.rows = rows
        self.image_col = image_col
        self.vf_cols = vf_cols
        self.clin_stats = clin_stats
        self.train = train
        self.img_root = img_root
        self.img_size = img_size

        norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]
        self.rgb_tf = transforms.Compose(
            [transforms.Resize((img_size, img_size)), transforms.ToTensor(), *aug, norm]
        )
        self.mask_tf = transforms.ToTensor()  # L→(1,H,W) float {0,1}

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        r = self.rows[idx]
        fn = r[self.image_col]
        path = fn if os.path.isabs(fn) else os.path.join(self.img_root, fn)

        img = Image.open(path).convert("RGB")
        od_img, oc_img = build_masks_from_labelme(img, fn, self.img_size)

        x_rgb = self.rgb_tf(img)  # (3,H,W)
        x_od = self.mask_tf(od_img)  # (1,H,W)
        x_oc = self.mask_tf(oc_img)  # (1,H,W)
        x5 = torch.cat([x_rgb, x_od, x_oc], dim=0)  # (5,H,W)

        x_clin = build_clinical_vector(r, self.clin_stats)  # (clin_dim,)
        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)  # (59,)

        return x5, x_clin, y

In [34]:
class MobileNetV3L_ROI_Clinical(nn.Module):
    def __init__(self, clin_dim, out_dim=59, pretrained=True):
        super().__init__()
        self.backbone = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        in_f = self.backbone.classifier[-1].in_features
        self.backbone.classifier[-1] = nn.Identity()

        self.img_head = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25)
        )

        self.clin_head = nn.Sequential(
            nn.Linear(clin_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.10),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True),
        )

        self.fuse = nn.Sequential(
            nn.Linear(512 + 64, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(256, out_dim),
        )

    def forward(self, x_img, x_clin):
        f = self.backbone(x_img)  # (B, 2048)
        f = self.img_head(f)  # (B, 512)
        g = self.clin_head(x_clin)  # (B, 64)
        z = torch.cat([f, g], dim=1)  # (B, 576)
        out = self.fuse(z)  # (B, 59)
        return out

In [35]:
# ----------------- MODEL: 5-CH RESNET50 + CLINICAL MLP (FUSION) -----------------
class MobileNetV3L_5ch_Clinical(nn.Module):
    def __init__(self, clin_dim, out_dim=59, pretrained=True):
        super().__init__()
        base = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        # adapt Conv2dNormActivation: 3→5 channels (init extra channels with mean RGB weights)
        old = base.features[0][0]
        new = nn.Conv2d(
            5,
            old.out_channels,
            kernel_size=old.kernel_size,
            stride=old.stride,
            padding=old.padding,
            bias=(old.bias is not None),
        )
        with torch.no_grad():
            new.weight[:, :3, :, :] = old.weight
            mean_w = old.weight.mean(dim=1, keepdim=True)
            new.weight[:, 3:5, :, :] = mean_w.repeat(1, 2, 1, 1)
            if old.bias is not None:
                new.bias.copy_(old.bias)
        base.features[0][0] = new

        in_f = base.classifier[-1].in_features
        base.classifier[-1] = nn.Identity()
        self.backbone = base

        self.img_head = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25)
        )
        self.clin_head = nn.Sequential(
            nn.Linear(clin_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.10),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True),
        )
        self.fuse = nn.Sequential(
            nn.Linear(512 + 64, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(256, out_dim),
        )

    def forward(self, x5, xclin):
        f = self.backbone(x5)  # (B, 2048)
        f = self.img_head(f)  # (B, 512)
        g = self.clin_head(xclin)  # (B, 64)
        z = torch.cat([f, g], dim=1)
        out = self.fuse(z)  # (B, 59)
        return out

In [36]:
@torch.no_grad()
def mae(pred, true):
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def ms_mae(pred, true):
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0
    for x_img, x_clin, y in loader:
        x_img = x_img.to(DEVICE)
        x_clin = x_clin.to(DEVICE)
        y = y.to(DEVICE)

        if train:
            opt.zero_grad()
        pred = model(x_img, x_clin)
        loss = crit(pred, y)
        if train:
            loss.backward()
            opt.step()

        bs = x_img.size(0)
        n += bs
        loss_sum += loss.item() * bs
        pmae_sum += mae(pred, y) * bs
        msmae_sum += ms_mae(pred, y) * bs

    return {"loss": loss_sum / n, "pointwise_mae": pmae_sum / n, "ms_mae": msmae_sum / n}

In [37]:
def train_resnet50_roi_odoc_with_cdr(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01):
    import os
    import random

    import torch
    from torch.utils.data import DataLoader

    # --- Early stopper that saves best on val pMAE ---
    class EarlyStopper:
        def __init__(self, patience=10, min_delta=0.01, ckpt_path=None):
            self.patience = int(patience)
            self.min_delta = float(min_delta)
            self.ckpt_path = ckpt_path
            self.best = float("inf")
            self.bad_epochs = 0

        def step(self, val_pmae, model, epoch_meta=None):
            if val_pmae < self.best - self.min_delta:
                self.best = val_pmae
                self.bad_epochs = 0
                if self.ckpt_path:
                    torch.save(
                        {
                            "model": model.state_dict(),
                            "val_pointwise_mae": self.best,
                            **(epoch_meta or {}),
                        },
                        self.ckpt_path,
                    )
                return False
            else:
                self.bad_epochs += 1
                return self.bad_epochs > self.patience

    # --- read & detect columns ---
    rows = read_csv(CSV_PATH)
    image_col, vf_cols = detect_columns(rows)
    print(f"[OK] image column: {image_col}")
    print(f"[OK] {len(vf_cols)} VF cols: {vf_cols[:5]} ... {vf_cols[-5:]}")

    # --- ADD CDR BEFORE SPLIT ---
    rows = augment_rows_with_cdr_psd(rows, image_col, vf_cols)  # <-- adds CDR fields per row

    random.shuffle(rows)
    N = len(rows)
    n_train = int(0.8 * N)
    train_rows = rows[:n_train]
    val_rows = rows[n_train:]

    # --- fit clinical stats on TRAIN only (handles imputation/encoding) ---
    clin_stats = fit_clinical_stats(train_rows, CLIN_NUM_COLS)
    clin_dim = len(CLIN_NUM_COLS) + len(CLIN_CAT_COLS)

    # --- datasets / loaders: use the dataset that returns (x5, x_clin, y) ---
    train_ds = ROI_ODOC_Clinical_Dataset(train_rows, image_col, vf_cols, clin_stats, train=True)
    val_ds = ROI_ODOC_Clinical_Dataset(val_rows, image_col, vf_cols, clin_stats, train=False)

    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True
    )

    # --- model that accepts 5ch image + clinical vector ---
    model = MobileNetV3L_5ch_Clinical(clin_dim=clin_dim, out_dim=NUM_POINTS, pretrained=True).to(
        DEVICE
    )
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    ckpt = os.path.join(CHECK_DIR, "best_mobilenetv3l_ROI_ODOC_CDR.pth")
    stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA, ckpt_path=ckpt)

    for epoch in range(1, EPOCHS + 1):
        tr = run_epoch(model, train_dl, opt)  # should read (x5, x_clin, y) inside
        va = run_epoch(model, val_dl, None)

        print(
            f"\nEpoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f}  train_pMAE={tr['pointwise_mae']:.3f}  train_MS={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f}  val_pMAE={va['pointwise_mae']:.3f}  val_MS={va['ms_mae']:.3f}"
        )

        # step LR on validation pMAE
        sched.step(va["pointwise_mae"])

        # save best & decide stopping
        improved = va["pointwise_mae"] < stopper.best - MIN_DELTA
        should_stop = stopper.step(va["pointwise_mae"], model, epoch_meta={"epoch": epoch})
        if improved:
            print(f"\n ✅ Saved BEST → {ckpt} (pMAE={stopper.best:.3f})")

        if should_stop:
            print(f"\nEarly stopping at epoch {epoch} (best val pMAE={stopper.best:.3f})")
            break

    # load best before returning
    state = torch.load(ckpt, map_location=DEVICE)
    model.load_state_dict(state["model"])
    return model, train_dl, val_dl, image_col, vf_cols, ckpt, clin_dim


# run training and expose globals
model_odoc, train_dl_odoc, val_dl_odoc, image_col_odoc, vf_cols_odoc, CKPT_ODOC, CLIN_DIM_OD = (
    train_resnet50_roi_odoc_with_cdr(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01)
)

[OK] image column: Corresponding CFP
[OK] 59 VF cols: ['VF1', 'VF2', 'VF3', 'VF4', 'VF5'] ... ['VF55', 'VF56', 'VF57', 'VF58', 'VF59']

Epoch 01 | train_loss=482.6727  train_pMAE=20.338  train_MS=20.182 || val_loss=450.2610  val_pMAE=19.408  val_MS=19.210

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC_CDR.pth (pMAE=19.408)

Epoch 02 | train_loss=179.4450  train_pMAE=11.041  train_MS=8.822 || val_loss=135.4271  val_pMAE=10.412  val_MS=9.022

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC_CDR.pth (pMAE=10.412)

Epoch 03 | train_loss=60.1799  train_pMAE=6.198  train_MS=3.536 || val_loss=49.0235  val_pMAE=5.135  val_MS=3.822

 ✅ Saved BEST → ./checkpoints/best_mobilenetv3l_ROI_ODOC_CDR.pth (pMAE=5.135)

Epoch 04 | train_loss=50.5560  train_pMAE=5.685  train_MS=2.990 || val_loss=48.7353  val_pMAE=5.167  val_MS=3.908

Epoch 05 | train_loss=47.1727  train_pMAE=5.445  train_MS=2.726 || val_loss=48.8136  val_pMAE=5.586  val_MS=3.948

Epoch 06 | train_loss=46.1599  train_pMA

In [38]:
# ---------------------------------------------------------
# ✅ Corrected Cell 10: Reload best model & evaluate
# ---------------------------------------------------------

# Use the model already trained in Cell 9
# model_odoc → best model returned by train_resnet50_roi_odoc_with_cdr
# val_dl_odoc → validation dataloader
# CKPT_ODOC → checkpoint path
# CLIN_DIM_OD → clinical feature dimension
# image_col_odoc, vf_cols_odoc already created

best_roi_clin = model_odoc  # model already returned from training
best_roi_clin.eval()

# Load the best checkpoint
state = torch.load(CKPT_ODOC, map_location=DEVICE)
best_roi_clin.load_state_dict(state["model"])
best_roi_clin.eval()

# Collect predictions
all_true, all_pred = [], []
with torch.no_grad():
    for x_img, x_clin, y in val_dl_odoc:
        x_img = x_img.to(DEVICE)
        x_clin = x_clin.to(DEVICE)
        y = y.to(DEVICE)

        p = best_roi_clin(x_img, x_clin)
        all_true.append(y.cpu())
        all_pred.append(p.cpu())

y_true = torch.cat(all_true, dim=0)
y_pred = torch.cat(all_pred, dim=0)

print("✅ Collected predictions:", y_true.shape, y_pred.shape)


# --- Metrics ---
def rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


def mae_value(a, b):
    return float(torch.mean(torch.abs(a - b)))


def r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


# Pointwise metrics (VF1..VF59 flattened)
pw_true = y_true.reshape(-1)
pw_pred = y_pred.reshape(-1)

print("\n== ROI+Clinical (with CDR): POINTWISE ==")
print(
    f"RMSE: {rmse(pw_true, pw_pred):.4f} | "
    f"MAE: {mae_value(pw_true, pw_pred):.4f} | "
    f"R²: {r2(pw_true, pw_pred):.4f}"
)

# Mean Sensitivity metrics
t_mean = y_true.mean(dim=1)
p_mean = y_pred.mean(dim=1)

print("\n== ROI+Clinical (with CDR): MEAN SENSITIVITY ==")
print(
    f"RMSE: {rmse(t_mean, p_mean):.4f} | "
    f"MAE: {mae_value(t_mean, p_mean):.4f} | "
    f"R²: {r2(t_mean, p_mean):.4f}"
)

✅ Collected predictions: torch.Size([127, 59]) torch.Size([127, 59])

== ROI+Clinical (with CDR): POINTWISE ==
RMSE: 6.4895 | MAE: 4.6358 | R²: 0.4951

== ROI+Clinical (with CDR): MEAN SENSITIVITY ==
RMSE: 4.2246 | MAE: 3.0022 | R²: 0.5021


# 5. ROI+ OD/OC + Clinical MobileNetV3L

In [39]:
# ==========================
# FULL FUSION: ROI + OD/OC + Clinical   →  VF (59)
# ==========================

import json
import os
import random
import re
from typing import Dict, List, Tuple

import numpy as np
from PIL import Image, ImageDraw
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms

# ----------------- PATHS & CONFIG -----------------

os.makedirs(CHECK_DIR, exist_ok=True)

SEED = 42
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 80
LR = 1e-4
WD = 1e-4
NUM_WORKERS = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_POINTS = 59  # VF1..VF59

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [40]:
# Clinical features for fusion (present or computed)
CLIN_NUM_COLS = ["AGE", "IOP_y", "CDR"]
CLIN_CAT_COLS = ["GENDER"]
IMAGE_COLS_CANDIDATES = [
    "Corresponding CFP",
    "image",
    "image_name",
    "img",
    "image_path",
    "filename",
    "file",
]


# ----------------- CSV UTILS -----------------
def read_csv(fp: str) -> List[Dict[str, str]]:
    with open(fp, "r", encoding="utf-8") as f:
        lines = [l.rstrip("\n") for l in f if l.strip()]
    header = [h.strip() for h in lines[0].split(",")]
    rows = []
    for line in lines[1:]:
        parts = [p.strip() for p in line.split(",")]
        parts = (parts + [""] * len(header))[: len(header)]
        rows.append({h: parts[i] for i, h in enumerate(header)})
    return rows


def detect_columns(rows: List[Dict[str, str]]) -> Tuple[str, List[str]]:
    if not rows:
        raise ValueError("CSV has no rows.")
    cols = list(rows[0].keys())

    # image column (prefer Corresponding CFP)
    image_col = None
    for c in IMAGE_COLS_CANDIDATES:
        if c in cols:
            image_col = c
            break
    if image_col is None:
        for c in cols:
            if any(
                rows[i][c].lower().endswith((".jpg", ".jpeg", ".png"))
                for i in range(min(10, len(rows)))
            ):
                image_col = c
                break
    if image_col is None:
        raise ValueError("Could not find image filename column.")

    # prefer explicit VF1..VF59
    vf_pref = [f"VF{i}" for i in range(1, 60)]
    if all(c in cols for c in vf_pref):
        return image_col, vf_pref

    # fallback: detect numeric columns (exclude clinical & image & VF0/VF60)
    excluded = set([image_col] + CLIN_NUM_COLS + CLIN_CAT_COLS + ["VF0", "VF60"])
    cand = []
    for c in cols:
        if c in excluded:
            continue
        ok = True
        for r in rows[: min(20, len(rows))]:
            v = r[c].strip()
            if v == "":
                ok = False
                break
            try:
                float(v)
            except:
                ok = False
                break
        if ok:
            cand.append(c)
    if len(cand) < NUM_POINTS:
        raise ValueError(f"Not enough numeric VF columns; found {len(cand)}, need {NUM_POINTS}.")

    def keyfun(name):
        m = re.search(r"(\d+)$", name)
        return (name, int(m.group(1)) if m else 9999)

    cand = sorted(cand, key=keyfun)[:NUM_POINTS]
    return image_col, cand

In [41]:
# ----------------- OD/OC POLYGONS → CDR & MASKS -----------------
OD_LABELS = {"od", "disc", "optic_disc", "optic-disc", "optic disc"}
OC_LABELS = {"oc", "cup", "optic_cup", "optic-cup", "optic cup"}


def _read_labelme_polys(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    od_polys, oc_polys = [], []
    for sh in data.get("shapes", []):
        lab = str(sh.get("label", "")).strip().lower()
        pts = sh.get("points", [])
        if len(pts) < 3:
            continue
        if lab in OD_LABELS:
            od_polys.append(pts)
        if lab in OC_LABELS:
            oc_polys.append(pts)

    # keep polygon with max vertical height
    def vheight(poly):
        ys = [p[1] for p in poly]
        return (max(ys) - min(ys)) if ys else 0.0

    if len(od_polys) > 1:
        od_polys = [max(od_polys, key=vheight)]
    if len(oc_polys) > 1:
        oc_polys = [max(oc_polys, key=vheight)]
    return od_polys, oc_polys


def _guess_json(img_name: str):
    base = os.path.splitext(os.path.basename(img_name))[0]
    for ext in (".json", ".JSON", ".Json"):
        p = os.path.join(JSON_DIR, base + ext)
        if os.path.exists(p):
            return p
    return ""


def compute_cdr_from_json(img_name: str):
    """CDR = vertical cup height / vertical disc height (from polygons)."""
    jpath = _guess_json(img_name)
    if not jpath:
        return None
    try:
        od_polys, oc_polys = _read_labelme_polys(jpath)
        if not od_polys or not oc_polys:
            return None

        def vheight(poly):
            ys = [float(y) for _, y in poly]
            return max(ys) - min(ys) if ys else 0.0

        h_od = vheight(od_polys[0])
        h_oc = vheight(oc_polys[0])
        if h_od <= 0:
            return None
        return float(h_oc / h_od)
    except Exception as e:
        print(f"[WARN] CDR parse failed for {img_name}: {e}")
        return None


def build_masks_from_labelme(img_pil: Image.Image, img_name: str, out_size: int):
    """Binary OD/OC masks (L mode), resized to out_size."""
    W, H = img_pil.size
    od_mask = Image.new("L", (W, H), 0)
    oc_mask = Image.new("L", (W, H), 0)
    jpath = _guess_json(img_name)
    if jpath:
        try:
            od_polys, oc_polys = _read_labelme_polys(jpath)
            d_od = ImageDraw.Draw(od_mask)
            d_oc = ImageDraw.Draw(oc_mask)
            for poly in od_polys:
                d_od.polygon(poly, outline=1, fill=1)
            for poly in oc_polys:
                d_oc.polygon(poly, outline=1, fill=1)
        except Exception as e:
            print(f"[WARN] mask parse {jpath}: {e}")
    od_mask = od_mask.resize((out_size, out_size), resample=Image.NEAREST)
    oc_mask = oc_mask.resize((out_size, out_size), resample=Image.NEAREST)
    return od_mask, oc_mask

In [42]:
# ----------------- AUGMENT ROWS WITH CDR  -----------------
def augment_rows_with_cdr_psd(rows, image_col, vf_cols):
    augmented = []
    miss_cdr = miss_psd = 0
    for r in rows:
        r2 = dict(r)
        cdr = compute_cdr_from_json(r2[image_col])
        if cdr is None:
            miss_cdr += 1
        r2["CDR"] = cdr

        augmented.append(r2)
    return augmented

In [43]:
# ----------------- CLINICAL PREPROCESS -----------------
def to_float(x):
    x = str(x).strip()
    if x == "":
        return None
    try:
        return float(x)
    except:
        return None


def fit_clinical_stats(rows, clin_num_cols):
    stats = {}
    for c in clin_num_cols:
        vals = [to_float(r.get(c, "")) for r in rows]
        vals = [v for v in vals if v is not None]
        mean = np.mean(vals) if vals else 0.0
        std = np.std(vals) if vals else 1.0
        if std == 0:
            std = 1.0
        stats[c] = (float(mean), float(std))
    return stats


def encode_gender(x):
    s = str(x).strip().lower()
    if s in ("m", "male", "man"):
        return 1.0
    if s in ("f", "female", "woman"):
        return 0.0
    return 0.5  # unknown/other


def build_clinical_vector(r, stats):
    vec = []
    for c in CLIN_NUM_COLS:
        v = to_float(r.get(c, ""))
        mean, std = stats[c]
        v = mean if v is None else v
        v = (v - mean) / std
        vec.append(v)
    for c in CLIN_CAT_COLS:
        if c == "GENDER":
            vec.append(encode_gender(r.get(c, "")))
        else:
            vec.append(0.0)
    return torch.tensor(vec, dtype=torch.float32)

In [44]:
# ----------------- DATASET: 5-CH ROI + CLINICAL -----------------
class ROI_ODOC_Clinical_Dataset(Dataset):
    def __init__(
        self, rows, image_col, vf_cols, clin_stats, train=True, img_root=ROI_DIR, img_size=IMG_SIZE
    ):
        self.rows = rows
        self.image_col = image_col
        self.vf_cols = vf_cols
        self.clin_stats = clin_stats
        self.train = train
        self.img_root = img_root
        self.img_size = img_size

        norm = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        aug = []
        if train:
            aug = [
                transforms.RandomHorizontalFlip(0.5),
                transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
            ]
        self.rgb_tf = transforms.Compose(
            [transforms.Resize((img_size, img_size)), transforms.ToTensor(), *aug, norm]
        )
        self.mask_tf = transforms.ToTensor()  # L→(1,H,W) float {0,1}

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        r = self.rows[idx]
        fn = r[self.image_col]
        path = fn if os.path.isabs(fn) else os.path.join(self.img_root, fn)

        img = Image.open(path).convert("RGB")
        od_img, oc_img = build_masks_from_labelme(img, fn, self.img_size)

        x_rgb = self.rgb_tf(img)  # (3,H,W)
        x_od = self.mask_tf(od_img)  # (1,H,W)
        x_oc = self.mask_tf(oc_img)  # (1,H,W)
        x5 = torch.cat([x_rgb, x_od, x_oc], dim=0)  # (5,H,W)

        x_clin = build_clinical_vector(r, self.clin_stats)  # (clin_dim,)
        y = torch.tensor([float(r[c]) for c in self.vf_cols], dtype=torch.float32)  # (59,)

        return x5, x_clin, y

In [45]:
# ----------------- MODEL: 5-CH RESNET50 + CLINICAL MLP (FUSION) -----------------
class MobileNetV3L_5ch_Clinical(nn.Module):
    def __init__(self, clin_dim, out_dim=59, pretrained=True):
        super().__init__()
        base = models.mobilenet_v3_large(
            weights=models.MobileNet_V3_Large_Weights.IMAGENET1K_V2 if pretrained else None
        )
        # adapt Conv2dNormActivation: 3→5 channels (init extra channels with mean RGB weights)
        old = base.features[0][0]
        new = nn.Conv2d(
            5,
            old.out_channels,
            kernel_size=old.kernel_size,
            stride=old.stride,
            padding=old.padding,
            bias=(old.bias is not None),
        )
        with torch.no_grad():
            new.weight[:, :3, :, :] = old.weight
            mean_w = old.weight.mean(dim=1, keepdim=True)
            new.weight[:, 3:5, :, :] = mean_w.repeat(1, 2, 1, 1)
            if old.bias is not None:
                new.bias.copy_(old.bias)
        base.features[0][0] = new

        in_f = base.classifier[-1].in_features
        base.classifier[-1] = nn.Identity()
        self.backbone = base

        self.img_head = nn.Sequential(
            nn.Linear(in_f, 512), nn.ReLU(inplace=True), nn.Dropout(0.25)
        )
        self.clin_head = nn.Sequential(
            nn.Linear(clin_dim, 64),
            nn.ReLU(inplace=True),
            nn.Dropout(0.10),
            nn.Linear(64, 64),
            nn.ReLU(inplace=True),
        )
        self.fuse = nn.Sequential(
            nn.Linear(512 + 64, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
            nn.Linear(256, out_dim),
        )

    def forward(self, x5, xclin):
        f = self.backbone(x5)  # (B, 2048)
        f = self.img_head(f)  # (B, 512)
        g = self.clin_head(xclin)  # (B, 64)
        z = torch.cat([f, g], dim=1)
        out = self.fuse(z)  # (B, 59)
        return out

In [46]:
# ----------------- METRICS & EPOCH LOOP (same logic as your earlier code) -----------------
@torch.no_grad()
def mae(pred, true):  # pointwise MAE
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def ms_mae(pred, true):  # MS per-sample then MAE
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def run_epoch(model, loader, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0
    for x5, xclin, y in loader:
        x5, xclin, y = x5.to(DEVICE), xclin.to(DEVICE), y.to(DEVICE)

        if train:
            opt.zero_grad()
        pred = model(x5, xclin)
        loss = crit(pred, y)
        if train:
            loss.backward()
            opt.step()

        bs = x5.size(0)
        n += bs
        loss_sum += loss.item() * bs
        pmae_sum += mae(pred, y) * bs
        msmae_sum += ms_mae(pred, y) * bs

    return {"loss": loss_sum / n, "pointwise_mae": pmae_sum / n, "ms_mae": msmae_sum / n}

In [47]:
# ----------------- TRAIN -----------------
def train_full_fusion(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01):
    rows = read_csv(CSV_PATH)
    image_col, vf_cols = detect_columns(rows)
    print(f"[OK] image column: {image_col}")
    print(f"[OK] {len(vf_cols)} VF cols detected: {vf_cols[:5]} ... {vf_cols[-5:]}")

    # augment with CDR BEFORE splitting
    rows = augment_rows_with_cdr_psd(rows, image_col, vf_cols)

    import random

    random.shuffle(rows)
    N = len(rows)
    n_train = int(0.8 * N)
    train_rows = rows[:n_train]
    val_rows = rows[n_train:]

    # fit clinical stats on train (handles None via mean imputation)
    clin_stats = fit_clinical_stats(train_rows, CLIN_NUM_COLS)
    clin_dim = len(CLIN_NUM_COLS) + len(CLIN_CAT_COLS)

    # datasets / loaders
    train_ds = ROI_ODOC_Clinical_Dataset(train_rows, image_col, vf_cols, clin_stats, train=True)
    val_ds = ROI_ODOC_Clinical_Dataset(val_rows, image_col, vf_cols, clin_stats, train=False)

    train_dl = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True
    )
    val_dl = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True
    )

    # model / optimizer
    model = MobileNetV3L_5ch_Clinical(clin_dim=clin_dim, out_dim=NUM_POINTS, pretrained=True).to(
        DEVICE
    )
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    best = float("inf")
    no_improve = 0
    ckpt = os.path.join(CHECK_DIR, "best_full_fusion_ROI_ODOC_CLIN1.pth")

    for epoch in range(1, EPOCHS + 1):
        tr = run_epoch(model, train_dl, opt)
        va = run_epoch(model, val_dl, None)

        sched.step(va["pointwise_mae"])

        print(
            f"Epoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f} train_pMAE={tr['pointwise_mae']:.3f} train_MS={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f} val_pMAE={va['pointwise_mae']:.3f} val_MS={va['ms_mae']:.3f}"
        )

        # early stopping on val pointwise MAE
        if va["pointwise_mae"] < best - MIN_DELTA:
            best = va["pointwise_mae"]
            no_improve = 0
            torch.save(
                {"epoch": epoch, "model": model.state_dict(), "val_pointwise_mae": best}, ckpt
            )
            print(f"  -> saved new best to {ckpt} (val pMAE={best:.3f})")
        else:
            no_improve += 1
            if no_improve > PATIENCE:
                print(f"Early stopping at epoch {epoch} (best val pMAE={best:.3f})")
                break

    # expose objects / paths like before
    return model, train_dl, val_dl, image_col, vf_cols, ckpt, clin_dim


# run training and expose globals
model_full, train_dl_full, val_dl_full, image_col_full, vf_cols_full, CKPT_FULL, CLIN_DIM = (
    train_full_fusion(EPOCHS=80, PATIENCE=10, MIN_DELTA=0.01)
)

[OK] image column: Corresponding CFP
[OK] 59 VF cols detected: ['VF1', 'VF2', 'VF3', 'VF4', 'VF5'] ... ['VF55', 'VF56', 'VF57', 'VF58', 'VF59']
Epoch 01 | train_loss=482.3802 train_pMAE=20.331 train_MS=20.176 || val_loss=442.9284 val_pMAE=19.239 val_MS=19.030
  -> saved new best to ./checkpoints/best_full_fusion_ROI_ODOC_CLIN1.pth (val pMAE=19.239)
Epoch 02 | train_loss=175.4235 train_pMAE=10.899 train_MS=8.607 || val_loss=86.2756 val_pMAE=8.033 val_MS=6.457
  -> saved new best to ./checkpoints/best_full_fusion_ROI_ODOC_CLIN1.pth (val pMAE=8.033)
Epoch 03 | train_loss=60.6013 train_pMAE=6.258 train_MS=3.625 || val_loss=61.5353 val_pMAE=5.596 val_MS=4.567
  -> saved new best to ./checkpoints/best_full_fusion_ROI_ODOC_CLIN1.pth (val pMAE=5.596)
Epoch 04 | train_loss=50.2914 train_pMAE=5.644 train_MS=2.971 || val_loss=49.5158 val_pMAE=5.480 val_MS=3.975
  -> saved new best to ./checkpoints/best_full_fusion_ROI_ODOC_CLIN1.pth (val pMAE=5.480)
Epoch 05 | train_loss=47.3519 train_pMAE=5.442 

In [48]:
# -------EVALUATE BEST + PRINT METRICS -----------------
# reload best
best_full = MobileNetV3L_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=NUM_POINTS, pretrained=False).to(
    DEVICE
)
state = torch.load(CKPT_FULL, map_location=DEVICE)
best_full.load_state_dict(state["model"])
best_full.eval()

# collect preds
all_true, all_pred = [], []
with torch.no_grad():
    for x5, xclin, y in val_dl_full:
        x5, xclin, y = x5.to(DEVICE), xclin.to(DEVICE), y.to(DEVICE)
        p = best_full(x5, xclin)
        all_true.append(y.cpu())
        all_pred.append(p.cpu())

y_true = torch.cat(all_true, dim=0)
y_pred = torch.cat(all_pred, dim=0)
print("✅ Collected predictions:", y_true.shape, y_pred.shape)


# paper-style metrics
def rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


def mae_value(a, b):
    return float(torch.mean(torch.abs(a - b)))


def r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


pw_true, pw_pred = y_true.reshape(-1), y_pred.reshape(-1)
print("\n== FULL FUSION (ROI + OD/OC + Clinical ): POINTWISE ==")
print(
    f"RMSE: {rmse(pw_true, pw_pred):.4f} | MAE: {mae_value(pw_true, pw_pred):.4f} | R²: {r2(pw_true, pw_pred):.4f}"
)

t_mean, p_mean = y_true.mean(dim=1), y_pred.mean(dim=1)
print("== FULL FUSION: POINTWISE-MEAN / MS ==")
print(
    f"RMSE: {rmse(t_mean, p_mean):.4f} | MAE: {mae_value(t_mean, p_mean):.4f} | R²: {r2(t_mean, p_mean):.4f}\n"
)

✅ Collected predictions: torch.Size([127, 59]) torch.Size([127, 59])

== FULL FUSION (ROI + OD/OC + Clinical ): POINTWISE ==
RMSE: 6.1533 | MAE: 4.3689 | R²: 0.5461
== FULL FUSION: POINTWISE-MEAN / MS ==
RMSE: 3.8140 | MAE: 2.5977 | R²: 0.5942



# 6. ROI + OD/OC + Clinical SWIN-T

In [49]:
# ==========================
# SWIN-T for FULL FUSION (ROI + OD/OC + Clinical) → VF (59)
# Reuses train_dl_full, val_dl_full, and CLIN_DIM from your existing setup
# ==========================
import os

import torch
import torch.nn as nn
from torchvision import models


# ---- metrics used inside the loop
@torch.no_grad()
def _mae(pred, true):
    return torch.mean(torch.abs(pred - true)).item()


@torch.no_grad()
def _ms_mae(pred, true):
    pm = pred.mean(dim=1)
    tm = true.mean(dim=1)
    return torch.mean(torch.abs(pm - tm)).item()


def _run_epoch_ff(model, loader, device, opt=None):
    train = opt is not None
    model.train() if train else model.eval()
    crit = nn.MSELoss()

    n = 0
    loss_sum = 0.0
    pmae_sum = 0.0
    msmae_sum = 0.0

    for batch in loader:
        # x5: (B,5,H,W) ; x_clin: (B, CLIN_DIM) ; y: (B,59)
        x5, x_clin, y = batch
        x5 = x5.to(device, non_blocking=True)
        x_clin = x_clin.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        if train:
            opt.zero_grad(set_to_none=True)
        pred = model(x5, x_clin)
        loss = crit(pred, y)
        if train:
            loss.backward()
            opt.step()

        bs = x5.size(0)
        n += bs
        loss_sum += loss.item() * bs
        pmae_sum += _mae(pred, y) * bs
        msmae_sum += _ms_mae(pred, y) * bs

    return {
        "loss": loss_sum / max(1, n),
        "pointwise_mae": pmae_sum / max(1, n),
        "ms_mae": msmae_sum / max(1, n),
    }


# ---- Model: Swin-T (RGB) + small CNN for OD/OC masks + clinical MLP → fused regressor
class SwinT_5ch_Clinical(nn.Module):
    """
    Uses Swin-T on RGB (x5[:, :3]),
    Encodes OD/OC masks (x5[:, 3:5]) with a lightweight CNN,
    Encodes clinical features with an MLP,
    Concats [swin_feat, mask_feat, clin_feat] → MLP → 59-D VF regression.
    """

    def __init__(self, clin_dim, out_dim=59, pretrained=True, dropout=0.25):
        super().__init__()
        # Swin-T backbone (ImageNet weights) → (B, feat_dim)
        self.backbone = models.swin_t(
            weights=models.Swin_T_Weights.IMAGENET1K_V1 if pretrained else None
        )
        feat_dim = self.backbone.head.in_features
        self.backbone.head = nn.Identity()  # keep pooled features

        # OD/OC mask encoder (2xHxW → vector)
        self.mask_enc = nn.Sequential(
            nn.Conv2d(2, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),  # (B,128)
        )
        mask_dim = 128

        # clinical encoder
        self.clin_head = nn.Sequential(
            nn.Linear(clin_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(128, 64),
            nn.ReLU(inplace=True),
        )
        clin_out = 64

        # fusion & regressor
        fused_in = feat_dim + mask_dim + clin_out
        self.regressor = nn.Sequential(
            nn.Linear(fused_in, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(512, out_dim),
        )

    def forward(self, x5, xclin):
        x_rgb = x5[:, :3, :, :]
        x_msk = x5[:, 3:, :, :]
        f_rgb = self.backbone(x_rgb)
        f_msk = self.mask_enc(x_msk)
        f_cln = self.clin_head(xclin)
        z = torch.cat([f_rgb, f_msk, f_cln], dim=1)
        return self.regressor(z)


# ---- Early Stopping helper
class EarlyStopper:
    def __init__(self, patience=10, min_delta=0.01, ckpt_path=None):
        self.patience = patience
        self.min_delta = float(min_delta)
        self.ckpt_path = ckpt_path
        self.best = float("inf")
        self.bad_epochs = 0

    def step(self, current, model, epoch_meta=None):
        if current < self.best - self.min_delta:
            self.best = current
            self.bad_epochs = 0
            if self.ckpt_path:
                torch.save(
                    {
                        "model": model.state_dict(),
                        "val_pointwise_mae": self.best,
                        **(epoch_meta or {}),
                    },
                    self.ckpt_path,
                )
            return False  # do not stop
        else:
            self.bad_epochs += 1
            return self.bad_epochs > self.patience  # stop if exceeded


# ---- Train
def train_full_fusion_swinT(
    train_loader,
    val_loader,
    clin_dim,
    out_dim=59,
    epochs=80,
    lr=1e-4,
    wd=1e-4,
    device="cuda",
    pretrained=True,
    patience=10,
    min_delta=0.01,
    ckpt_dir="./checkpoints",
):
    os.makedirs(ckpt_dir, exist_ok=True)
    ckpt = os.path.join(ckpt_dir, "best_full_fusion_swint.pth")

    model = SwinT_5ch_Clinical(clin_dim=clin_dim, out_dim=out_dim, pretrained=pretrained).to(
        device
    )
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    # reduce LR when val pMAE plateaus
    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", patience=3, factor=0.5)

    stopper = EarlyStopper(patience=patience, min_delta=min_delta, ckpt_path=ckpt)

    for epoch in range(1, epochs + 1):
        tr = _run_epoch_ff(model, train_loader, device, opt=opt)
        va = _run_epoch_ff(model, val_loader, device, opt=None)
        sched.step(va["pointwise_mae"])

        print(
            f"Epoch {epoch:02d} | "
            f"train_loss={tr['loss']:.4f} train_pMAE={tr['pointwise_mae']:.3f} train_MS={tr['ms_mae']:.3f} || "
            f"val_loss={va['loss']:.4f} val_pMAE={va['pointwise_mae']:.3f} val_MS={va['ms_mae']:.3f}"
        )

        should_stop = stopper.step(va["pointwise_mae"], model, epoch_meta={"epoch": epoch})
        if should_stop:
            print(f"Early stopping at epoch {epoch} (best val pMAE={stopper.best:.3f})")
            break

    # load best
    state = torch.load(ckpt, map_location=device)
    model.load_state_dict(state["model"])
    return model, ckpt

In [50]:
# # ===== RUN: SWIN-T FULL-FUSION TRAIN + EVAL =====
# from math import sqrt
# import os

# import torch

# # ---- config (tweak if you like)
# EPOCHS = 80
# LR = 1e-4
# WD = 1e-4
# PATIENCE = 10
# MIN_DELTA = 0.01
# CHECK_DIR = "./checkpoints"
# DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# os.makedirs(CHECK_DIR, exist_ok=True)

# # ---- quick checks
# assert "train_dl_full" in globals() and "val_dl_full" in globals(), (
#     "Missing loaders. Make sure you created train_dl_full and val_dl_full."
# )
# assert "CLIN_DIM" in globals(), "Missing CLIN_DIM."
# OUT_DIM = 59 if "NUM_POINTS" not in globals() else int(NUM_POINTS)

# # ---- train
# swinT_model, SWINT_CKPT = train_full_fusion_swinT(
#     train_loader=train_dl_full,
#     val_loader=val_dl_full,
#     clin_dim=CLIN_DIM,
#     out_dim=OUT_DIM,
#     epochs=EPOCHS,
#     lr=LR,
#     wd=WD,
#     device=DEVICE,
#     patience=PATIENCE,
#     min_delta=MIN_DELTA,
#     ckpt_dir=CHECK_DIR,
#     pretrained=True,
# )

# print(f"\n[OK] Training finished. Best checkpoint: {SWINT_CKPT}")


# # ---- evaluation helpers (pointwise + mean-of-points/“MS”)
# @torch.no_grad()
# def _rmse(a, b):
#     return float(torch.sqrt(torch.mean((a - b) ** 2)))


# @torch.no_grad()
# def _mae_scalar(a, b):
#     return float(torch.mean(torch.abs(a - b)))


# @torch.no_grad()
# def _r2(a, b):
#     ss_res = torch.sum((a - b) ** 2)
#     ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
#     return float(1.0 - ss_res / ss_tot)


# # ---- load best and evaluate on val
# best_swinT = SwinT_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=OUT_DIM, pretrained=False).to(DEVICE)
# state = torch.load(SWINT_CKPT, map_location=DEVICE)
# best_swinT.load_state_dict(state["model"])
# best_swinT.eval()

# y_true, y_pred = [], []
# with torch.no_grad():
#     for x5, xclin, y in val_dl_full:
#         x5 = x5.to(DEVICE)
#         xclin = xclin.to(DEVICE)
#         y = y.to(DEVICE)
#         p = best_swinT(x5, xclin)
#         y_true.append(y.cpu())
#         y_pred.append(p.cpu())

# y_true = torch.cat(y_true, dim=0)
# y_pred = torch.cat(y_pred, dim=0)

# # pointwise metrics
# pw_true, pw_pred = y_true.reshape(-1), y_pred.reshape(-1)
# print("\n== SWIN-T FULL FUSION: POINTWISE ==")
# print(
#     f"RMSE: {_rmse(pw_true, pw_pred):.4f} | MAE: {_mae_scalar(pw_true, pw_pred):.4f} | R²: {_r2(pw_true, pw_pred):.4f}"
# )

# # mean-of-points (“MS”) metrics
# t_mean, p_mean = y_true.mean(dim=1), y_pred.mean(dim=1)
# print("== SWIN-T FULL FUSION: MEAN (MS) ==")
# print(
#     f"RMSE: {_rmse(t_mean, p_mean):.4f} | MAE: {_mae_scalar(t_mean, p_mean):.4f} | R²: {_r2(t_mean, p_mean):.4f}"
# )

# 7. Weighted Averaging Ensemble Technique (MobileNetV3L + SWIN-T)

In [51]:
# ==========================
# Simple Ensemble: MobileNetV3L_5ch_Clinical + Swint_5ch_Clinical
# Uses the SAME loaders (val_dl_full / test_dl_full) as your full-fusion setup
# ==========================
import os

import torch
import torch.nn as nn

# ---- your checkpoint paths
RESNET_CKPT_PATH = CHECK_DIR + "/best_full_fusion_ROI_ODOC_CLIN1.pth"
SWIN_CKPT_PATH = CHECK_DIR + "/best_full_fusion_swint.pth"


# ---- small metric helpers (names won't clash with your existing ones)
@torch.no_grad()
def _rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


@torch.no_grad()
def _mae(a, b):
    return float(torch.mean(torch.abs(a - b)))


@torch.no_grad()
def _r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


def _load_model_states():
    # Build the exact architectures (no pretrain needed when loading checkpoints)
    resnet = MobileNetV3L_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=NUM_POINTS, pretrained=False).to(
        DEVICE
    )
    swin = SwinT_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=NUM_POINTS, pretrained=False).to(DEVICE)

    r_state = torch.load(RESNET_CKPT_PATH, map_location=DEVICE)
    s_state = torch.load(SWIN_CKPT_PATH, map_location=DEVICE)

    resnet.load_state_dict(r_state["model"])
    resnet.eval()
    swin.load_state_dict(s_state["model"])
    swin.eval()
    return resnet, swin


@torch.no_grad()
def ensemble_eval(loader, alpha=0.5):
    """
    alpha: weight for SWIN (0..1). 0.5 = simple average
    pred = (1-alpha)*resnet + alpha*swin
    """
    assert 0.0 <= alpha <= 1.0
    resnet, swin = _load_model_states()

    y_true_chunks, y_pred_res_chunks, y_pred_swin_chunks, y_pred_ens_chunks = [], [], [], []

    for x5, xclin, y in loader:
        x5 = x5.to(DEVICE, non_blocking=True)
        xcli = xclin.to(DEVICE, non_blocking=True)
        y = y.to(DEVICE, non_blocking=True)

        p_r = resnet(x5, xcli)  # (B, 59)
        p_s = swin(x5, xcli)  # (B, 59)
        p_e = (1.0 - alpha) * p_r + alpha * p_s

        y_true_chunks.append(y.cpu())
        y_pred_res_chunks.append(p_r.cpu())
        y_pred_swin_chunks.append(p_s.cpu())
        y_pred_ens_chunks.append(p_e.cpu())

    y_true = torch.cat(y_true_chunks, dim=0)
    p_res = torch.cat(y_pred_res_chunks, dim=0)
    p_swin = torch.cat(y_pred_swin_chunks, dim=0)
    p_ens = torch.cat(y_pred_ens_chunks, dim=0)

    # pointwise metrics (flatten all 59 points)
    pw_true, pw_res, pw_swin, pw_ens = (
        y_true.reshape(-1),
        p_res.reshape(-1),
        p_swin.reshape(-1),
        p_ens.reshape(-1),
    )

    print("\n== INDIVIDUAL MODELS (pointwise) ==")
    print(
        f"MobileNet  → RMSE: {_rmse(pw_true, pw_res):.4f} | MAE: {_mae(pw_true, pw_res):.4f} | R²: {_r2(pw_true, pw_res):.4f}"
    )
    print(
        f"Swin-B  → RMSE: {_rmse(pw_true, pw_swin):.4f} | MAE: {_mae(pw_true, pw_swin):.4f} | R²: {_r2(pw_true, pw_swin):.4f}"
    )

    print("\n== ENSEMBLE (pointwise) ==")
    print(
        f"Avg (α={alpha:.2f}) → RMSE: {_rmse(pw_true, pw_ens):.4f} | MAE: {_mae(pw_true, pw_ens):.4f} | R²: {_r2(pw_true, pw_ens):.4f}"
    )

    # MS metrics = mean of 59 points per sample
    t_mean, r_mean, s_mean, e_mean = (
        y_true.mean(dim=1),
        p_res.mean(dim=1),
        p_swin.mean(dim=1),
        p_ens.mean(dim=1),
    )

    print("\n== INDIVIDUAL MODELS (MS) ==")
    print(
        f"MobileNet  → RMSE: {_rmse(t_mean, r_mean):.4f} | MAE: {_mae(t_mean, r_mean):.4f} | R²: {_r2(t_mean, r_mean):.4f}"
    )
    print(
        f"Swin-B  → RMSE: {_rmse(t_mean, s_mean):.4f} | MAE: {_mae(t_mean, s_mean):.4f} | R²: {_r2(t_mean, s_mean):.4f}"
    )

    print("\n== ENSEMBLE (MS) ==")
    print(
        f"Avg (α={alpha:.2f}) → RMSE: {_rmse(t_mean, e_mean):.4f} | MAE: {_mae(t_mean, e_mean):.4f} | R²: {_r2(t_mean, e_mean):.4f}\n"
    )

    return {
        "y_true": y_true,
        "pred_resnet": p_res,
        "pred_swin": p_swin,
        "pred_ensemble": p_ens,
    }


# ===== RUN on your existing loaders =====
# Use val set:
assert "val_dl_full" in globals(), "val_dl_full not found. Run your full-fusion data cell first."
_ = ensemble_eval(val_dl_full, alpha=0.5)  # try alpha=0.3, 0.7, etc.

# If you also have test_dl_full:
# assert 'test_dl_full' in globals()
# _ = ensemble_eval(test_dl_full, alpha=0.5)


== INDIVIDUAL MODELS (pointwise) ==
MobileNet  → RMSE: 6.1533 | MAE: 4.3689 | R²: 0.5461
Swin-B  → RMSE: 5.8988 | MAE: 4.1052 | R²: 0.5829

== ENSEMBLE (pointwise) ==
Avg (α=0.50) → RMSE: 5.8362 | MAE: 4.1308 | R²: 0.5917

== INDIVIDUAL MODELS (MS) ==
MobileNet  → RMSE: 3.8140 | MAE: 2.5977 | R²: 0.5942
Swin-B  → RMSE: 3.4719 | MAE: 2.4142 | R²: 0.6637

== ENSEMBLE (MS) ==
Avg (α=0.50) → RMSE: 3.3722 | MAE: 2.3352 | R²: 0.6828



In [52]:
# ==========================
# SIMPLE ENSEMBLE: ResNet (full fusion) + Swin-B (full fusion)
# Grid-search alpha on VAL to weight SWIN higher/lower
# ==========================
import os

import torch
import torch.nn as nn


# --- Metrics (same style you used)
@torch.no_grad()
def _rmse(a, b):
    return float(torch.sqrt(torch.mean((a - b) ** 2)))


@torch.no_grad()
def _mae(a, b):
    return float(torch.mean(torch.abs(a - b)))


@torch.no_grad()
def _r2(a, b):
    ss_res = torch.sum((a - b) ** 2)
    ss_tot = torch.sum((a - torch.mean(a)) ** 2) + 1e-12
    return float(1 - ss_res / ss_tot)


# --- Load both models
RESNET_CKPT = CHECK_DIR + "/best_full_fusion_ROI_ODOC_CLIN1.pth"
SWIN_CKPT = CHECK_DIR + "/best_full_fusion_swint.pth"

resnet = MobileNetV3L_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=NUM_POINTS, pretrained=False).to(
    DEVICE
)
swin = SwinT_5ch_Clinical(clin_dim=CLIN_DIM, out_dim=NUM_POINTS, pretrained=False).to(DEVICE)

resnet.load_state_dict(torch.load(RESNET_CKPT, map_location=DEVICE)["model"])
swin.load_state_dict(torch.load(SWIN_CKPT, map_location=DEVICE)["model"])
resnet.eval()
swin.eval()

# --- Collect full VAL predictions once for speed
y_true_list, pred_resnet_list, pred_swin_list = [], [], []
with torch.no_grad():
    for x5, xclin, y in val_dl_full:
        x5 = x5.to(DEVICE)
        xclin = xclin.to(DEVICE)
        y = y.to(DEVICE)
        pr = resnet(x5, xclin)
        ps = swin(x5, xclin)
        y_true_list.append(y.cpu())
        pred_resnet_list.append(pr.cpu())
        pred_swin_list.append(ps.cpu())

y_true = torch.cat(y_true_list, dim=0)  # (N, 59)
pred_r = torch.cat(pred_resnet_list, dim=0)  # (N, 59)
pred_s = torch.cat(pred_swin_list, dim=0)  # (N, 59)

# --- Grid-search alpha in [0,1] to minimize pointwise MAE (you can switch to MS if preferred)
best_alpha, best_mae = None, float("inf")
for a in [i / 20 for i in range(21)]:  # 0.00, 0.05, ..., 1.00
    ens = a * pred_s + (1 - a) * pred_r
    mae_pw = _mae(ens.reshape(-1), y_true.reshape(-1))
    if mae_pw < best_mae:
        best_mae = mae_pw
        best_alpha = a

# --- Final ensemble metrics with the chosen alpha
ens = best_alpha * pred_s + (1 - best_alpha) * pred_r
pw_true, pw_pred = y_true.reshape(-1), ens.reshape(-1)
print(
    f"\n[Ensemble] best alpha={best_alpha:.2f} (weights: SWIN={best_alpha:.2f}, MOBILENET={(1 - best_alpha):.2f})"
)
print("POINTWISE:")
print(
    f"  RMSE={_rmse(pw_true, pw_pred):.4f} | MAE={_mae(pw_true, pw_pred):.4f} | R²={_r2(pw_true, pw_pred):.4f}"
)
t_mean, p_mean = y_true.mean(dim=1), ens.mean(dim=1)
print("MS (mean sensitivity):")
print(
    f"  RMSE={_rmse(t_mean, p_mean):.4f} | MAE={_mae(t_mean, p_mean):.4f} | R²={_r2(t_mean, p_mean):.4f}"
)

# Save alpha if you want to reuse for test-time ensembling
BEST_ALPHA = best_alpha


[Ensemble] best alpha=0.80 (weights: SWIN=0.80, MOBILENET=0.20)
POINTWISE:
  RMSE=5.8273 | MAE=4.0901 | R²=0.5929
MS (mean sensitivity):
  RMSE=3.3642 | MAE=2.3283 | R²=0.6842


In [53]:
!ls $CHECK_DIR

best_full_fusion_ROI_ODOC_CLIN1.pth  best_mobilenetv3l_original_roi.pth
best_full_fusion_swint.pth	     best_mobilenetv3l_ROI_ODOC_CDR.pth
best_mobilenetv3l_original_cfp.pth   best_mobilenetv3l_ROI_ODOC.pth


In [54]:
import shutil

from IPython.display import FileLink

shutil.make_archive("final_checkpoints_archive", "zip", CHECK_DIR)
FileLink("final_checkpoints_archive.zip")