In [None]:
# %% [markdown]
# # Grad-CAM export — ALL ARCH × FT combos (use thresholds from youden_thresholds.csv)
# - Replicates the exact pipeline you ran for "retfound_partial_ft_4"
# - For each (arch, ft) combo:
#     * Read BEST_THRESH from youden_thresholds.csv (split='valid', ft_blks mapping)
#     * Load model checkpoint from /home/kjw/Projects/dementia/ckpts/{MODEL_DESC}/ckpt.pth.tar
#     * Read INFER CSVs from /home/hch/dementia/infer_out/{MODEL_DESC}
#     * Filter TPs by BEST_THRESH (TEST, FUTURE)
#     * Save origin_images.pth & saliency_map.pth under ./tensors/{MODEL_DESC}_test|_future
#     * Save representative Grad-CAM (L/R heatmap & overlay) under ./rep_gradcam_lr/{MODEL_DESC}
#     * Compute Mask metrics & save RESULTS_*.csv
#     * Compute ATTN_summary_TEST/FUTURE.csv
# - Note: if multiple lora_rank exist for a (arch,lora,ft_blks="full"), pick the one with highest AUC.

# %%
import os, sys, math, json, random, warnings
warnings.filterwarnings("ignore")
from pathlib import Path
from dataclasses import dataclass

import numpy as np
import pandas as pd

from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose, Resize, ToTensor
from tqdm import tqdm
from sklearn.metrics import roc_curve, roc_auc_score

# ----- project-local imports (same as trainer.py / gradcam_dementia.py)
WORK_DIR   = "/home/hch/dementia"   # for importing local modules
MASTER_CSV = "/home/hch/opportunistic/20250518_master4_merged2.csv"  # for eye_orientation merge

# segmentation roots (same as your code)
ARTERY_DIR = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/artery_vein/artery_binary_process")
VEIN_DIR   = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/artery_vein/vein_binary_process")
BINVES_DIR = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/binary_vessel/binary_process")  # optional
ODC_RAW_DIR= Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/optic_disc_cup/raw")

# image size & batch
IMG_SIZE   = 448
BATCH      = 8

# make paths consistent with existing codebase
os.chdir(WORK_DIR)

from data import DementiaDetectionDataset, DementiaPredictionDataset
from models.encoder import build_model

# -----------------------------------------------------------------------------
# 0) Config: all combos
# -----------------------------------------------------------------------------
ARCHES = ['retfound', 'mae', 'openclip']
FTS    = ['partial', 'lora']


def ft_blks_for(ft: str):
    if ft == 'partial':
        return 4          
    if ft == 'lora':
        return 'full'     
    return None           

# -----------------------------------------------------------------------------
# 1) Utilities: MODEL_DESC, CKPT/INFER paths, robust CSV threshold lookup
# -----------------------------------------------------------------------------
def model_desc_from_csv_row(row: pd.Series) -> str:
    arch = str(row['arch'])
    ft   = str(row['ft'])
    desc = f"{arch}_{ft}"
    if ft == 'partial':
        desc += f"_ft_{row['ft_blks']}"
    elif ft == 'lora':
        lrk = str(row.get('lora_rank', ''))
        desc += f"_rank_{lrk}_ft_{row['ft_blks']}"
    return desc

def model_desc_from_spec(arch: str, ft: str, lora_rank: str | int | None) -> str:
    desc = f"{arch}_{ft}"
    if ft == 'partial':
        desc += f"_ft_4"
    elif ft == 'lora':
        lrk = str(lora_rank if lora_rank is not None else "4")  # default guess 4
        desc += f"_rank_{lrk}_ft_full"
    return desc

def ckpt_path(desc: str) -> str:
    return f"/home/kjw/Projects/dementia/ckpts/{desc}/ckpt.pth.tar"

def infer_dir(desc: str) -> str:
    return f"/home/hch/dementia/infer_out/{desc}"

def tensors_dirs(desc: str) -> tuple[str, str]:
    out_test   = f"./tensors/{desc}_test"
    out_future = f"./tensors/{desc}_future"
    os.makedirs(out_test, exist_ok=True)
    os.makedirs(out_future, exist_ok=True)
    return out_test, out_future

# -----------------------------------------------------------------------------
# 2) Read thresholds: choose best match per (arch, ft, ft_blks) — for lora pick the row with max AUC if multiple ranks
# -----------------------------------------------------------------------------
THRESH_CSV = Path("./youden_thresholds.csv")  # saved in root_dir by trainer.py
if not THRESH_CSV.exists():
    # 도중에 경로 달랐으면 여기 맞춰주세요
    alt = Path("/home/hch/dementia/youden_thresholds.csv")
    if alt.exists():
        THRESH_CSV = alt
assert THRESH_CSV.exists(), f"youden_thresholds.csv not found: {THRESH_CSV}"

thr_df = pd.read_csv(THRESH_CSV)
# normalize columns to str for matching
for col in ["arch","ft","ft_blks","split"]:
    if col in thr_df.columns:
        thr_df[col] = thr_df[col].astype(str)

# Filter only 'valid' split rows
thr_df = thr_df.query("split == 'valid'").copy()

def pick_threshold_row(arch: str, ft: str) -> pd.Series | None:
    ftb = ft_blks_for(ft)
    df = thr_df.copy()
    df = df.query("arch == @arch and ft == @ft")
    if ft == 'linear':
        pass
    elif ft == 'partial':
        df = df.query("ft_blks == '4'")
    elif ft == 'lora':
        df = df.query("ft_blks == 'full'")
    if len(df) == 0:
        return None
    if 'auc' in df.columns:
        df = df.sort_values('auc', ascending=False)
    return df.iloc[0]

# -----------------------------------------------------------------------------
# 3) Model build & load
# -----------------------------------------------------------------------------
def build_model_for(arch: str, ft: str, img_size: int = IMG_SIZE, enable_amp: bool = True, lora_rank=None):
    args = type("Args", (), {})()
    args.arch = arch
    args.ft   = ft
    args.img_size = img_size

    ftb = ft_blks_for(ft)
    if ftb is not None:
        args.ft_blks = ftb

    if ft == 'partial' and isinstance(getattr(args, 'ft_blks', None), str):
        args.ft_blks = int(args.ft_blks)

    if ft == 'lora':
        try:
            args.lora_rank = int(lora_rank) if lora_rank is not None and str(lora_rank) != 'nan' else 4
        except Exception:
            args.lora_rank = 4

    return build_model(args)

def safe_load_state_dict(model, sd):
    if isinstance(sd, dict):
        for key in ["state_dict", "model", "net", "module", "ema_state"]:
            if key in sd and isinstance(sd[key], dict):
                try:
                    model.load_state_dict(sd[key], strict=False)
                    return True
                except Exception:
                    pass
        try:
            model.load_state_dict(sd, strict=False)
            return True
        except Exception:
            pass
    return False

# -----------------------------------------------------------------------------
# 4) ViT attention patch & Grad-CAM rollout (same as your code)
# -----------------------------------------------------------------------------
def find_vit_backbone(module: nn.Module):
    if hasattr(module, "blocks"):
        return module
    for _, child in module.named_children():
        vb = find_vit_backbone(child)
        if vb is not None:
            return vb
    return None

def patch_attention_to_capture(vit_module):
    for blk in vit_module.blocks:
        attn = blk.attn
        if getattr(attn, "_patched_capture", False):
            continue
        def wrapped_forward(x, _attn=attn):
            B, N, C = x.shape
            qkv = _attn.qkv(x)
            head_dim = C // _attn.num_heads
            qkv = qkv.reshape(B, N, 3, _attn.num_heads, head_dim).permute(2,0,3,1,4)
            q, k, v = qkv[0], qkv[1], qkv[2]
            if hasattr(_attn, "q_norm") and _attn.q_norm is not None:
                q = _attn.q_norm(q.transpose(2,3)).transpose(2,3)
            if hasattr(_attn, "k_norm") and _attn.k_norm is not None:
                k = _attn.k_norm(k.transpose(2,3)).transpose(2,3)
            scale = (head_dim ** -0.5)
            q = q * scale
            attn_scores = torch.matmul(q, k.transpose(-2, -1))   # [B,H,N,N]
            attn_probs  = attn_scores.softmax(dim=-1)
            _attn._last_attn = attn_probs
            try:
                _attn._last_attn.retain_grad()
            except Exception:
                pass
            attn_probs = _attn.attn_drop(attn_probs)
            x_out = torch.matmul(attn_probs, v)
            x_out = x_out.transpose(1,2).reshape(B, N, C)
            x_out = _attn.proj(x_out)
            x_out = _attn.proj_drop(x_out)
            return x_out
        attn.forward = wrapped_forward
        attn._patched_capture = True


class ViTGradCamRollout:
    def __init__(self, vit_module):
        self.vit = vit_module

    def _normalize_attn(self, M):
        M = M / (M.sum(dim=-1, keepdim=True) + 1e-6)
        T = M.size(-1)
        M = M + torch.eye(T, device=M.device, dtype=M.dtype)
        M = M / (M.sum(dim=-1, keepdim=True) + 1e-6)
        return M

    def compute_cam(self, model, img_tensor, class_idx=1):
        model.zero_grad(set_to_none=True)
        out = model(img_tensor)
        if isinstance(out, dict):
            out = out.get("logits", out.get("pred", out))
        if not torch.is_tensor(out):
            out = torch.as_tensor(out)
        if out.ndim == 1:
            score = out[0]
        elif out.ndim == 2 and out.shape[1] == 1:
            score = out[0, 0]
        elif out.ndim == 2 and out.shape[1] >= 2:
            score = out[0, class_idx]
        else:
            score = out.reshape(-1)[0]
        score.backward(retain_graph=False)

        attn_list, grad_list = [], []
        for blk in self.vit.blocks:
            A = getattr(blk.attn, "_last_attn", None)
            if A is None:
                continue
            attn_list.append(A)
            grad_list.append(getattr(A, "grad", None))

        if len(attn_list) == 0:
            raise RuntimeError("No attention captured; check patching.")

        if hasattr(self.vit, "patch_embed") and hasattr(self.vit.patch_embed, "grid_size"):
            gh, gw = self.vit.patch_embed.grid_size
        else:
            Ttok = attn_list[0].shape[-1]
            n = Ttok - 1
            s = int(math.isqrt(n))
            gh = gw = s

        rollout = None
        for A, G in zip(attn_list, grad_list):
            A = A[0].mean(dim=0).mean(dim=0)          # [T,T]
            if G is not None:
                G = G[0].mean(dim=0).mean(dim=0)      # [T,T]
                M = (A * G).clamp(min=0)
            else:
                M = A.clamp(min=0)
            M = self._normalize_attn(M)
            rollout = M if rollout is None else torch.matmul(rollout, M)

        cls_to_patches = rollout[0, 1:].detach().cpu().numpy()
        cam = cls_to_patches.reshape(gh, gw)
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6)
        return cam

# -----------------------------------------------------------------------------
# 5) Dataset + meta join & TP filtering
# -----------------------------------------------------------------------------
def try_get_df_row(ds, idx):
    for attr in ["df", "meta", "ann", "items", "data"]:
        if hasattr(ds, attr):
            obj = getattr(ds, attr)
            if hasattr(obj, "iloc"):
                return obj.iloc[idx]
            if isinstance(obj, (list, tuple)) and len(obj) > idx:
                return obj[idx]
    return None

def extract_path(row):
    for c in ["pngfilename", "img_path", "filepath", "path"]:
        if row is not None and (c in row):
            return row[c]
    return None

def attach_imgpath(df_csv, ds):
    rows = []
    for _, r in df_csv.iterrows():
        meta = try_get_df_row(ds, int(r["idx"]))
        img = extract_path(meta)
        rows.append({**r.to_dict(), "img_path": img})
    return pd.DataFrame(rows)

def normalize_master(master_csv: str = MASTER_CSV) -> pd.DataFrame:
    ddd = pd.read_csv(master_csv)
    ddd['pngfilename'] = ddd['pngfilename'].str.replace('.', '_', regex=False)
    ddd['pngfilename'] = ddd['pngfilename'].str.replace('_png', '.png', regex=False)
    ddd['img_path'] = '/home/hch/opportunistic/AutoMorph_Data/Results/M0/images/' + ddd['pngfilename']
    return ddd[['img_path','eye_orientation','pngfilename']]

# -----------------------------------------------------------------------------
# 6) Export tensors (origin_images.pth & saliency_map.pth)
# -----------------------------------------------------------------------------
preprocess = Compose([Resize((IMG_SIZE, IMG_SIZE)), ToTensor()])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_tensor_01(img_path):
    img = Image.open(img_path).convert("RGB")
    t = preprocess(img)  # [3,H,W], 0..1
    return t

def ensure_path(p):
    if p and os.path.isabs(p):
        return p
    return os.path.join(WORK_DIR, p) if isinstance(p, str) else None

def compute_cam_tensor(rollout, model, img_path):
    full = ensure_path(img_path)
    if (not full) or (not os.path.exists(full)):
        return None, None
    x = load_tensor_01(full).unsqueeze(0).to(device, non_blocking=True)  # [1,3,H,W]
    cam_hw = rollout.compute_cam(model, x, class_idx=1)                  # [h',w'] in [0,1]
    cam_hw = cv2.resize((cam_hw*255).astype(np.uint8), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)
    cam_hw = torch.from_numpy(cam_hw).float().div(255.0)                 # [H,W] in [0,1]
    return x.squeeze(0).cpu(), cam_hw.unsqueeze(0)                       # [3,H,W], [1,H,W]

def export_tp_pack(tp_df: pd.DataFrame, out_dir: str, rollout: ViTGradCamRollout, model: nn.Module):
    os.makedirs(out_dir, exist_ok=True)
    ori_list, cam_list, meta_list = [], [], []
    for _, row in tqdm(tp_df.iterrows(), total=len(tp_df), desc=f"Export {Path(out_dir).name}"):
        p = row["img_path"]
        ori, cam = compute_cam_tensor(rollout, model, p)
        if ori is None:
            continue
        ori_list.append(ori.unsqueeze(0))  # [1,3,H,W]
        cam_list.append(cam.unsqueeze(0))  # [1,1,H,W]
        meta_list.append({
            "idx": int(row["idx"]),
            "img_path": ensure_path(p),
            "eye_orientation": row.get("eye_orientation", None),
            "pred": float(row["pred"]),
            "label_or_event": int(row.get("label", row.get("event", -1)))
        })
    if len(ori_list) == 0:
        print("No TP images to export:", out_dir); return False
    origin_images = torch.cat(ori_list, dim=0)        # [N,3,H,W]
    saliency_map  = torch.cat(cam_list, dim=0)        # [N,1,H,W]
    torch.save(origin_images.contiguous(), os.path.join(out_dir, "origin_images.pth"))
    torch.save(saliency_map.contiguous(),  os.path.join(out_dir, "saliency_map.pth"))
    with open(os.path.join(out_dir, "meta.json"), "w") as f:
        json.dump(meta_list, f, indent=2)
    print(f"[✓] Saved: {out_dir}/origin_images.pth, saliency_map.pth  (N={len(meta_list)})")
    return True

# -----------------------------------------------------------------------------
# 7) Rep Grad-CAM L/R overlay (same as your code)
# -----------------------------------------------------------------------------
import matplotlib.pyplot as plt

REP_ROOT = Path("./rep_gradcam_lr")
REP_ROOT.mkdir(exist_ok=True)

def _to_numpy_img(t3chw: torch.Tensor):
    return t3chw.detach().cpu().clamp(0,1).permute(1,2,0).numpy()

def _overlay_cam(rgb_hw3, cam_hw, alpha=0.45, cmap_name="jet"):
    cmap = plt.get_cmap(cmap_name)
    cam_rgb = cmap(cam_hw)[..., :3]
    out = (1 - alpha) * rgb_hw3 + alpha * cam_rgb
    return np.clip(out, 0, 1)

def _norm_eye_label(x):
    if x is None: return None
    s = str(x).strip().lower()
    if s in ["r","right","rt","오","오른","오른쪽","1","true","od","o.d"]:
        return "R"
    if s in ["l","left","lt","왼","왼쪽","0","false","os","o.s"]:
        return "L"
    try:
        return "R" if int(float(s)) == 1 else "L"
    except: return None

def _load_pack(src_dir: str):
    p = Path(src_dir)
    ori = torch.load(p/"origin_images.pth")
    cam = torch.load(p/"saliency_map.pth")
    with open(p/"meta.json","r") as f:
        metas = json.load(f)
    return ori, cam, metas

def _indices_by_eye(metas, eye_code):
    return [i for i,m in enumerate(metas) if _norm_eye_label(m.get("eye_orientation")) == eye_code]

def _avg_heatmap(cam: torch.Tensor, indices):
    if not indices: return None
    arr = cam[indices,0].detach().cpu().numpy()
    mean = arr.mean(axis=0)
    mn,mx = mean.min(), mean.max()
    if mx > mn: mean = (mean - mn)/(mx - mn)
    return mean

def _random_from(indices, seed=None):
    if not indices: return None
    rng = random.Random(seed)
    return rng.choice(indices)

def _save_img(array_hw3, fname, out_dir: Path):
    out_dir.mkdir(exist_ok=True, parents=True)
    plt.imsave(out_dir / fname, array_hw3); print(f"[✓] Saved: {out_dir/fname}")

def _save_heatmap(array_hw, fname, out_dir: Path, cmap="jet"):
    out_dir.mkdir(exist_ok=True, parents=True)
    plt.imsave(out_dir / fname, array_hw, cmap=cmap); print(f"[✓] Saved: {out_dir/fname}")

def save_heatmap_and_overlay(src_dir: str, model_desc: str, set_name="TEST", seed=0, cmap="jet"):
    ori, cam, metas = _load_pack(src_dir)
    idxs_L = _indices_by_eye(metas, "L")
    idxs_R = _indices_by_eye(metas, "R")
    avg_L = _avg_heatmap(cam, idxs_L)
    avg_R = _avg_heatmap(cam, idxs_R)
    rL = _random_from(idxs_L, seed)
    rR = _random_from(idxs_R, seed+1)

    out_dir = REP_ROOT / model_desc
    if avg_L is not None:
        _save_heatmap(avg_L, f"{set_name}_Left_heatmap.png", out_dir, cmap)
        if rL is not None:
            overlay_L = _overlay_cam(_to_numpy_img(ori[rL]), avg_L, alpha=0.45, cmap_name=cmap)
            _save_img(overlay_L, f"{set_name}_Left_overlay.png", out_dir)
    if avg_R is not None:
        _save_heatmap(avg_R, f"{set_name}_Right_heatmap.png", out_dir, cmap)
        if rR is not None:
            overlay_R = _overlay_cam(_to_numpy_img(ori[rR]), avg_R, alpha=0.45, cmap_name=cmap)
            _save_img(overlay_R, f"{set_name}_Right_overlay.png", out_dir)

# -----------------------------------------------------------------------------
# 8) Mask utilities + attention summaries (ODC split into 4 wedges added)
# -----------------------------------------------------------------------------
import torchvision.transforms.functional as TF

def ensure_bool_mask(mask_img: Image.Image, size=(IMG_SIZE, IMG_SIZE)) -> torch.Tensor:
    if not isinstance(mask_img, Image.Image):
        mask_img = Image.fromarray(np.asarray(mask_img))
    mask = mask_img.resize(size, resample=Image.NEAREST)
    mask_np = np.array(mask)
    mask_bin = (mask_np > 0).astype(np.uint8)
    return torch.from_numpy(mask_bin)

def load_mask_by_filename(root: Path, pngfilename: str) -> torch.Tensor | None:
    cand = root / pngfilename
    if cand.exists():
        return ensure_bool_mask(Image.open(cand))
    stem = Path(pngfilename).stem
    for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"):
        p = root / f"{stem}{ext}"
        if p.exists():
            return ensure_bool_mask(Image.open(p))
    return None

def load_optic_disc_cup_union(pngfilename: str) -> torch.Tensor | None:
    p = ODC_RAW_DIR / pngfilename
    if not p.exists():
        stem = Path(pngfilename).stem
        found = None
        for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"):
            cand = ODC_RAW_DIR / f"{stem}{ext}"
            if cand.exists():
                found = cand
                break
        if found is None:
            return None
        p = found
    try:
        img = Image.open(p).convert("RGB")
        arr = np.array(img)
        r = arr[..., 0]
        b = arr[..., 2]
        rb_union = ((r > 0) | (b > 0)).astype(np.uint8)
        return ensure_bool_mask(Image.fromarray(rb_union * 255))
    except Exception:
        return None

def split_odc_quadrants(odc: torch.Tensor):

    if odc is None or odc.ndim != 2:
        return None
    h, w = odc.shape
    ys, xs = torch.nonzero(odc > 0, as_tuple=True)
    if ys.numel() == 0:
        return None

    cy = ys.float().mean()
    cx = xs.float().mean()

    yy = torch.arange(h, device=odc.device).view(-1, 1).expand(h, w).float()
    xx = torch.arange(w, device=odc.device).view(1, -1).expand(h, w).float()

    # 두 대각선: y-cy =  (x-cx),  y-cy = -(x-cx)
    sign1 = (yy - cy) - (xx - cx)   # <0 / >0
    sign2 = (yy - cy) + (xx - cx)   # <0 / >0

    superior = ((sign1 < 0) & (sign2 < 0) & (odc > 0)).to(torch.uint8)
    inferior = ((sign1 > 0) & (sign2 > 0) & (odc > 0)).to(torch.uint8)
    right_w  = ((sign1 < 0) & (sign2 > 0) & (odc > 0)).to(torch.uint8)
    left_w   = ((sign1 > 0) & (sign2 < 0) & (odc > 0)).to(torch.uint8)

    return {
        "superior": superior,
        "inferior": inferior,
        "right": right_w,
        "left": left_w,
        "cx": float(cx.item()),
        "cy": float(cy.item()),
    }

def make_fov_mask(h=IMG_SIZE, w=IMG_SIZE) -> torch.Tensor:
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h // 2, w // 2
    r = min(h, w) // 2
    fov = (yy - cy)**2 + (xx - cx)**2 <= (r**2)
    return torch.from_numpy(fov.astype(np.uint8))

def normalize_gradcam_into_fov(gradcam_t: torch.Tensor, fov_mask: torch.Tensor) -> torch.Tensor | None:
    if gradcam_t.dim() == 3:
        if gradcam_t.size(0) in (1,3):
            gradcam_t = gradcam_t[0, ...] if gradcam_t.size(0) > 0 else gradcam_t.squeeze(0)
        else:
            gradcam_t = gradcam_t.squeeze(0)
    elif gradcam_t.dim() == 4:
        gradcam_t = gradcam_t.squeeze()
    gc = TF.to_pil_image(gradcam_t.float())
    gc = gc.resize((IMG_SIZE, IMG_SIZE), resample=Image.BILINEAR)
    gc = TF.pil_to_tensor(gc).squeeze().float()
    gc = gc * fov_mask
    s = gc.sum().item()
    if s <= 0:
        return None
    gc = 100.0 * gc / s
    return gc

def region_mean(gc_norm100: torch.Tensor, region_mask01: torch.Tensor) -> float:
    denom = region_mask01.sum().item()
    if denom <= 0: return float("nan")
    return (gc_norm100 * region_mask01).sum().item() / denom

def compute_metrics_for_dataset(df_csv: Path, saliency_pth: Path, out_dir: Path) -> pd.DataFrame:
    out_dir.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(df_csv)
    if "pngfilename" not in df.columns:
        raise ValueError(f"'pngfilename' column not found in {df_csv}")
    gradcams = torch.load(saliency_pth)
    n = len(gradcams) if hasattr(gradcams, "__len__") else int(gradcams.size(0))
    df = df.reset_index(drop=True)
    if len(df) != n:
        m = min(len(df), n)
        df = df.iloc[:m].reset_index(drop=True)
        if hasattr(gradcams, "__len__"):
            gradcams = gradcams[:m]
        else:
            gradcams = gradcams[:m, ...]
    records = []
    fov = make_fov_mask()
    for idx in tqdm(range(len(df)), desc=f"Computing masks for {df_csv.stem}", total=len(df)):
        row = df.iloc[idx]
        pngfilename = str(row["pngfilename"])
        artery = load_mask_by_filename(ARTERY_DIR, pngfilename)
        vein   = load_mask_by_filename(VEIN_DIR,   pngfilename)
        if artery is None or vein is None:
            continue
        odc = load_optic_disc_cup_union(pngfilename)  # may be None
        both = (artery | vein).clamp_(0, 1)
        not_both = ((1 - both) * fov).clamp_(0, 1)
        if odc is not None:
            odc = (odc * fov).clamp_(0, 1)
            not_odc = ((1 - odc) * fov).clamp_(0, 1)
            not_odc_not_both = ((1 - ((both | odc).clamp_(0, 1))) * fov).clamp_(0, 1)
            quads = split_odc_quadrants(odc)
        else:
            not_odc = None; not_odc_not_both = None; quads = None

        gc = normalize_gradcam_into_fov(gradcams[idx], fov)
        if gc is None:
            continue

        vein_ = (vein * fov).clamp_(0, 1);    not_vein = ((1 - vein) * fov).clamp_(0, 1)
        artery_ = (artery * fov).clamp_(0, 1);not_artery = ((1 - artery) * fov).clamp_(0, 1)
        both_   = (both * fov).clamp_(0, 1);  not_both_ = not_both

        eye_code = _norm_eye_label(row.get("eye_orientation", None))

        od_temporal = float("nan")
        od_nasal    = float("nan")
        od_sup      = float("nan")
        od_inf      = float("nan")

        if (gc is not None) and (quads is not None):
            # superior / inferior
            od_sup = region_mean(gc, quads["superior"])
            od_inf = region_mean(gc, quads["inferior"])

            # temporal / nasal 
            if eye_code == "L":
                temporal_mask = quads["right"]
                nasal_mask    = quads["left"]
            elif eye_code == "R":
                temporal_mask = quads["left"]
                nasal_mask    = quads["right"]
            else:
                temporal_mask = None
                nasal_mask    = None

            if temporal_mask is not None:
                od_temporal = region_mean(gc, temporal_mask)
            if nasal_mask is not None:
                od_nasal = region_mean(gc, nasal_mask)

        result = {
            "pngfilename": pngfilename,
            "veins_n":               region_mean(gc, vein_),
            "not_veins_n":           region_mean(gc, not_vein),
            "arteries_n":            region_mean(gc, artery_),
            "not_arteries_n":        region_mean(gc, not_artery),
            "both_n":                region_mean(gc, both_),
            "not_both_n":            region_mean(gc, not_both_),
            "has_optic_disc":        bool(odc is not None),
            "optic_disc_n":          float("nan") if odc is None else region_mean(gc, odc),
            "not_optic_disc_n":      float("nan") if odc is None else region_mean(gc, not_odc),
            "not_optic_disc_not_both_n": float("nan") if odc is None else region_mean(gc, not_odc_not_both),
            "optic_disc_temporal_n": od_temporal,
            "optic_disc_nasal_n":    od_nasal,
            "optic_disc_superior_n": od_sup,
            "optic_disc_inferior_n": od_inf,
        }
        records.append(result)
    res = pd.DataFrame.from_records(records)
    out_csv = out_dir / f"RESULTS_{df_csv.stem.replace('_df','')}.csv"
    res.to_csv(out_csv, index=False)
    print(f"[Saved] {out_csv}  (n={len(res)})")
    if len(res):
        cols = ["veins_n","not_veins_n","arteries_n","not_arteries_n","both_n","not_both_n",
                "optic_disc_n","not_optic_disc_n","not_optic_disc_not_both_n",
                "optic_disc_temporal_n","optic_disc_nasal_n","optic_disc_superior_n","optic_disc_inferior_n"]
        summary = res[cols].mean(numeric_only=True)
        print("Column means:\n", summary.to_string())
    return res

def _safe_ratio(num: pd.Series, den: pd.Series) -> pd.Series:
    num = pd.to_numeric(num, errors="coerce"); den = pd.to_numeric(den, errors="coerce")
    ratio = num / den
    ratio[(~np.isfinite(den)) | (den <= 0)] = np.nan
    return ratio

def bootstrap_ci(data: np.ndarray, n_boot: int = 2000, ci: int = 95, seed: int = 42):
    data = np.asarray(data); data = data[np.isfinite(data)]
    if data.size == 0: return (np.nan, np.nan)
    rng = np.random.default_rng(seed); means = np.empty(n_boot); n = data.size
    for i in range(n_boot):
        sample = rng.choice(data, size=n, replace=True)
        means[i] = np.nanmean(sample)
    alpha = (100 - ci) / 2.0
    lower = np.percentile(means, alpha); upper = np.percentile(means, 100 - alpha)
    return (lower, upper)

def summarize_attention(res_df: pd.DataFrame, label: str = "TEST", n_boot: int = 2000) -> pd.DataFrame:
    ref = res_df["not_optic_disc_not_both_n"]
    ratios = pd.DataFrame({
        "vascular_attention_vein":   _safe_ratio(res_df["veins_n"],        ref),
        "vascular_attention_artery": _safe_ratio(res_df["arteries_n"],      ref),
        "vascular_attention_both":   _safe_ratio(res_df["both_n"],          ref),
        "optic_disc_attention":      _safe_ratio(res_df["optic_disc_n"],    ref),

        "optic_disc_attention_temporal": _safe_ratio(res_df["optic_disc_temporal_n"], ref),
        "optic_disc_attention_nasal":    _safe_ratio(res_df["optic_disc_nasal_n"],     ref),
        "optic_disc_attention_superior": _safe_ratio(res_df["optic_disc_superior_n"],  ref),
        "optic_disc_attention_inferior": _safe_ratio(res_df["optic_disc_inferior_n"],  ref),
    })
    mean_vals = ratios.mean(numeric_only=True)
    results = {}
    for col in ratios.columns:
        ci_low, ci_high = bootstrap_ci(ratios[col].values, n_boot=n_boot, ci=95)
        results[col] = {
            "mean": float(mean_vals[col]),
            "95%_CI_low": float(ci_low),
            "95%_CI_high": float(ci_high),
            "n_used": int(np.isfinite(ratios[col].values).sum())
        }
    out = pd.DataFrame(results).T
    print(f"\n=== [{label}] Attention ratios (ref = not_optic_disc_not_both_n) ===")
    print(mean_vals.to_string())
    return ratios, out

In [None]:
# -----------------------------------------------------------------------------
# 9) MAIN LOOP over all combos
# -----------------------------------------------------------------------------
master_norm = normalize_master(MASTER_CSV)

skip_n = 0  # 처음 2개 조합만 스킵
count = 0

for arch in ARCHES:
    for ft in FTS:
        if count < skip_n:
            print(f"[SKIP {count+1}/{skip_n}] arch={arch}, ft={ft}")
            count += 1
            continue
        
        row = pick_threshold_row(arch, ft)
        if row is None:
            print(f"[SKIP] No threshold row in youden_thresholds.csv for: arch={arch}, ft={ft}")
            continue

        # derive model_desc from the selected row (captures lora_rank if present)
        desc = model_desc_from_csv_row(row)

        # re-check ft_blks policy to ensure alignment with requested mapping
        if ft == 'partial' and str(row['ft_blks']) != '4':
            print(f"[SKIP] Found threshold with ft_blks={row['ft_blks']} but need 4: {arch}|{ft}")
            continue
        if ft == 'lora' and str(row['ft_blks']) != 'full':
            print(f"[SKIP] Found threshold with ft_blks={row['ft_blks']} but need full: {arch}|{ft}")
            continue

        BEST_THRESH = float(row['youden_thr'])
        AUC_USED    = float(row.get('auc', float('nan')))
        LORA_RANK   = row.get('lora_rank', None)

        ckpt = ckpt_path(desc)
        infer = infer_dir(desc)
        OUT_ROOT_TEST, OUT_ROOT_FUTURE = tensors_dirs(desc)

        print("\n" + "="*80)
        print(f"[{desc}]  arch={arch}  ft={ft}  ft_blks={row['ft_blks']}  lora_rank={LORA_RANK}  AUC={AUC_USED:.3f}")
        print(f"  THRESH={BEST_THRESH:.6f}")
        print(f"  CKPT  : {ckpt}")
        print(f"  INFER : {infer}")
        print("="*80)

        # ---- Build & load model
        model = build_model_for(arch, ft, IMG_SIZE, enable_amp=True, lora_rank=LORA_RANK).to(device).eval()
        sd = torch.load(ckpt, map_location="cpu")
        ok = safe_load_state_dict(model, sd)
        if not ok:
            print(f"[SKIP] Failed to load weights from: {ckpt}")
            continue

        # ---- Patch attention & rollout
        vit = find_vit_backbone(model)
        if vit is None:
            print(f"[SKIP] ViT backbone with 'blocks' not found in model: {desc}")
            continue
        patch_attention_to_capture(vit)
        rollout = ViTGradCamRollout(vit)

        # ---- Load datasets & infer CSVs
        test_ds = DementiaDetectionDataset(kind="test", img_sz=IMG_SIZE)
        fut_ds  = DementiaPredictionDataset(img_sz=IMG_SIZE)

        test_csv = Path(infer) / "test_preds.csv"          # [idx, label, pred]
        fut_csv  = Path(infer) / "prediction_preds.csv"    # [idx, pred, event, obs_time]
        if not test_csv.exists() or not fut_csv.exists():
            print(f"[SKIP] Missing infer CSVs under {infer}")
            continue
        test_csv = pd.read_csv(test_csv)
        fut_csv  = pd.read_csv(fut_csv)

        # ---- attach image paths
        test_df = attach_imgpath(test_csv, test_ds)
        fut_df  = attach_imgpath(fut_csv,  fut_ds)

        # ---- merge eye_orientation
        test_df = pd.merge(test_df, master_norm, on='img_path', how='inner')
        fut_df  = pd.merge(fut_df,  master_norm, on='img_path', how='inner')

        # ---- filter TP by BEST_THRESH
        test_tp = test_df.query("pred >= @BEST_THRESH and label == 1").reset_index(drop=True)
        fut_tp  = fut_df .query("pred >= @BEST_THRESH and event == 1").reset_index(drop=True)
        print(f"TP counts — TEST: {len(test_tp)} | FUTURE: {len(fut_tp)}")

        # ---- export origin_images.pth & saliency_map.pth
        ok_test = export_tp_pack(test_tp, OUT_ROOT_TEST, rollout, model)
        ok_fut  = export_tp_pack(fut_tp,  OUT_ROOT_FUTURE, rollout, model)

        # ---- save the TP CSVs (like the base code)
        if len(test_tp): test_tp.to_csv(Path(OUT_ROOT_TEST)/"test_df.csv", index=False)
        if len(fut_tp):  fut_tp.to_csv(Path(OUT_ROOT_FUTURE)/"fut_df.csv", index=False)

        # ---- representative Grad-CAM images (L/R heatmap & overlay)
        if ok_test:   save_heatmap_and_overlay(OUT_ROOT_TEST,   desc, set_name="TEST",   seed=123)
        if ok_fut:    save_heatmap_and_overlay(OUT_ROOT_FUTURE, desc, set_name="FUTURE", seed=456)

        # ---- Mask metrics + Attention summaries
        if ok_test:
            test_df_csv   = Path(OUT_ROOT_TEST) / "test_df.csv"
            test_saliency = Path(OUT_ROOT_TEST) / "saliency_map.pth"
            test_out_dir  = Path(OUT_ROOT_TEST) / "Mask"
            test_results  = compute_metrics_for_dataset(test_df_csv, test_saliency, test_out_dir)
            test_ratios, test_summary = summarize_attention(test_results, label=f"TEST:{desc}", n_boot=2000)
            test_summary.to_csv(test_out_dir / "ATTN_summary_TEST.csv")
        if ok_fut:
            fut_df_csv   = Path(OUT_ROOT_FUTURE) / "fut_df.csv"
            fut_saliency = Path(OUT_ROOT_FUTURE) / "saliency_map.pth"
            fut_out_dir  = Path(OUT_ROOT_FUTURE) / "Mask"
            fut_results  = compute_metrics_for_dataset(fut_df_csv, fut_saliency, fut_out_dir)
            fut_ratios,  fut_summary  = summarize_attention(fut_results,  label=f"FUTURE:{desc}", n_boot=2000)
            fut_summary.to_csv(fut_out_dir / "ATTN_summary_FUTURE.csv")

print("\n✅ Done for all available combos with thresholds resolved from youden_thresholds.csv")

In [None]:
# %% [markdown]
# Build nicely formatted tables from saved ATTN_summary files
from pathlib import Path
import pandas as pd
import numpy as np
import re

ROOT = Path(".")

def glob_summaries(kind: str):
    # kind: "TEST" (detection), "FUTURE" (prediction)
    if kind == "TEST":
        pattern = "tensors/*_test/Mask/ATTN_summary_TEST.csv"
    else:
        pattern = "tensors/*_future/Mask/ATTN_summary_FUTURE.csv"
    return sorted([Path(p) for p in ROOT.glob(pattern)])

def parse_model_and_ft(desc: str):
    arch_map = {
        "retfound": "RETFound",
        "mae": "MAE",
        "openclip": "OpenCLIP",
        "dinov2": "DINOv2",
        "dinov3": "DINOv3",
        "retfound_dinov2": "RETFound_dinov2",
    }
    d = desc.lower()
    arch = None
    for k in arch_map.keys():
        if d.startswith(k):
            arch = k
            break
    if desc.startswith("retfound_dinov2"):
        model = arch_map.get("retfound_dinov2", "RETFound_dinov2")
    else:
        base_name = desc.split("_")[0].upper() if arch is None else arch.upper()
        model = arch_map.get(arch, base_name)
        
    if "_partial_ft_4" in d:
        ft_method = "partial_ft4"
    elif "_lora_" in d and "_ft_full" in d:
        m = re.search(r"lora_rank_(\d+)_ft_full", d)
        ft_method = f"lora_full(r{m.group(1)})" if m else "lora_full"
    else:
        ft_method = "linear"
    return model, ft_method

def format_ci(mean, low, high):
    def r2(x):
        try: return f"{float(x):.2f}"
        except: return "nan"
    return f"{r2(mean)} ({r2(low)} - {r2(high)})"

def one_row_from_summary(path: Path):
    parent = path.parent.parent  # .../{MODEL_DESC}_{test|future}/Mask
    desc_dir = parent.name       # ex) retfound_partial_ft_4_test
    if desc_dir.endswith("_test"):
        desc = desc_dir[:-5]
    elif desc_dir.endswith("_future"):
        desc = desc_dir[:-7]
    else:
        desc = desc_dir

    model, ft_method = parse_model_and_ft(desc)
    df = pd.read_csv(path, index_col=0)

    def get_fmt(row_name):
        if (row_name not in df.index) or any(col not in df.columns for col in ["mean","95%_CI_low","95%_CI_high"]):
            return "nan (nan - nan)"
        m  = df.loc[row_name, "mean"]
        lo = df.loc[row_name, "95%_CI_low"]
        hi = df.loc[row_name, "95%_CI_high"]
        return format_ci(m, lo, hi)

    row = {
        "Model": model,
        "Finetuning Method": ft_method,
        "Non-(vessel or optic disc)": "1.00 (reference)",
        "Artery": get_fmt("vascular_attention_artery"),
        "Vein": get_fmt("vascular_attention_vein"),
        "Vessel (Artery + Vein)": get_fmt("vascular_attention_both"),
        "Optic disc (Total)": get_fmt("optic_disc_attention"),
        "Optic disc (Temporal)": get_fmt("optic_disc_attention_temporal"),
        "Optic disc (Nasal)":    get_fmt("optic_disc_attention_nasal"),
        "Optic disc (Superior)": get_fmt("optic_disc_attention_superior"),
        "Optic disc (Inferior)": get_fmt("optic_disc_attention_inferior"),
        "_sort_key": desc,  # stable ordering
    }
    return row

def build_table(kind: str) -> pd.DataFrame:
    rows = []
    for p in glob_summaries(kind):
        try:
            rows.append(one_row_from_summary(p))
        except Exception as e:
            print(f"[WARN] Failed to read {p}: {e}")
    cols = [
        "Model","Finetuning Method","Non-(vessel or optic disc)",
        "Artery","Vein","Vessel (Artery + Vein)",
        "Optic disc (Total)",
        "Optic disc (Temporal)","Optic disc (Nasal)","Optic disc (Superior)","Optic disc (Inferior)"
    ]
    if not rows:
        return pd.DataFrame(columns=cols)
    return (pd.DataFrame(rows)
              .sort_values("_sort_key")
              .drop(columns=["_sort_key"]))[cols]

# === Build tables ===
detection_df  = build_table("TEST")    # Dementia detection
prediction_df = build_table("FUTURE")  # Dementia prediction

try:
    from caas_jupyter_tools import display_dataframe_to_user
    display_dataframe_to_user("Dementia detection — Relative ratio table", detection_df)
    display_dataframe_to_user("Dementia prediction — Relative ratio table", prediction_df)
except Exception:
    pass 

print("\n=== Dementia detection (Relative ratio, 95% CI) ===")
print(detection_df.to_string(index=False))
print("\n=== Dementia prediction (Relative ratio, 95% CI) ===")
print(prediction_df.to_string(index=False))


In [None]:
detection_df.to_csv('./summary_tables/ATTN_summary_detection.csv')
prediction_df.to_csv('./summary_tables/ATTN_summary_prediction.csv')

## colorbar

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# --------------------------------------------------------------------
# Colorbar only (same as 'jet' colormap used in saliency)
# --------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(2.0, 10.5))

# 0~1 범위의 더미 데이터 (값 없이 색상만 표시)
gradient = np.linspace(0, 1, 256).reshape(-1, 1)
#im = ax.imshow(gradient, aspect='auto', cmap='jet', origin='lower')

# 축 숨기기
ax.set_axis_off()

# 세로 colorbar 추가
cbar = plt.colorbar(im, ax=ax, fraction=0.5, pad=2)
#cbar.set_label("Saliency intensity", fontsize=12)
cbar.ax.tick_params(labelsize=10)
cbar.set_ticks([])
cbar.outline.set_visible(False)

plt.tight_layout()
plt.show()
