## GRAD-CAM heatmap and overlay image for manuscript

In [None]:
# %% [markdown]
# # Rep Grad-CAM L/R export from saved tensors (manuscript version)
# - Assumes the following are ALREADY saved under ./tensors/{MODEL_DESC}_test and ./tensors/{MODEL_DESC}_future:
#     * origin_images.pth, saliency_map.pth, test_df.csv, fut_df.csv
# - Saves FOV-masked avg heatmaps and overlays to:
#     * ./rep_gradcam_lr_manuscript/{MODEL_DESC}/{SET}_{Left|Right}_{heatmap|overlay}.png
# - Uses eye_orientation from test_df.csv / fut_df.csv (no meta.json required)

import os, random, warnings
warnings.filterwarnings("ignore")
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

# -----------------------
# Config
# -----------------------
IMG_SIZE = 448
SEED     = 123

MODEL_DESCS = [
    "retfound_partial_ft_4",
    # "dinov3_partial_ft_4",
    # "retfound_dinov2_linear",
    # "openclip_linear",
    # "mae_linear",
    # "dinov2_linear",
    # # ...
]

TENSORS_ROOT = Path("./tensors")
REP_ROOT     = Path("./rep_gradcam_lr_manuscript")  
REP_ROOT.mkdir(parents=True, exist_ok=True)

# -----------------------
# Utils
# -----------------------
def _norm_eye_label(x):
    if x is None: return None
    s = str(x).strip().lower()
    if s in ["r","right","rt","오","오른","오른쪽","1","true","od","o.d"]:
        return "R"
    if s in ["l","left","lt","왼","왼쪽","0","false","os","o.s"]:
        return "L"
    try:
        return "R" if int(float(s)) == 1 else "L"
    except:
        return None

def make_fov_mask(h=IMG_SIZE, w=IMG_SIZE):
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h // 2, w // 2
    r = min(h, w) // 2
    fov = (yy - cy)**2 + (xx - cx)**2 <= (r**2)
    return torch.from_numpy(fov.astype(np.uint8))  # {0,1}

def _ensure_hw01(arr_like):
    if isinstance(arr_like, torch.Tensor):
        x = arr_like.detach().cpu().float()
    else:
        x = torch.from_numpy(np.asarray(arr_like)).float()
    # 정규화 (최소-최대)
    mn, mx = float(x.min()), float(x.max())
    if mx > mn:
        x = (x - mn) / (mx - mn)
    else:
        x = x * 0.0
    return x

def _apply_fov_and_renorm(cam_hw_01: torch.Tensor, fov01: torch.Tensor):
    cam = (cam_hw_01 * fov01)
    inside = cam[fov01 > 0]
    if inside.numel() == 0:
        return cam_hw_01  
    mn, mx = float(inside.min()), float(inside.max())
    if mx > mn:
        cam = (cam - mn) / (mx - mn + 1e-6)
    return cam.clamp_(0, 1)

def _to_numpy_img(t3chw: torch.Tensor):
    # expects [3,H,W], 0..1
    return t3chw.detach().cpu().clamp(0,1).permute(1,2,0).numpy()

def _overlay_cam(rgb_hw3, cam_hw01, alpha=0.45, cmap_name="jet"):
    cmap = plt.get_cmap(cmap_name)
    cam_rgb = cmap(cam_hw01.detach().cpu().numpy())[..., :3]  # [H,W,3]
    out = (1 - alpha) * rgb_hw3 + alpha * cam_rgb
    return np.clip(out, 0, 1)

def _save_img(array_hw3, fname, out_dir: Path):
    out_dir.mkdir(exist_ok=True, parents=True)
    plt.imsave(out_dir / fname, array_hw3)
    print(f"[✓] Saved: {out_dir/fname}")

def _save_heatmap(array_hw01, fname, out_dir: Path, cmap="jet"):
    out_dir.mkdir(exist_ok=True, parents=True)
    plt.imsave(out_dir / fname, array_hw01, cmap=cmap, vmin=0.0, vmax=1.0)
    print(f"[✓] Saved: {out_dir/fname}")

def _avg_heatmap(cam_stack: torch.Tensor, indices, fov01: torch.Tensor):

    if not indices: 
        return None
    if cam_stack.dim() == 4:
        cams = cam_stack[indices, 0]  # [k,H,W]
    else:
        cams = cam_stack[indices]     # [k,H,W]
    cams01 = torch.stack([_ensure_hw01(c) for c in cams], dim=0)  # [k,H,W]
    cams01 = cams01 * fov01  
    mean_hw = cams01.mean(dim=0)  # [H,W]
    mean_hw = _apply_fov_and_renorm(mean_hw, fov01)
    return mean_hw

def _random_from(indices, seed=None):
    if not indices: return None
    rng = random.Random(seed)
    return rng.choice(indices)

def _load_pack_and_df(src_dir: Path):

    origin_p = src_dir / "origin_images.pth"
    salmap_p = src_dir / "saliency_map.pth"

    test_csv = src_dir / "test_df.csv"
    fut_csv  = src_dir / "fut_df.csv"
    if test_csv.exists():
        df = pd.read_csv(test_csv)
        set_name = "TEST"
    elif fut_csv.exists():
        df = pd.read_csv(fut_csv)
        set_name = "FUTURE"
    else:
        raise FileNotFoundError(f"No test_df.csv or fut_df.csv under {src_dir}")

    assert origin_p.exists() and salmap_p.exists(), f"Missing tensor files in {src_dir}"
    origin = torch.load(origin_p) 
    salmap = torch.load(salmap_p)  

    n = salmap.size(0)
    if len(df) != n:
        m = min(len(df), n)
        origin = origin[:m]
        salmap = salmap[:m]
        df = df.iloc[:m].reset_index(drop=True)
    return origin, salmap, df, set_name

def _indices_by_eye_from_df(df: pd.DataFrame, eye_code: str):
    if "eye_orientation" not in df.columns:
        return list(range(len(df)))  
    idxs = []
    for i, v in enumerate(df["eye_orientation"].tolist()):
        if _norm_eye_label(v) == eye_code:
            idxs.append(i)
    return idxs

def save_heatmap_and_overlay_with_fov(src_dir: Path, model_desc: str, seed=SEED, cmap="jet"):
    origin, salmap, df, set_name = _load_pack_and_df(src_dir)

    idxs_L = _indices_by_eye_from_df(df, "L")
    idxs_R = _indices_by_eye_from_df(df, "R")

    fov01 = make_fov_mask(IMG_SIZE, IMG_SIZE).float()

    avg_L = _avg_heatmap(salmap, idxs_L, fov01)
    avg_R = _avg_heatmap(salmap, idxs_R, fov01)

    rL = _random_from(idxs_L, seed)
    rR = _random_from(idxs_R, seed)

    out_dir = REP_ROOT / model_desc
    out_dir.mkdir(parents=True, exist_ok=True)

    if avg_L is not None:
        _save_heatmap(avg_L.detach().cpu().numpy(), f"{set_name}_Left_heatmap.png", out_dir, cmap=cmap)
        if rL is not None:
            rgb = _to_numpy_img(origin[rL])  # [H,W,3], 0..1
            ov  = _overlay_cam(rgb, avg_L, alpha=0.45, cmap_name=cmap)
            _save_img(ov, f"{set_name}_Left_overlay.png", out_dir)

    if avg_R is not None:
        _save_heatmap(avg_R.detach().cpu().numpy(), f"{set_name}_Right_heatmap.png", out_dir, cmap=cmap)
        if rR is not None:
            rgb = _to_numpy_img(origin[rR])  # [H,W,3], 0..1
            ov  = _overlay_cam(rgb, avg_R, alpha=0.45, cmap_name=cmap)
            _save_img(ov, f"{set_name}_Right_overlay.png", out_dir)

# -----------------------
# Main
# -----------------------
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    for desc in MODEL_DESCS:
        test_dir   = TENSORS_ROOT / f"{desc}_test"
        future_dir = TENSORS_ROOT / f"{desc}_future"

        print("\n" + "="*80)
        print(f"[{desc}]")
        if test_dir.exists():
            print(f" - Processing TEST:   {test_dir}")
            save_heatmap_and_overlay_with_fov(test_dir,   model_desc=desc, seed=SEED)
        else:
            print(" - TEST dir not found (skip)")
        

        if future_dir.exists():
            print(f" - Processing FUTURE: {future_dir}")
            save_heatmap_and_overlay_with_fov(future_dir, model_desc=desc, seed=SEED+1000)
        else:
            print(" - FUTURE dir not found (skip)")
            

    print("\n✅ Done. Saved FOV-masked heatmaps & overlays to ./rep_gradcam_lr_manuscript/{MODEL_DESC}")

SEED     = 0

if __name__ == "__main__":
    main()


## individual saliency

In [None]:
# %% [markdown]
# # Single-image Grad-CAM: heatmap (left) + overlay (right)
# - MODE: "detection" (TEST) or "prediction" (FUTURE)
# - Loads tensors once per (MODEL_DESC, MODE) and caches for repeated plotting
# - Shows figure (1 row x 2 cols) AND saves 2 PNG files (heatmap / overlay)

import os, random, warnings
warnings.filterwarnings("ignore")
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

# -----------------------
# Config
# -----------------------
IMG_SIZE   = 448
TENSORS_ROOT = Path("./tensors")
OUT_ROOT     = Path("./rep_gradcam_lr_manuscript_single")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

# -----------------------
# Cache (for repeated use in notebook)
# -----------------------
_CACHE = {}  # key: (model_desc, MODE) -> dict(origin, salmap, df, set_name, fov01)

# -----------------------
# Utils
# -----------------------
def _norm_eye_label(x):
    if x is None: return None
    s = str(x).strip().lower()
    if s in ["r","right","rt","오","오른","오른쪽","1","true","od","o.d"]:
        return "R"
    if s in ["l","left","lt","왼","왼쪽","0","false","os","o.s"]:
        return "L"
    try:
        return "R" if int(float(s)) == 1 else "L"
    except:
        return None

def _ensure_hw01(arr_like):
    x = torch.as_tensor(arr_like, dtype=torch.float32).detach().cpu()
    mn, mx = float(x.min()), float(x.max())
    if mx > mn:
        x = (x - mn) / (mx - mn)
    else:
        x = x * 0.0
    return x

def _make_fov_mask(h=IMG_SIZE, w=IMG_SIZE):
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h // 2, w // 2
    r = min(h, w) // 2
    fov = (yy - cy)**2 + (xx - cx)**2 <= (r**2)
    return torch.from_numpy(fov.astype(np.uint8)).float()  # 0/1 float

def _apply_fov_and_renorm(cam_hw01: torch.Tensor, fov01: torch.Tensor):
    cam = cam_hw01 * fov01
    inside = cam[fov01 > 0]
    if inside.numel() == 0:
        return cam_hw01
    mn, mx = float(inside.min()), float(inside.max())
    if mx > mn:
        cam = (cam - mn) / (mx - mn + 1e-6)
    return cam.clamp_(0, 1)

def _overlay_cam(rgb_hw3: np.ndarray, cam_hw01: torch.Tensor, alpha=0.45, cmap_name="jet"):
    cmap = plt.get_cmap(cmap_name)
    cam_rgb = cmap(cam_hw01.detach().cpu().numpy())[..., :3]  # [H,W,3]
    out = (1 - alpha) * rgb_hw3 + alpha * cam_rgb
    return np.clip(out, 0, 1)

def _to_numpy_rgb(t3chw: torch.Tensor):
    return t3chw.detach().cpu().clamp(0,1).permute(1,2,0).numpy()

def _indices_by_eye(df: pd.DataFrame, eye_code: str):
    if "eye_orientation" not in df.columns:
        return list(range(len(df)))
    idxs = []
    for i, v in enumerate(df["eye_orientation"].tolist()):
        if _norm_eye_label(v) == eye_code:
            idxs.append(i)
    return idxs

def _load_pack(model_desc: str, MODE: str):
    """
    MODE: "detection" -> _test, uses test_df.csv
          "prediction" -> _future, uses fut_df.csv
    """
    key = (model_desc, MODE)
    if key in _CACHE:
        return _CACHE[key]

    if MODE not in {"detection", "prediction"}:
        raise ValueError("MODE must be 'detection' or 'prediction'")

    sub = f"{model_desc}_test" if MODE == "detection" else f"{model_desc}_future"
    src_dir = TENSORS_ROOT / sub

    origin_p = src_dir / "origin_images.pth"
    salmap_p = src_dir / "saliency_map.pth"
    csv_p    = src_dir / ("test_df.csv" if MODE == "detection" else "fut_df.csv")

    if not (origin_p.exists() and salmap_p.exists() and csv_p.exists()):
        raise FileNotFoundError(f"Missing files under: {src_dir}")

    origin = torch.load(origin_p)  # [N,3,H,W], 0..1
    salmap = torch.load(salmap_p)  # [N,1,H,W]
    df     = pd.read_csv(csv_p)

    # length alignment
    n = min(len(df), salmap.size(0), origin.size(0))
    origin, salmap, df = origin[:n], salmap[:n], df.iloc[:n].reset_index(drop=True)

    set_name = "TEST" if MODE == "detection" else "FUTURE"
    fov01 = _make_fov_mask(IMG_SIZE, IMG_SIZE)

    pack = {"origin": origin, "salmap": salmap, "df": df, "set_name": set_name, "fov01": fov01, "src_dir": src_dir}
    _CACHE[key] = pack
    return pack

# -----------------------
# Main render function
# -----------------------
def render_one(
    model_desc: str,
    MODE: str = "detection",         # "detection" or "prediction"
    EYE: str = "L",                   # "L" or "R"
    SEED: int = 0,
    idx: int | None = None,           # 특정 index를 지정하고 싶으면 설정, 아니면 시드 기반 랜덤
    cmap: str = "jet",
    alpha: float = 0.45,
    out_prefix: str | None = None     # 파일명 접두사를 바꾸고 싶으면 지정
):

    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    pack = _load_pack(model_desc, MODE)
    origin, salmap, df, set_name, fov01 = pack["origin"], pack["salmap"], pack["df"], pack["set_name"], pack["fov01"]

    eye_code = "L" if str(EYE).upper().startswith("L") else "R"
    candidate = _indices_by_eye(df, eye_code)

    if len(candidate) == 0:
        raise ValueError(f"No indices for eye={eye_code} in df['eye_orientation'].")

    if idx is None:
        rng = random.Random(SEED)
        idx = rng.choice(candidate)
    else:
        if idx not in candidate:
            raise ValueError(f"idx={idx} is not in the {eye_code}-eye subset.")

    # --- pick tensors ---
    rgb  = _to_numpy_rgb(origin[idx])                  # [H,W,3], in 0..1
    cam0 = salmap[idx, 0]                              # [H,W]
    cam01 = _ensure_hw01(cam0)
    cam01 = _apply_fov_and_renorm(cam01, fov01)        # FOV 적용 후 재정규화

    overlay = _overlay_cam(rgb, cam01, alpha=alpha, cmap_name=cmap)

    # --- figure: left=heatmap, right=overlay ---
    fig, axes = plt.subplots(1, 2, figsize=(10, 5), dpi=160)
    ax0, ax1 = axes

    im0 = ax0.imshow(cam01.detach().cpu().numpy(), cmap=cmap, vmin=0, vmax=1)
    ax0.set_title("Heatmap", fontsize=16)
    ax0.axis("off")

    ax1.imshow(overlay)
    ax1.set_title("Overlay", fontsize=16)
    ax1.axis("off")

    plt.tight_layout()
    plt.show()

    # --- save PNGs ---
    out_dir = OUT_ROOT / model_desc
    out_dir.mkdir(parents=True, exist_ok=True)

    base = out_prefix if out_prefix is not None else f"{set_name}_{eye_code}_idx{idx}_seed{SEED}"
    heatmap_path = out_dir / f"{base}_heatmap.png"
    overlay_path = out_dir / f"{base}_overlay.png"

    plt.imsave(heatmap_path, cam01.detach().cpu().numpy(), cmap=cmap, vmin=0, vmax=1)
    plt.imsave(overlay_path, overlay)

    print(f"[✓] Saved heatmap : {heatmap_path}")
    print(f"[✓] Saved overlay : {overlay_path}")

# -----------------------
# Usage examples (run repeatedly as you like)
# -----------------------
# render_one(model_desc="retfound_partial_ft_4", MODE="detection", EYE="L", SEED=123)

# render_one(model_desc="retfound_partial_ft_4", MODE="prediction", EYE="R", SEED=7, idx=42)

# render_one(model_desc="retfound_partial_ft_4", MODE="prediction", EYE="L", SEED=0, out_prefix="figure2_example")


In [None]:
render_one(model_desc="retfound_partial_ft_4", MODE="prediction", EYE="R", SEED=52)

## Schematic illustration of the explainability analysis pipeline, including saliency map generation, anatomical segmentation, and computation of regional saliency and relative saliency ratios.

In [None]:
# with random saliency + FOV outline + padding canvas
# -*- coding: utf-8 -*-
"""
Schematic generator with REAL image & masks (single example, seeded)
- Picks one real image (seeded) that has artery/vein/ODC masks present
- Uses a procedurally generated (seeded) saliency map (NOT real model saliency)
- Saves RGBA PNGs with outside-FOV alpha=0
- Adds thin FOV circle outline
- Adds small padding around the image to avoid outline clipping

Outputs (under ./schematic_out_real):
  base_FOV.png
  saliency_full_FOV.png
  not_optic_disc_not_both_mask.png
  not_optic_disc_not_both_saliency.png
  veins_mask.png / veins_saliency.png
  arteries_mask.png / arteries_saliency.png
  both_mask.png / both_saliency.png
  optic_disc_mask.png / optic_disc_saliency.png
  overlay_image_plus_saliency.png
  USED_FILENAME.txt  (which image was used)
"""

import os, random
from pathlib import Path

import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib import cm

# =========================

SEED       = 123
H, W       = 448, 448            
PADDING    = 12                   
OUT_H, OUT_W = H + 2*PADDING, W + 2*PADDING

IMAGES_DIR = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M0/images")
ARTERY_DIR = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/artery_vein/artery_binary_process")
VEIN_DIR   = Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/artery_vein/vein_binary_process")
ODC_RAW_DIR= Path("/home/hch/opportunistic/AutoMorph_Data/Results/M2/optic_disc_cup/raw")

OUT_DIR    = Path("./schematic_out_real")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# =========================
# Utilities
# =========================
def set_seed(seed=SEED):
    random.seed(seed)
    np.random.seed(seed)

def circle_fov(h=H, w=W):
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h//2, w//2
    r = min(h, w) // 2
    return ((yy - cy)**2 + (xx - cx)**2 <= r*r).astype(np.uint8)

def circle_fov_padded(h=OUT_H, w=OUT_W):
    yy, xx = np.ogrid[:h, :w]
    cy, cx = h//2, w//2
    r = min(H, W) // 2
    return ((yy - cy)**2 + (xx - cx)**2 <= r*r).astype(np.uint8)

def pad_to_canvas(arr, pad=PADDING, fill=0):

    if arr.ndim == 2:
        out = np.full((arr.shape[0] + 2*pad, arr.shape[1] + 2*pad), fill, arr.dtype)
        out[pad:-pad, pad:-pad] = arr
    elif arr.ndim == 3:
        out = np.full((arr.shape[0] + 2*pad, arr.shape[1] + 2*pad, arr.shape[2]), fill, arr.dtype)
        out[pad:-pad, pad:-pad, :] = arr
    else:
        raise ValueError("Unsupported ndim in pad_to_canvas")
    return out

def load_image_resize_to_hw(img_path: Path, h=H, w=W):
    img = Image.open(img_path).convert("RGB")
    img = img.resize((w, h), resample=Image.BILINEAR)
    return np.array(img, dtype=np.uint8)

def ensure_bool_mask_from_file(path: Path, h=H, w=W):
    """
    Load mask image file (binary-like), resize to (H, W), return uint8 {0,1}
    """
    if not path.exists():
        return None
    try:
        m = Image.open(path)
        m = m.resize((w, h), resample=Image.NEAREST)
        arr = np.array(m)
        if arr.ndim == 3:
            arr = (np.any(arr > 0, axis=-1)).astype(np.uint8)
        else:
            arr = (arr > 0).astype(np.uint8)
        return arr
    except Exception:
        return None

def find_file_by_stem(root: Path, stem: str):
    for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"):
        p = root / f"{stem}{ext}"
        if p.exists():
            return p
    return None

def find_any_image_by_stem(stem: str):
    return find_file_by_stem(IMAGES_DIR, stem)

def load_optic_disc_union_by_stem(stem: str, h=H, w=W):
    p = find_file_by_stem(ODC_RAW_DIR, stem)
    if p is None:
        return None
    try:
        img = Image.open(p).convert("RGB").resize((w, h), resample=Image.NEAREST)
        arr = np.array(img)
        r = arr[..., 0]; b = arr[..., 2]
        rb = ((r > 0) | (b > 0)).astype(np.uint8)
        return rb
    except Exception:
        return None

def to_rgba(img_rgb, fov_mask):
    """
    img_rgb: uint8 [H,W,3]; fov_mask: uint8/bool [H,W] (1=inside)
    Returns RGBA with outside-FOV alpha=0
    """
    h, w, _ = img_rgb.shape
    rgba = np.dstack([img_rgb, np.full((h,w), 255, np.uint8)])
    rgba[..., 3] = (fov_mask.astype(np.uint8) * 255)
    return rgba

def save_rgba(arr_rgba, path: Path):
    Image.fromarray(arr_rgba, mode="RGBA").save(str(path))

def add_fov_outline(rgba_img: np.ndarray, fov_mask: np.ndarray,
                    color=(0, 0, 0), thickness=2, antialias=False) -> np.ndarray:

    mask255 = (fov_mask.astype(np.uint8) * 255)
    mask255 = np.ascontiguousarray(mask255)

    contours, _ = cv2.findContours(mask255, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)

    out = np.ascontiguousarray(rgba_img.copy())
    rgb = np.ascontiguousarray(out[..., :3].copy())
    line_type = cv2.LINE_AA if antialias else cv2.LINE_8
    for cnt in contours:
        #cv2.drawContours(rgb, [cnt], -1, color, thickness=thickness, lineType=line_type)
        cv2.drawContours(rgb, [cnt], -1, color, thickness=1, lineType=cv2.LINE_AA)
    out[..., :3] = rgb

    a = np.ascontiguousarray(out[..., 3].copy())
    for cnt in contours:
        cv2.drawContours(a, [cnt], -1, 255, thickness=thickness, lineType=line_type)
        #cv2.drawContours(rgb, [cnt], -1, color, thickness=1, lineType=cv2.LINE_AA)
    out[..., 3] = a
    return out
# --------------------------------------

def smooth_random_saliency(h=H, w=W, octaves=5, base_sigma=7.0):
    s = np.random.rand(h, w).astype(np.float32)
    for i in range(octaves):
        sigma = base_sigma * (i + 1)
        s = cv2.GaussianBlur(s, (0, 0), sigmaX=sigma, sigmaY=sigma)
    s -= s.min()
    s /= (s.max() + 1e-6)
    return s

def colorize_saliency(sal, cmap="jet"):
    """
    sal: [H,W] in [0,1]
    return: uint8 RGB [H,W,3]
    """
    try:
        m = plt.colormaps[cmap]   # matplotlib>=3.8
    except Exception:
        m = cm.get_cmap(cmap)     # fallback
    rgb = m(sal)[..., :3]
    return (rgb * 255).astype(np.uint8)

def render_binary_mask(mask01, fov, color_one=(211, 211, 211), color_zero=(255,255,255)):

    mask01 = (mask01.astype(np.uint8) & fov.astype(np.uint8))
    h, w = mask01.shape
    out = np.zeros((h, w, 4), np.uint8)
    out[..., :3] = np.array(color_zero, np.uint8)
    out[..., 3]  = (fov * 255).astype(np.uint8)
    idx = mask01.astype(bool)
    out[idx, 0] = color_one[0]
    out[idx, 1] = color_one[1]
    out[idx, 2] = color_one[2]
    return out

def render_saliency_in_region(sal_rgb, region01, fov):

    h, w, _ = sal_rgb.shape
    out = np.dstack([sal_rgb.copy(), np.zeros((h,w), np.uint8)])
    visible = (region01.astype(np.uint8) & fov.astype(np.uint8))
    out[..., 3] = (visible * 255).astype(np.uint8)
    return out

def logical_or(*arrs):
    out = None
    for a in arrs:
        if a is None:
            continue
        out = a if out is None else np.clip(out | a, 0, 1).astype(np.uint8)
    return out

# =========================
# Candidate selection
# =========================
def collect_stems(root: Path):
    stems = set()
    for p in root.glob("*"):
        if p.is_file():
            stems.add(p.stem)
    return stems

def pick_one_filename_with_all_masks(seed=SEED):

    art_stems = collect_stems(ARTERY_DIR)
    vein_stems= collect_stems(VEIN_DIR)
    odc_stems = collect_stems(ODC_RAW_DIR)

    common = list(art_stems & vein_stems & odc_stems)

    if not common:
        common = list(art_stems & vein_stems)

    valid = []
    for s in common:
        imp = find_any_image_by_stem(s)
        if imp is not None:
            valid.append((s, imp))
    if not valid:
        raise FileNotFoundError("No common stem with image + artery/vein (+/- ODC). Check paths.")
    rng = random.Random(seed)
    stem, imgp = rng.choice(valid)
    return stem, imgp

# =========================
# Main
# =========================
def main(seed=SEED):
    set_seed(seed)


    stem, img_path = pick_one_filename_with_all_masks(seed)
    with open(OUT_DIR/"USED_FILENAME.txt", "w") as f:
        f.write(f"{stem}\n{img_path}\n")


    img = load_image_resize_to_hw(img_path, H, W)


    art_p = find_file_by_stem(ARTERY_DIR, stem)
    vein_p= find_file_by_stem(VEIN_DIR,   stem)
    odc   = load_optic_disc_union_by_stem(stem, H, W)   # optic disc ∪ cup

    arteries = ensure_bool_mask_from_file(art_p, H, W) if art_p else None
    veins    = ensure_bool_mask_from_file(vein_p, H, W) if vein_p else None
    both     = None
    if arteries is not None and veins is not None:
        both = np.clip(arteries | veins, 0, 1).astype(np.uint8)


    fov_small = circle_fov(H, W)

    fov_big   = circle_fov_padded(OUT_H, OUT_W)


    img_rgba = to_rgba(img, fov_small)                 # (H,W,4)
    img_rgba = pad_to_canvas(img_rgba, PADDING, fill=0)# (OUT_H,OUT_W,4)
    img_rgba = add_fov_outline(img_rgba, fov_big, color=(0,0,0), thickness=1)
    Image.fromarray(img_rgba, mode="RGBA").save(str(OUT_DIR / "base_FOV.png"))


    sal = smooth_random_saliency(H, W, octaves=3)
    sal_rgb = colorize_saliency(sal, cmap="jet")       # (H,W,3)
    sal_rgba_full = to_rgba(sal_rgb, fov_small)        # (H,W,4)
    sal_rgba_full = pad_to_canvas(sal_rgba_full, PADDING, fill=0)
    sal_rgba_full = add_fov_outline(sal_rgba_full, fov_big, color=(0,0,0), thickness=1)
    Image.fromarray(sal_rgba_full, mode="RGBA").save(str(OUT_DIR / "saliency_full_FOV.png"))


    if both is None:
        not_both_f = (fov_small & 1).astype(np.uint8)   
    else:
        not_both_f = (fov_small & (1 - both)).astype(np.uint8)

    if odc is None:
        not_od_f = (fov_small & 1).astype(np.uint8)     
    else:
        not_od_f = (fov_small & (1 - odc)).astype(np.uint8)

    not_od_not_both = (fov_small & (1 - logical_or(both if both is not None else np.zeros_like(fov_small),
                                                   odc  if odc  is not None else np.zeros_like(fov_small)))).astype(np.uint8)


    def do_region(name, mask01):
        mask_rgba = render_binary_mask(mask01, fov_small, color_one=(211, 211, 211), color_zero=(255,255,255))
        mask_rgba = pad_to_canvas(mask_rgba, PADDING, fill=0)
        mask_rgba = add_fov_outline(mask_rgba, fov_big, color=(0,0,0), thickness=1)
        Image.fromarray(mask_rgba, mode="RGBA").save(str(OUT_DIR / f"{name}_mask.png"))

        sal_in = render_saliency_in_region(sal_rgb, mask01, fov_small)
        sal_in = pad_to_canvas(sal_in, PADDING, fill=0)
        sal_in = add_fov_outline(sal_in, fov_big, color=(0,0,0), thickness=1)
        Image.fromarray(sal_in, mode="RGBA").save(str(OUT_DIR / f"{name}_saliency.png"))

    # (3)(4) not_optic_disc_not_both
    do_region("not_optic_disc_not_both", not_od_not_both)

    # (5) veins, arteries, both, optic_disc 각각 반복
    if veins is not None:
        do_region("veins", (veins & fov_small).astype(np.uint8))
    if arteries is not None:
        do_region("arteries", (arteries & fov_small).astype(np.uint8))
    if both is not None:
        do_region("both", (both & fov_small).astype(np.uint8))
    if odc is not None:
        do_region("optic_disc", (odc & fov_small).astype(np.uint8))


    overlay = (0.55*img.astype(np.float32) + 0.45*sal_rgb.astype(np.float32)).clip(0,255).astype(np.uint8)
    overlay_rgba = to_rgba(overlay, fov_small)
    overlay_rgba = pad_to_canvas(overlay_rgba, PADDING, fill=0)
    overlay_rgba = add_fov_outline(overlay_rgba, fov_big, color=(0,0,0), thickness=1)
    Image.fromarray(overlay_rgba, mode="RGBA").save(str(OUT_DIR / "overlay_image_plus_saliency.png"))

    print("\n[Saved files in]", OUT_DIR.resolve())
    for p in sorted(OUT_DIR.glob("*.png")):
        print("-", p.name)


In [None]:
main(11)