## AUROC, C-index bootstrap

In [None]:
# === Unified bootstrap across all models ===


import os, sys, warnings
import numpy as np
import pandas as pd
from pathlib import Path

from sklearn.metrics import roc_auc_score
import statsmodels.api as sm
from lifelines import CoxPHFitter
from lifelines.utils import concordance_index

# =========================
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

WORK_DIR  = "/home/hch/dementia"
INFER_DIR = os.path.join(WORK_DIR, "infer_out")
OUT_SUMMARY = os.path.join(INFER_DIR, "summary_metrics_pretty.csv")
os.chdir(WORK_DIR)

N_BOOT = 2000          # 필요시 2000으로
SEED_AUC = 1234
SEED_CIDX = 5678

# =========================
def fmt_p(p):
    if pd.isna(p): return np.nan
    try:
        return "<0.001" if float(p) < 0.001 else float(f"{float(p):.3f}")
    except Exception:
        return p

def safe_auc(y, s):
    try:
        if len(y) > 0 and pd.Series(y).nunique() == 2:
            return float(roc_auc_score(y, s))
    except Exception:
        pass
    return np.nan

def cindex_from_risk(times, risk, events):
    try:
        return concordance_index(event_times=np.asarray(times),
                                 predicted_scores=-np.asarray(risk),
                                 event_observed=np.asarray(events))
    except Exception:
        return np.nan

def progress_print(i, n, tag):
    step = i + 1
    if step == n or step % max(1, n // 20) == 0:  # 5% 간격
        pct = int(round(step / n * 100))
        sys.stdout.write(f"\r[{tag}] {step}/{n} ({pct}%)")
        sys.stdout.flush()
    if step == n:
        sys.stdout.write("\n")

# =========================
det_all  = pd.read_csv("dementia_detection.csv", low_memory=False)
det_test = det_all[det_all["test"] == True].reset_index(drop=True)

surv_all = pd.read_csv("dementia_prediction.csv", low_memory=False).reset_index(drop=True)

def add_common_covs(df):
    out = df.copy()

    if "STDY_DT" in out.columns:
        out["STDY_DT"] = pd.to_datetime(out["STDY_DT"], errors="coerce")

    if "SEXINT" not in out.columns:
        if "SEX" in out.columns:
            out["SEXINT"] = (out["SEX"] == "M").astype(int)
        else:
            out["SEXINT"] = np.nan

    if "EXERCISE_STATUS_bin" not in out.columns:
        if "EXERCISE_STATUS" in out.columns:
            ex = pd.to_numeric(out["EXERCISE_STATUS"], errors="coerce")
            out["EXERCISE_STATUS_bin"] = np.where(ex >= 2, 1, np.where(ex <= 1, 0, np.nan))
        else:
            out["EXERCISE_STATUS_bin"] = np.nan

    for c in ["STDY_AGE", "cholesterol_updated", "sbp", "bmi", "days_diff"]:
        if c in out.columns:
            out[c] = pd.to_numeric(out[c], errors="coerce")
        else:
            out[c] = np.nan
    return out

det_test = add_common_covs(det_test)
surv_all = add_common_covs(surv_all)

def recompute_survival_frame(frame: pd.DataFrame) -> pd.DataFrame:
    fut = frame.copy()
    fut["event"] = (fut["days_diff"] >= 730).astype(int)
    fut = fut[fut["event"].eq(1) | fut["days_diff"].isna()].copy()
    ref = pd.Timestamp("2019-01-01")
    fut["obs_time"] = np.where(
        fut.event == 1,
        (fut["days_diff"] - 730).clip(lower=0),
        (ref - fut["STDY_DT"]).dt.days - 730
    )
    fut.loc[fut.obs_time > 3650, "obs_time"] = 3650
    fut.loc[(fut.event == 1) & (fut.obs_time == 3650), "event"] = 0
    return fut

surv_all = recompute_survival_frame(surv_all)

def calc_caide_napoe(frame: pd.DataFrame) -> pd.Series:
    x = frame.copy()
    for col in ["STDY_AGE","SEXINT","sbp","bmi","cholesterol_updated","EXERCISE_STATUS"]:
        if col in x.columns:
            x[col] = pd.to_numeric(x[col], errors="coerce")
        else:
            x[col] = np.nan
    age = x["STDY_AGE"]
    age_pts = np.select([age < 47, (47 <= age) & (age <= 53), age > 53], [0,3,4], default=np.nan)
    sex_pts = np.where(x["SEXINT"]==1, 1, 0)
    edu_pts = np.zeros(len(x), dtype=float)  # 교육 변수 없으면 0
    sbp_pts = np.where(x["sbp"] >= 140, 2, 0)
    bmi_pts = np.where(x["bmi"] >= 30, 2, 0)
    chol_mmol = x["cholesterol_updated"] * 0.02586
    chol_pts  = np.where(chol_mmol >= 6.5, 2, 0)
    ex = x["EXERCISE_STATUS"]
    pa_pts = np.where(ex >= 2, 0, 1)
    score = age_pts + sex_pts + edu_pts + sbp_pts + bmi_pts + chol_pts + pa_pts
    return pd.Series(score, index=frame.index, name="CAIDE_noAPOE")

def caide_valid_mask(frame: pd.DataFrame) -> pd.Series:
    req = ["STDY_AGE","SEXINT","sbp","bmi","cholesterol_updated",'EXERCISE_STATUS']
    req = ["STDY_AGE"]
    return frame[req].notna().all(axis=1)

det_CAIDE = calc_caide_napoe(det_test)
det_CAIDE_mask = caide_valid_mask(det_test)

surv_CAIDE = calc_caide_napoe(surv_all)
surv_CAIDE_mask = caide_valid_mask(surv_all) & surv_all["obs_time"].notna() & surv_all["event"].notna()

COVARS = ["STDY_AGE","SEXINT","cholesterol_updated","sbp","bmi","EXERCISE_STATUS_bin"]
X_test_base = det_test[COVARS]
test_cov_mask = X_test_base.notna().all(axis=1)

X_fut_base = surv_all[COVARS]
fut_cov_mask = X_fut_base.notna().all(axis=1)
fut_obs_time = surv_all["obs_time"].to_numpy()
fut_event    = surv_all["event"].to_numpy()

# =========================
model_dirs = sorted([p for p in Path(INFER_DIR).glob("*") if p.is_dir() and (p/"test_preds.csv").exists()])

# test: 각 모델의 (idx, label, pred)
test_series = {}   # model_desc -> Series(pred, index=idx)
test_labels = {}   # model_desc -> Series(label, index=idx)

# future: 각 모델의 (idx, pred)
fut_series  = {}   # model_desc -> Series(pred, index=idx)

for mdir in model_dirs:
    desc = mdir.name

    # ---- Test
    test_csv = mdir / "test_preds.csv"
    try:
        tp = pd.read_csv(
            test_csv,
            usecols=["idx","label","pred"],
            dtype={"idx": np.int64, "label": np.float64, "pred": np.float64}
        )
    except Exception as e:
        print(f"[WARN] skip {desc} (bad test_preds.csv): {e}")
        continue

    s_pred = pd.Series(tp["pred"].to_numpy(dtype=float), index=tp["idx"].to_numpy(dtype=np.int64), name=desc)
    s_lab  = pd.Series(tp["label"].to_numpy(dtype=float), index=tp["idx"].to_numpy(dtype=np.int64), name="label")
    test_series[desc] = s_pred
    test_labels[desc] = s_lab

    # ---- Future (optional)
    fut_csv = mdir / "prediction_preds.csv"
    if fut_csv.exists():
        try:
            fp = pd.read_csv(
                fut_csv,
                usecols=["idx","pred"],
                dtype={"idx": np.int64, "pred": np.float64}
            )
            fut_series[desc] = pd.Series(fp["pred"].to_numpy(dtype=float),
                                         index=fp["idx"].to_numpy(dtype=np.int64),
                                         name=desc)
        except Exception as e:
            print(f"[WARN] skip future for {desc}: {e}")

# =========================
all_test_idx_sets = [s.index for s in test_series.values()]
if len(all_test_idx_sets) == 0:
    raise RuntimeError("No valid test predictions found.")

common_test_idx = set(all_test_idx_sets[0])
for idx_set in all_test_idx_sets[1:]:
    common_test_idx &= set(idx_set)

first_model_for_label = next(iter(test_labels))
labels_first = test_labels[first_model_for_label]

common_test_idx &= set(np.where(det_CAIDE_mask.to_numpy())[0])
common_test_idx &= set(labels_first.index[np.isfinite(labels_first.values)])
common_test_idx = np.array(sorted(common_test_idx), dtype=np.int64)

all_fut_idx_sets = [s.index for s in fut_series.values()] if len(fut_series) > 0 else []
common_fut_idx = None
if len(all_fut_idx_sets) > 0:
    common_fut_idx = set(all_fut_idx_sets[0])
    for idx_set in all_fut_idx_sets[1:]:
        common_fut_idx &= set(idx_set)
    common_fut_idx &= set(np.where(surv_CAIDE_mask.to_numpy())[0])
    fin_mask = np.isfinite(fut_obs_time) & np.isfinite(fut_event)
    common_fut_idx &= set(np.where(fin_mask)[0])
    common_fut_idx = np.array(sorted(common_fut_idx), dtype=np.int64)

# =========================
model_names = sorted(test_series.keys())

y_test_all   = labels_first.reindex(common_test_idx).to_numpy(dtype=float)
s_caide_test = det_CAIDE.reindex(common_test_idx).to_numpy(dtype=float)

S_test = np.zeros((len(common_test_idx), len(model_names)), dtype=float)
for j, m in enumerate(model_names):
    S_test[:, j] = test_series[m].reindex(common_test_idx).to_numpy(dtype=float)

have_future = common_fut_idx is not None and len(common_fut_idx) > 1
if have_future:
    t_fut_all = fut_obs_time[common_fut_idx]
    e_fut_all = fut_event[common_fut_idx]
    r_caide_f = surv_CAIDE.reindex(common_fut_idx).to_numpy(dtype=float)

    R_fut = np.zeros((len(common_fut_idx), len(model_names)), dtype=float)
    for j, m in enumerate(model_names):
        if m in fut_series:
            R_fut[:, j] = fut_series[m].reindex(common_fut_idx).to_numpy(dtype=float)
        else:
            R_fut[:, j] = np.nan

# =========================
auc_caide = safe_auc(y_test_all, s_caide_test)
auc_models = [safe_auc(y_test_all, S_test[:, j]) for j in range(len(model_names))]

cidx_caide = np.nan
cidx_models = [np.nan] * len(model_names)
if have_future:
    cidx_caide = cindex_from_risk(t_fut_all, r_caide_f, e_fut_all)
    for j in range(len(model_names)):
        cidx_models[j] = cindex_from_risk(t_fut_all, R_fut[:, j], e_fut_all)

# =========================
rng_auc  = np.random.default_rng(SEED_AUC)
rng_cidx = np.random.default_rng(SEED_CIDX)

auc_boot_mat   = []  
auc_boot_caide = []  
diffs_auc_mat  = []  

if pd.Series(y_test_all).nunique() == 2 and len(y_test_all) >= 2:
    n = len(y_test_all)
    for b in range(N_BOOT):
        idx = rng_auc.integers(0, n, size=n)
        yb = y_test_all[idx]
        progress_print(b, N_BOOT, "AUC  boot")
        if pd.Series(yb).nunique() < 2:
            continue
        s_cai_b = s_caide_test[idx]

        # CAIDE AUROC
        try:
            auc_cai_b = roc_auc_score(yb, s_cai_b)
        except Exception:
            auc_cai_b = np.nan

        aucs_b = np.full(len(model_names), np.nan, dtype=float)
        diffs_b = np.full(len(model_names), np.nan, dtype=float)
        for j in range(len(model_names)):
            s_m_b = S_test[idx, j]
            try:
                auc_m_b = roc_auc_score(yb, s_m_b)
                aucs_b[j] = auc_m_b
                diffs_b[j] = auc_m_b - auc_cai_b
            except Exception:
                pass

        auc_boot_mat.append(aucs_b)
        auc_boot_caide.append(auc_cai_b)
        diffs_auc_mat.append(diffs_b)

if len(auc_boot_mat) > 0:
    auc_boot_mat   = np.vstack(auc_boot_mat)                      # (B_eff, n_models)
    auc_boot_caide = np.asarray(auc_boot_caide, dtype=float)      # (B_eff,)
    diffs_auc_mat  = np.vstack(diffs_auc_mat)                     # (B_eff, n_models)
    auc_lo = np.nanpercentile(auc_boot_mat, 2.5, axis=0)
    auc_hi = np.nanpercentile(auc_boot_mat, 97.5, axis=0)
    # CAIDE AUROC 95% CI
    auc_caide_lo, auc_caide_hi = np.nanpercentile(auc_boot_caide, [2.5, 97.5])
    # ΔAUROC 95% CI
    diff_auc_means = np.nanmean(diffs_auc_mat, axis=0)
    diff_auc_lo    = np.nanpercentile(diffs_auc_mat, 2.5, axis=0)
    diff_auc_hi    = np.nanpercentile(diffs_auc_mat, 97.5, axis=0)
else:
    auc_lo = auc_hi = diff_auc_means = diff_auc_lo = diff_auc_hi = None
    auc_caide_lo = auc_caide_hi = np.nan

cidx_boot_mat   = []
cidx_boot_caide = []
diffs_cidx_mat  = []

if have_future and len(t_fut_all) >= 2:
    n = len(t_fut_all)
    for b in range(N_BOOT):
        idx = rng_cidx.integers(0, n, size=n)
        tb, eb = t_fut_all[idx], e_fut_all[idx]
        r_cai_b = r_caide_f[idx]
        progress_print(b, N_BOOT, "Cidx boot")
        try:
            c_cai_b = cindex_from_risk(tb, r_cai_b, eb)
        except Exception:
            c_cai_b = np.nan

        cidxs_b = np.full(len(model_names), np.nan, dtype=float)
        diffs_b = np.full(len(model_names), np.nan, dtype=float)
        for j in range(len(model_names)):
            r_m_b = R_fut[idx, j]
            try:
                c_m_b = cindex_from_risk(tb, r_m_b, eb)
                cidxs_b[j] = c_m_b
                diffs_b[j] = c_m_b - c_cai_b
            except Exception:
                pass

        cidx_boot_mat.append(cidxs_b)
        cidx_boot_caide.append(c_cai_b)
        diffs_cidx_mat.append(diffs_b)

if len(cidx_boot_mat) > 0:
    cidx_boot_mat   = np.vstack(cidx_boot_mat)                    # (B_eff, n_models)
    cidx_boot_caide = np.asarray(cidx_boot_caide, dtype=float)    # (B_eff,)
    diffs_cidx_mat  = np.vstack(diffs_cidx_mat)                   # (B_eff, n_models)
    cidx_lo = np.nanpercentile(cidx_boot_mat, 2.5, axis=0)
    cidx_hi = np.nanpercentile(cidx_boot_mat, 97.5, axis=0)
    # CAIDE C-index 95% CI
    cidx_caide_lo, cidx_caide_hi = np.nanpercentile(cidx_boot_caide, [2.5, 97.5])
    # ΔC-index 95% CI
    diff_cidx_means = np.nanmean(diffs_cidx_mat, axis=0)
    diff_cidx_lo    = np.nanpercentile(diffs_cidx_mat, 2.5, axis=0)
    diff_cidx_hi    = np.nanpercentile(diffs_cidx_mat, 97.5, axis=0)
else:
    cidx_lo = cidx_hi = diff_cidx_means = diff_cidx_lo = diff_cidx_hi = None
    cidx_caide_lo = cidx_caide_hi = np.nan

# =========================
print("\n=== CAIDE (Common cohort) ===")
if not pd.isna(auc_caide):
    print(f"AUROC (CAIDE): {auc_caide:.3f}", end="")
    if not np.isnan(auc_caide_lo):
        print(f"  | 95% CI [{auc_caide_lo:.3f}, {auc_caide_hi:.3f}]")
    else:
        print()
else:
    print("AUROC (CAIDE): NA")
if have_future and not pd.isna(cidx_caide):
    line = f"C-index (CAIDE): {cidx_caide:.3f}"
    if not np.isnan(cidx_caide_lo):
        line += f"  | 95% CI [{cidx_caide_lo:.3f}, {cidx_caide_hi:.3f}]"
    print(line)

for j, m in enumerate(model_names):
    print(f"\n=== {m} (Common cohort) ===")
    # AUROC & 95% CI
    auc_m = auc_models[j]
    if not pd.isna(auc_m):
        line = f"AUROC (Model): {auc_m:.3f}"
        if isinstance(auc_lo, np.ndarray) and not pd.isna(auc_lo[j]):
            line += f"  | 95% CI [{auc_lo[j]:.3f}, {auc_hi[j]:.3f}]"
        print(line)
    else:
        print("AUROC (Model): NA")
    # ΔAUROC 95% CI
    if isinstance(diff_auc_means, np.ndarray) and not any(pd.isna([diff_auc_means[j], diff_auc_lo[j], diff_auc_hi[j]])):
        print(f"ΔAUROC (Model − CAIDE) 95% CI: {diff_auc_means[j]:.4f} [{diff_auc_lo[j]:.4f}, {diff_auc_hi[j]:.4f}]")
    else:
        print("ΔAUROC 95% CI: NA")

    if have_future:
        # C-index & 95% CI
        c_m = cidx_models[j]
        if not pd.isna(c_m):
            line = f"C-index (Model): {c_m:.3f}"
            if isinstance(cidx_lo, np.ndarray) and not pd.isna(cidx_lo[j]):
                line += f"  | 95% CI [{cidx_lo[j]:.3f}, {cidx_hi[j]:.3f}]"
            print(line)
        else:
            print("C-index (Model): NA")
        # ΔC-index 95% CI
        if isinstance(diff_cidx_means, np.ndarray) and not any(pd.isna([diff_cidx_means[j], diff_cidx_lo[j], diff_cidx_hi[j]])):
            print(f"ΔC-index (Model − CAIDE) 95% CI: {diff_cidx_means[j]:.4f} [{diff_cidx_lo[j]:.4f}, {diff_cidx_hi[j]:.4f}]")
        else:
            print("ΔC-index 95% CI: NA")

# =========================
rows = []
for j, m in enumerate(model_names):
    # OR (Test, per 10% ↑)
    s_pred_full = pd.Series(S_test[:, j], index=common_test_idx)
    y_full      = pd.Series(y_test_all, index=common_test_idx, name="label")

    pred10 = (s_pred_full.to_numpy(dtype=float) * 10.0)
    OR_a = ORa_low = ORa_hi = ORa_p = np.nan
    try:
        mask_logit = (
            test_cov_mask.to_numpy()[common_test_idx] &
            np.isfinite(pred10) & np.isfinite(y_full.to_numpy(dtype=float))
        )
        if mask_logit.any():
            Xlog = X_test_base.iloc[common_test_idx[mask_logit]].copy()
            Xlog = sm.add_constant(pd.concat([pd.Series(pred10[mask_logit], name="pred10", index=Xlog.index), Xlog], axis=1))
            ylog = pd.Series(y_full.to_numpy(dtype=float)[mask_logit], index=Xlog.index, name="label_det")
            lg = sm.Logit(ylog, Xlog).fit(disp=False)
            OR_a = float(np.exp(lg.params["pred10"]))
            ci = lg.conf_int().loc["pred10"]
            ORa_low, ORa_hi = float(np.exp(ci[0])), float(np.exp(ci[1]))
            ORa_p = fmt_p(lg.pvalues["pred10"])
    except Exception:
        pass

    # HR (Future, per 10% ↑)
    HR_a = HRa_low = HRa_hi = HRa_p = np.nan
    n_fut_used = 0
    if have_future and m in fut_series:
        r_pred_full = pd.Series(R_fut[:, j], index=common_fut_idx)
        pred10_f = (r_pred_full.to_numpy(dtype=float) * 10.0)
        try:
            mask_cox = (
                fut_cov_mask.to_numpy()[common_fut_idx] &
                np.isfinite(pred10_f) &
                np.isfinite(t_fut_all) &
                np.isfinite(e_fut_all)
            )
            n_fut_used = int(mask_cox.sum())
            if mask_cox.any():
                Xcox = X_fut_base.iloc[common_fut_idx[mask_cox]].copy()
                df_cox = pd.concat(
                    [
                        pd.Series(t_fut_all[mask_cox], name="obs_time", index=Xcox.index),
                        pd.Series(e_fut_all[mask_cox], name="event", index=Xcox.index),
                        pd.Series(pred10_f[mask_cox],  name="pred10", index=Xcox.index),
                        Xcox
                    ],
                    axis=1
                )
                cph = CoxPHFitter()
                cph.fit(df_cox, duration_col="obs_time", event_col="event", show_progress=False)
                HR_a = float(np.exp(cph.params_["pred10"]))
                ci = cph.confidence_intervals_.loc["pred10"]
                HRa_low, HRa_hi = float(np.exp(ci[0])), float(np.exp(ci[1]))
                HRa_p = fmt_p(cph.summary.loc["pred10","p"])
        except Exception:
            pass

    rows.append(dict(
        model_desc = m,
        Test_AUC_MODEL = None if pd.isna(auc_models[j]) else float(f"{auc_models[j]:.3f}"),
        Test_AUC_CAIDE = None if pd.isna(auc_caide) else float(f"{auc_caide:.3f}"),
        OR_adj = OR_a, OR_adj_low = ORa_low, OR_adj_hi = ORa_hi, OR_adj_p = ORa_p,
        Future_cindex_MODEL = None if (not have_future or pd.isna(cidx_models[j])) else float(f"{cidx_models[j]:.3f}"),
        Future_cindex_CAIDE = None if (not have_future or pd.isna(cidx_caide)) else float(f"{cidx_caide:.3f}"),
        HR_adj = HR_a, HR_adj_low = HRa_low, HR_adj_hi = HRa_hi, HRa_p = HRa_p,
        n_test_used = int(len(common_test_idx)),
        n_fut_used  = int(n_fut_used if have_future else 0),
    ))

summary = pd.DataFrame(rows).sort_values("model_desc").reset_index(drop=True)
for col in ["OR_adj","OR_adj_low","OR_adj_hi","HR_adj","HR_adj_low","HR_adj_hi"]:
    if col in summary.columns:
        summary[col] = summary[col].apply(lambda v: "" if pd.isna(v) else f"{float(v):.3f}")

summary.to_csv(OUT_SUMMARY, index=False, encoding="utf-8")
print(f"\n[Saved] {OUT_SUMMARY}\n")
print(summary)


# detailed performance metrics bootstrap

In [None]:
# -*- coding: utf-8 -*-
"""
Test-set metrics (incl. AUROC) at validation Youden-J thresholds with single-cycle bootstrap for all models.

- Thresholds: from youden_thresholds.csv (split='valid')
- Models: defined by ARCHES × FTS, pick one row per (arch, ft) (ties -> highest AUC in csv)
- Test predictions: /home/hch/dementia/infer_out/{MODEL_DESC}/test_preds.csv (expects: idx, label, pred)
- Evaluation:
  * Use intersection of idx across all selected models
  * Compute Accuracy, Sensitivity, Specificity, PPV, NPV, F1, AUROC
  * One bootstrap cycle for all models together (N_BOOT=2000), same resample indices
- Output: single table with point estimates and 95% CIs (rounded to 3 decimals)
"""

import os, warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import trange
from sklearn.metrics import roc_auc_score

# ---------------------------
THRESH_CSV = Path("./youden_thresholds.csv")
INFER_ROOT = Path("/home/hch/dementia/infer_out")
N_BOOT = 2000
RNG_SEED = 42
OUT_CSV = Path("./test_metrics_bootstrap.csv")

# 모델 스펙(arch, ft) 선택
ARCHES = ['retfound', 'mae', 'openclip', 'dinov2', 'dinov3', 'retfound_dinov2']
FTS    = ['linear', 'partial', 'lora']

# ---------------------------
def ft_blks_for(ft: str):
    if ft == 'partial': return 4
    if ft == 'lora':    return 'full'
    return None

def model_desc_from_csv_row(row: pd.Series) -> str:
    arch = str(row['arch']); ft = str(row['ft'])
    desc = f"{arch}_{ft}"
    if ft == 'partial':
        desc += f"_ft_{row['ft_blks']}"
    elif ft == 'lora':
        lrk = str(row.get('lora_rank', ''))
        desc += f"_rank_{lrk}_ft_{row['ft_blks']}"
    return desc

def pick_threshold_row(thr_df: pd.DataFrame, arch: str, ft: str) -> pd.Series | None:
    df = thr_df.copy()
    df = df.query("arch == @arch and ft == @ft")
    if ft == 'partial':
        df = df.query("ft_blks == '4'")
    elif ft == 'lora':
        df = df.query("ft_blks == 'full'")
    # linear은 ft_blks 조건 없음
    if len(df) == 0:
        return None
    if 'auc' in df.columns:
        df = df.sort_values('auc', ascending=False)
    return df.iloc[0]

def safe_div(num, den):
    num = float(num); den = float(den)
    if not np.isfinite(den) or den == 0: return np.nan
    return num / den

def counts_to_metrics(tp, fp, tn, fn):
    acc  = safe_div(tp + tn, tp + fp + tn + fn)
    sens = safe_div(tp, tp + fn)          # TPR
    spec = safe_div(tn, tn + fp)          # TNR
    ppv  = safe_div(tp, tp + fp)          # Precision
    npv  = safe_div(tn, tn + fn)
    f1   = safe_div(2*tp, 2*tp + fp + fn)
    return acc, sens, spec, ppv, npv, f1

def safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray):
    y_true = np.asarray(y_true).astype(int)
    if (np.unique(y_true).size < 2):
        return np.nan
    try:
        return float(roc_auc_score(y_true, y_prob))
    except Exception:
        return np.nan

def format3(x):
    if x is None or (isinstance(x, float) and (np.isnan(x) or not np.isfinite(x))):
        return ""
    return f"{x:.3f}"

def ci_str(low, high):
    if (low is None) or (high is None) or any(
        (v is None) or (isinstance(v,float) and (not np.isfinite(v))) for v in [low, high]
    ):
        return ""
    return f"[{low:.3f}, {high:.3f}]"

# ---------------------------
assert THRESH_CSV.exists(), f"youden_thresholds.csv not found at: {THRESH_CSV}"
thr_df = pd.read_csv(THRESH_CSV)
for col in ["arch","ft","ft_blks","split"]:
    if col in thr_df.columns:
        thr_df[col] = thr_df[col].astype(str)
thr_df = thr_df.query("split == 'valid'").copy()

# ---------------------------
selected = []
for arch in ARCHES:
    for ft in FTS:
        row = pick_threshold_row(thr_df, arch, ft)
        if row is None:
            continue
        if ft == 'partial' and str(row['ft_blks']) != '4':
            continue
        if ft == 'lora' and str(row['ft_blks']) != 'full':
            continue
        desc = model_desc_from_csv_row(row)
        thr  = float(row['youden_thr'])
        selected.append({
            "arch": arch,
            "ft": ft,
            "model_desc": desc,
            "threshold": thr
        })

tmp_df = pd.DataFrame(selected).drop_duplicates(subset=["model_desc"]).reset_index(drop=True)
selected = tmp_df.to_dict(orient="records")
if len(selected) == 0:
    raise RuntimeError("no models selected. check youden_thresholds.csv and ARCHES/FTS.")

print(f"number of selected models: {len(selected)}")
for s in selected:
    print(f" - {s['model_desc']} (thr={s['threshold']:.6f})")

# ---------------------------
preds_by_model = {}
idx_sets = []

for s in selected:
    desc = s["model_desc"]
    test_csv = INFER_ROOT / desc / "test_preds.csv"
    if not test_csv.exists():
        print(f"[SKIP] missing: {test_csv}")
        continue
    df = pd.read_csv(test_csv)
    for col in ["idx", "label", "pred"]:
        if col not in df.columns:
            raise ValueError(f"{test_csv}, {col}")
    df = df[["idx","label","pred"]].copy()
    df["idx"] = df["idx"].astype(int)
    df = df.dropna(subset=["label","pred"]).reset_index(drop=True)
    preds_by_model[desc] = df
    idx_sets.append(set(df["idx"].tolist()))

if len(preds_by_model) == 0:
    raise RuntimeError("no models with test_preds.csv")

common_idx = set.intersection(*idx_sets) if len(idx_sets) > 1 else idx_sets[0]
common_idx = sorted(list(common_idx))
if len(common_idx) == 0:
    raise RuntimeError("error")
print(f"samples: {len(common_idx)}")

data_by_model = {}
for s in selected:
    desc = s["model_desc"]; thr = s["threshold"]
    if desc not in preds_by_model:
        continue
    df = preds_by_model[desc]
    sub = df[df["idx"].isin(common_idx)].copy()
    sub = sub.set_index("idx").loc[common_idx].reset_index()
    y_true = sub["label"].astype(int).to_numpy()
    y_prob = sub["pred"].astype(float).to_numpy()
    data_by_model[desc] = {
        "threshold": thr,
        "y_true": y_true,
        "y_prob": y_prob
    }

if len(data_by_model) == 0:
    raise RuntimeError("error")

# ---------------------------
def metrics_on_threshold(y_true, y_prob, thr):
    y_pred = (y_prob >= thr).astype(int)
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    acc, sens, spec, ppv, npv, f1 = counts_to_metrics(tp, fp, tn, fn)
    return {
        "TP": tp, "FP": fp, "TN": tn, "FN": fn,
        "Accuracy": acc, "Sensitivity": sens, "Specificity": spec,
        "PPV": ppv, "NPV": npv, "F1": f1
    }

point_estimates = {}
for desc, pack in data_by_model.items():
    pe = metrics_on_threshold(pack["y_true"], pack["y_prob"], pack["threshold"])
    pe["AUROC"] = safe_roc_auc(pack["y_true"], pack["y_prob"])
    point_estimates[desc] = pe

# ---------------------------
rng = np.random.default_rng(RNG_SEED)
n = len(common_idx)

metrics_list = ["AUROC","Accuracy","Sensitivity","Specificity","PPV","NPV","F1"]
boot_store = {m: {desc: np.full(N_BOOT, np.nan, dtype=float) for desc in data_by_model.keys()} 
              for m in metrics_list}

for b in trange(N_BOOT, desc="Bootstrapping (single-cycle for all models)"):
    sample_idx = rng.integers(low=0, high=n, size=n, endpoint=False)
    for desc, pack in data_by_model.items():
        y_t = pack["y_true"][sample_idx]
        y_p = pack["y_prob"][sample_idx]
        thr = pack["threshold"]
        # AUROC
        boot_store["AUROC"][desc][b] = safe_roc_auc(y_t, y_p)
        # Threshold-based metrics
        m = metrics_on_threshold(y_t, y_p, thr)
        for k in ["Accuracy","Sensitivity","Specificity","PPV","NPV","F1"]:
            boot_store[k][desc][b] = m[k]

def ci95(arr):
    arr = np.asarray(arr, dtype=float)
    arr = arr[np.isfinite(arr)]
    if arr.size == 0:
        return (np.nan, np.nan)
    return (np.percentile(arr, 2.5), np.percentile(arr, 97.5))

# ---------------------------
rows = []
for s in selected:
    desc = s["model_desc"]
    if desc not in data_by_model:
        continue
    pe = point_estimates[desc]
    row = {
        "model_desc": desc,
        "threshold": round(s["threshold"], 6),
    }
    au_point = pe.get("AUROC", np.nan)
    au_low, au_high = ci95(boot_store["AUROC"][desc])
    row["AUROC"] = format3(au_point)
    row["AUROC 95% CI"] = ci_str(au_low, au_high)
    for metric in ["Accuracy","Sensitivity","Specificity","PPV","NPV","F1"]:
        point = pe[metric]
        low, high = ci95(boot_store[metric][desc])
        row[metric] = format3(point)
        row[f"{metric} 95% CI"] = ci_str(low, high)
    rows.append(row)

res_df = pd.DataFrame(rows)
cols_order = ["model_desc","threshold",
              "AUROC","AUROC 95% CI",
              "Accuracy","Accuracy 95% CI",
              "Sensitivity","Sensitivity 95% CI",
              "Specificity","Specificity 95% CI",
              "PPV","PPV 95% CI",
              "NPV","NPV 95% CI",
              "F1","F1 95% CI"]
res_df = res_df.reindex(columns=cols_order)

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
res_df.to_csv(OUT_CSV, index=False)
print("\n=== Test metrics incl. AUROC @ validation Youden-J thresholds (single-cycle bootstrap) ===")
print(res_df.to_string(index=False))
print(f"\n[Saved] {OUT_CSV}")


In [None]:
# -*- coding: utf-8 -*-
"""
Format 'test_metrics_bootstrap.csv' + 'caide_test_metrics_threshold5.csv' to match the layout of Supplementary Table S2.

Output columns:
  Model | Fine-tuning method | Accuracy (95% CI) | Sensitivity (95% CI) | Specificity (95% CI)
        | PPV (95% CI) | NPV (95% CI) | F1-score (95% CI)

- Model order: OpenCLIP, MAE, DINOv3, RETFound-MAE, RETFound-DINOv2
- Fine-tuning rows per Model: Classifier only (linear), Partial finetuning (partial), LoRA tuning (lora)
- Cells show "point" in first line and "(low - high)" in second line. Missing combos remain blank.

Reads:
  ./test_metrics_bootstrap.csv
  ./caide_test_metrics_threshold5.csv

Saves:
  ./supp_table_S2_formatted.csv
"""

import os, warnings, re
warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import pandas as pd

# -----------------------
# Paths
# -----------------------
TEST_METRICS_CSV  = Path("./test_metrics_bootstrap.csv")
CAIDE_METRICS_CSV = Path("./caide_test_metrics_threshold5.csv")
OUT_CSV           = Path("./supp_table_S2_formatted.csv")

assert TEST_METRICS_CSV.exists(), f"Not found: {TEST_METRICS_CSV}"
assert CAIDE_METRICS_CSV.exists(), f"Not found: {CAIDE_METRICS_CSV}"

df = pd.read_csv(TEST_METRICS_CSV)
caide = pd.read_csv(CAIDE_METRICS_CSV)

# -----------------------
# Helpers
# -----------------------
def parse_model_desc(desc: str):
    """
    Parse model_desc like:
      openclip_linear
      mae_partial_ft_4
      dinov3_lora_rank_4_ft_full
      retfound_dinov2_partial_ft_4
      retfound_mae_linear
    Return (arch_key, ft_key)
    """
    parts = desc.split("_")
    if parts[0] == "retfound" and len(parts) >= 2 and parts[1] in ("dinov2", "mae"):
        arch_key = f"{parts[0]}_{parts[1]}"  # retfound_dinov2, retfound_mae
        # ft is next token
        ft_key = parts[2] if len(parts) >= 3 else ""
    else:
        arch_key = parts[0] if len(parts) >= 1 else ""
        ft_key = parts[1] if len(parts) >= 2 else ""
    return arch_key, ft_key

ARCH_DISPLAY = {
    "openclip": "OpenCLIP",
    "mae": "MAE",
    "dinov3": "DINOv3",
    "dinov2": "DINOv2",            # in case
    "retfound": "RETFound-MAE",
    "retfound_dinov2": "RETFound-DINOv2",
    "retfound": "RETFound",        # fallback
}

FT_DISPLAY = {
    "linear": "Classifier only",
    "partial": "Partial finetuning",
    "lora": "LoRA tuning",
}

MODEL_ORDER = ["openclip", "mae", "dinov3", "retfound", "retfound_dinov2"]
FT_ORDER    = ["linear", "partial", "lora"]

METRICS = ["Accuracy", "Sensitivity", "Specificity", "PPV", "NPV", "F1"]

def normalize_ci_text(ci_str: str) -> str:
    """
    Convert '[0.680, 0.730]' -> '0.680 - 0.730'
    """
    if not isinstance(ci_str, str) or ci_str.strip() == "":
        return ""
    s = ci_str.strip()
    if s.startswith("[") and s.endswith("]"):
        s = s[1:-1]
    s = re.sub(r"\s*,\s*", " - ", s)
    return s

def cell(point, ci):
    """
    Format a cell with point on first line and (low - high) on next line.
    If both missing -> ''.
    """
    p = "" if (point is None or (isinstance(point, float) and not np.isfinite(point)) or str(point).strip()=="") else str(point)
    c = normalize_ci_text(ci)
    if p == "" and c == "": 
        return ""
    return f"{p}\n({c})" if c != "" else p

# -----------------------
# Preprocess: attach parsed keys
# -----------------------
df = df.copy()
df["arch_key"] = ""
df["ft_key"] = ""
for i, r in df.iterrows():
    a, f = parse_model_desc(str(r["model_desc"]))
    df.at[i, "arch_key"] = a
    df.at[i, "ft_key"] = f

# -----------------------
# Build formatted table rows
# -----------------------
rows = []
for arch_key in MODEL_ORDER:
    for k, ft in enumerate(FT_ORDER):
        sub = df[(df["arch_key"] == arch_key) & (df["ft_key"] == ft)]
        model_cell = ARCH_DISPLAY.get(arch_key, arch_key) if k == 0 else ""
        ft_cell = FT_DISPLAY.get(ft, ft)

        row = {
            "Model": model_cell,
            "Fine-tuning method": ft_cell,
        }

        if len(sub) == 1:
            s = sub.iloc[0]
            for m in METRICS:
                row[f"{m} (95% CI)"] = cell(s.get(m, ""), s.get(f"{m} 95% CI", ""))
        else:
            for m in METRICS:
                row[f"{m} (95% CI)"] = ""
        rows.append(row)

# -----------------------
# Append CAIDE row at bottom
# -----------------------
if len(caide) >= 1:
    c = caide.iloc[0]
    caide_row = {"Model": "CAIDE score", "Fine-tuning method": ""}
    for m in METRICS:
        caide_row[f"{m} (95% CI)"] = cell(c.get(m, ""), c.get(f"{m} 95% CI", ""))
    rows.append(caide_row)

# -----------------------
# Create DataFrame & save
# -----------------------
out_cols = ["Model", "Fine-tuning method"] + [f"{m} (95% CI)" for m in METRICS]
out_df = pd.DataFrame(rows, columns=out_cols)

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
out_df.to_csv(OUT_CSV, index=False)
print("\n=== Supplementary Table S2 (formatted) ===")
print(out_df.to_string(index=False))
print(f"\n[Saved] {OUT_CSV}")


In [None]:
# -*- coding: utf-8 -*-
"""
CAIDE-only detailed performance on TEST set (threshold = 5 points)
- Metrics: AUROC + (Accuracy, Sensitivity, Specificity, PPV, NPV, F1)
- 2000× bootstrap for 95% CI (single cycle, CAIDE only)
- Output: ./caide_test_metrics_threshold5.csv
"""

import os, sys, warnings
warnings.filterwarnings("ignore")
os.environ["PYTHONWARNINGS"] = "ignore"

import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import trange
from sklearn.metrics import roc_auc_score

# =========================
WORK_DIR = "/home/hch/dementia"
os.chdir(WORK_DIR)

IN_CSV  = Path("dementia_detection.csv")  
OUT_CSV = Path("./caide_test_metrics_threshold5.csv")

N_BOOT   = 2000
RNG_SEED = 2025
CAIDE_THRESHOLD = 6  # points

# =========================
def add_common_covs(df: pd.DataFrame) -> pd.DataFrame:
    out = df.copy()

    if "SEXINT" not in out.columns:
        if "SEX" in out.columns:
            out["SEXINT"] = (out["SEX"] == "M").astype(int)
        else:
            out["SEXINT"] = np.nan

    if "EXERCISE_STATUS" in out.columns:
        out["EXERCISE_STATUS"] = pd.to_numeric(out["EXERCISE_STATUS"], errors="coerce")
    else:
        out["EXERCISE_STATUS"] = np.nan

    for c in ["STDY_AGE", "cholesterol_updated", "sbp", "bmi"]:
        if c in out.columns:
            out[c] = pd.to_numeric(out[c], errors="coerce")
        else:
            out[c] = np.nan
    return out

def calc_caide_napoe(frame: pd.DataFrame) -> pd.Series:
    x = frame.copy()
    for col in ["STDY_AGE","SEXINT","sbp","bmi","cholesterol_updated","EXERCISE_STATUS"]:
        if col in x.columns:
            x[col] = pd.to_numeric(x[col], errors="coerce")
        else:
            x[col] = np.nan

    age = x["STDY_AGE"]
    age_pts = np.select([age < 47, (47 <= age) & (age <= 53), age > 53], [0, 3, 4], default=np.nan)
    sex_pts = np.where(x["SEXINT"] == 1, 1, 0)
    edu_pts = np.zeros(len(x), dtype=float)  
    sbp_pts = np.where(x["sbp"] >= 140, 2, 0)
    bmi_pts = np.where(x["bmi"] >= 30, 2, 0)
    chol_mmol = x["cholesterol_updated"] * 0.02586
    chol_pts  = np.where(chol_mmol >= 6.5, 2, 0)
    ex = x["EXERCISE_STATUS"]
    pa_pts = np.where(ex >= 2, 0, 1)  

    score = age_pts + sex_pts + edu_pts + sbp_pts + bmi_pts + chol_pts + pa_pts
    return pd.Series(score, index=frame.index, name="CAIDE_noAPOE")

def caide_valid_mask(frame: pd.DataFrame) -> pd.Series:
    req = ["STDY_AGE","SEXINT","sbp","bmi","cholesterol_updated","EXERCISE_STATUS"]
    return frame[req].notna().all(axis=1)

def safe_div(num, den):
    num = float(num); den = float(den)
    if not np.isfinite(den) or den == 0: return np.nan
    return num / den

def counts_to_metrics(tp, fp, tn, fn):
    acc  = safe_div(tp + tn, tp + fp + tn + fn)
    sens = safe_div(tp, tp + fn)          # TPR
    spec = safe_div(tn, tn + fp)          # TNR
    ppv  = safe_div(tp, tp + fp)          # Precision
    npv  = safe_div(tn, tn + fn)
    f1   = safe_div(2*tp, 2*tp + fp + fn)
    return acc, sens, spec, ppv, npv, f1

def safe_auc(y_true: np.ndarray, y_prob: np.ndarray):
    y_true = np.asarray(y_true).astype(int)
    if np.unique(y_true).size < 2:
        return np.nan
    try:
        return float(roc_auc_score(y_true, y_prob))
    except Exception:
        return np.nan

def format3(x):
    if x is None or (isinstance(x, float) and (np.isnan(x) or not np.isfinite(x))):
        return ""
    return f"{x:.3f}"

def ci_str(lo, hi):
    if any((v is None) or (isinstance(v, float) and (not np.isfinite(v))) for v in [lo, hi]):
        return ""
    return f"[{lo:.3f}, {hi:.3f}]"

def metrics_from_arrays(y_true: np.ndarray, caide_score: np.ndarray, thr: float):
    y_true = np.asarray(y_true).astype(int)
    y_pred = (caide_score >= thr).astype(int)
    tp = int(((y_true == 1) & (y_pred == 1)).sum())
    fp = int(((y_true == 0) & (y_pred == 1)).sum())
    tn = int(((y_true == 0) & (y_pred == 0)).sum())
    fn = int(((y_true == 1) & (y_pred == 0)).sum())
    acc, sens, spec, ppv, npv, f1 = counts_to_metrics(tp, fp, tn, fn)
    auc = safe_auc(y_true, caide_score)  
    return {
        "TP": tp, "FP": fp, "TN": tn, "FN": fn,
        "AUROC": auc,
        "Accuracy": acc, "Sensitivity": sens, "Specificity": spec,
        "PPV": ppv, "NPV": npv, "F1": f1
    }

# =========================
assert IN_CSV.exists(), f"Input not found: {IN_CSV}"
det_all = pd.read_csv(IN_CSV, low_memory=False)

if "test" not in det_all.columns:
    raise ValueError("error")
det_test = det_all[det_all["test"] == True].reset_index(drop=True)

det_test = add_common_covs(det_test)
caide_score = calc_caide_napoe(det_test)               
valid_mask  = caide_valid_mask(det_test)               
if "label" not in det_test.columns:
    raise ValueError("error")
label_valid = det_test["label"].notna()

mask = valid_mask & label_valid
det_used = det_test.loc[mask].copy()

y_true = det_used["label"].astype(int).to_numpy()
s_caide = caide_score.loc[mask].astype(float).to_numpy()
n = len(det_used)
if n == 0:
    raise RuntimeError("error")

print(f"CAIDE test-set evaluation (threshold = {CAIDE_THRESHOLD})")
print(f"- usable samples: n = {n}")

# =========================
point = metrics_from_arrays(y_true, s_caide, CAIDE_THRESHOLD)

# =========================
rng = np.random.default_rng(RNG_SEED)

boot = {
    "AUROC": np.full(N_BOOT, np.nan, dtype=float),
    "Accuracy": np.full(N_BOOT, np.nan, dtype=float),
    "Sensitivity": np.full(N_BOOT, np.nan, dtype=float),
    "Specificity": np.full(N_BOOT, np.nan, dtype=float),
    "PPV": np.full(N_BOOT, np.nan, dtype=float),
    "NPV": np.full(N_BOOT, np.nan, dtype=float),
    "F1": np.full(N_BOOT, np.nan, dtype=float),
}

for b in trange(N_BOOT, desc="Bootstrapping (CAIDE only)"):
    idx = rng.integers(0, n, size=n)
    yb = y_true[idx]
    sb = s_caide[idx]
    m = metrics_from_arrays(yb, sb, CAIDE_THRESHOLD)
    for k in boot.keys():
        boot[k][b] = m[k]

def ci95(arr):
    arr = np.asarray(arr, dtype=float)
    arr = arr[np.isfinite(arr)]
    if arr.size == 0: return (np.nan, np.nan)
    return (np.percentile(arr, 2.5), np.percentile(arr, 97.5))

# =========================
rows = []
row = {
    "Metric": "CAIDE (threshold=5) - TEST",
    "n": int(n),
    "TP": int(point["TP"]),
    "FP": int(point["FP"]),
    "TN": int(point["TN"]),
    "FN": int(point["FN"]),
}
for metric in ["AUROC","Accuracy","Sensitivity","Specificity","PPV","NPV","F1"]:
    lo, hi = ci95(boot[metric])
    row[metric] = format3(point[metric])
    row[f"{metric} 95% CI"] = ci_str(lo, hi)
rows.append(row)

res = pd.DataFrame(rows, columns=[
    "Metric","n","TP","FP","TN","FN",
    "AUROC","AUROC 95% CI",
    "Accuracy","Accuracy 95% CI",
    "Sensitivity","Sensitivity 95% CI",
    "Specificity","Specificity 95% CI",
    "PPV","PPV 95% CI",
    "NPV","NPV 95% CI",
    "F1","F1 95% CI"
])

OUT_CSV.parent.mkdir(parents=True, exist_ok=True)
res.to_csv(OUT_CSV, index=False)
print("\n=== CAIDE detailed performance (TEST, thr=5) ===")
print(res.to_string(index=False))
print(f"\n[Saved] {OUT_CSV}")
