In [2]:
# file: scripts/export_second_image_fixations_heatmaps.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)):
        if x in (0, 1):
            return bool(x)
        return None
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b

    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    """
    Matches your snippet:
    l, t, r, b = (screen_w//2 - half_box, screen_h//2 - half_box,
                  screen_w//2 + half_box, screen_h//2 + half_box)
    """
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s  # density (matches your PDF-ish colorbar scale)
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append(
            {"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))}
        )
    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No trials with correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _rank_pairs_within_condition(agg: pd.DataFrame) -> pd.DataFrame:
    def ranker(group: pd.DataFrame) -> pd.DataFrame:
        g_inc = group.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
        g_inc["rank_most_incorrect"] = np.arange(1, len(g_inc) + 1)

        g_cor = group.sort_values(["acc", "pair"], ascending=[False, True]).reset_index(drop=True)
        g_cor["rank_most_correct"] = np.arange(1, len(g_cor) + 1)

        merged = g_inc.merge(g_cor[["pair", "rank_most_correct"]], on="pair", how="left")
        merged["n_pairs_in_condition"] = len(group)
        return merged

    out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)
    out["condition"] = pd.Categorical(out["condition"], categories=list(CONDITIONS), ordered=True)
    return out


def _collect_trials(all_trials: Sequence[TrialRow], condition: str, pair_str: str) -> List[TrialRow]:
    return [t for t in all_trials if t.condition == condition and t.pair.as_str == pair_str]


def _select_top_k(agg_ranked: pd.DataFrame, k: int) -> Dict[str, pd.DataFrame]:
    out: Dict[str, pd.DataFrame] = {}
    for cond in CONDITIONS:
        sub = agg_ranked[agg_ranked["condition"] == cond].copy()
        out[cond] = sub.sort_values(["acc", "pair"], ascending=[False, True]).head(k).reset_index(drop=True)
    return out


def _select_bottom_k(agg_ranked: pd.DataFrame, k: int) -> Dict[str, pd.DataFrame]:
    out: Dict[str, pd.DataFrame] = {}
    for cond in CONDITIONS:
        sub = agg_ranked[agg_ranked["condition"] == cond].copy()
        out[cond] = sub.sort_values(["acc", "pair"], ascending=[True, True]).head(k).reset_index(drop=True)
    return out


def _plot_fixations_only(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent: Tuple[int, int, int, int],  # (l, t, r, b)
    title: str,
) -> None:
    l, t, r, b = extent
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),  # like #FF6B6B, but RGBA
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(title, fontsize=10)
    ax.grid(False)


def _plot_heatmap_only(
    ax: plt.Axes,
    img_arr: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent: Tuple[int, int, int, int],  # (l, t, r, b)
    title: str,
    cmap: str,
    alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=cmap,
        alpha=alpha,
        vmin=0.0,
        vmax=vmax,
    )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(title, fontsize=10)
    ax.grid(False)
    return hm


def export_pdfs(
    root_path: str,
    output_dir: str,
    k: int = 3,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Tuple[Path, Path]:
    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    agg_ranked = _rank_pairs_within_condition(agg)

    agg_ranked.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    top = _select_top_k(agg_ranked, k=k)
    bottom = _select_bottom_k(agg_ranked, k=k)

    top_pdf = out_dir / f"TOP{k}_second_image_fix_heat.pdf"
    bottom_pdf = out_dir / f"BOTTOM{k}_second_image_fix_heat.pdf"

    def _page_for(rank_i: int, group: str, selection: Dict[str, pd.DataFrame], pdf: PdfPages) -> None:
        fig, axes = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=True)
        last_hm = None

        for col, cond in enumerate(CONDITIONS):
            ax_fix = axes[0, col]
            ax_hm = axes[1, col]

            sub = selection[cond]
            if rank_i >= len(sub):
                ax_fix.axis("off")
                ax_hm.axis("off")
                continue

            row = sub.iloc[rank_i]
            pair_str = str(row["pair"])
            trials = _collect_trials(all_trials, cond, pair_str)
            if not trials:
                ax_fix.axis("off")
                ax_hm.axis("off")
                continue

            # second image only
            img_path = _find_image_file(root, trials[0].second_image)
            if img_path is None:
                ax_fix.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                ax_hm.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                ax_fix.axis("off")
                ax_hm.axis("off")
                continue

            sw, sh = screen_default
            for t in trials[:3]:
                sw, sh = _screen_dims(t.raw, (sw, sh))

            l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
            box_size = int(half_box * 2)

            img_arr = _load_image_resized(img_path, size_px=box_size)

            pts_list = [_extract_fixations(t.raw) for t in trials]
            pts = np.vstack([p for p in pts_list if len(p) > 0]) if any(len(p) > 0 for p in pts_list) else np.zeros((0, 2))

            heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

            acc = float(row["acc"])
            right = int(row["right"])
            wrong = int(row["wrong"])
            N = int(row["N"])
            n_pairs = int(row["n_pairs_in_condition"])
            pair_rank = int(row["rank_most_correct"]) if group == "TOP" else int(row["rank_most_incorrect"])

            title_base = f"Testing • {cond.upper()} • Acc: {acc*100:.1f}% • Pair {pair_rank}/{n_pairs} • N={N}"
            title_fix = title_base + f"\nSecond Image • Fixations • right={right} wrong={wrong} • Fix={len(pts)}"
            title_hm = title_base + f"\nSecond Image • Density Heatmap • right={right} wrong={wrong} • Fix={len(pts)}"

            _plot_fixations_only(
                ax=ax_fix,
                img_arr=img_arr,
                pts=pts,
                screen_w=sw,
                screen_h=sh,
                extent=(l, ttop, r, b),
                title=title_fix,
            )
            last_hm = _plot_heatmap_only(
                ax=ax_hm,
                img_arr=img_arr,
                heat=heat,
                screen_w=sw,
                screen_h=sh,
                extent=(l, ttop, r, b),
                title=title_hm,
                cmap=heatmap_cmap,
                alpha=heatmap_alpha,
                vmax_mode=vmax_mode,
            )

        fig.suptitle(
            f"Testing Phase • {group} {rank_i+1}/{k} • Second Image Only • Fixations + Heatmap",
            fontsize=14,
        )

        if last_hm is not None:
            cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.85, pad=0.02)
            cbar.set_label("Fixation Density", rotation=90)

        pdf.savefig(fig)
        plt.close(fig)

    with PdfPages(top_pdf) as pdf:
        for i in range(k):
            _page_for(i, "TOP", top, pdf)

    with PdfPages(bottom_pdf) as pdf:
        for i in range(k):
            _page_for(i, "BOTTOM", bottom, pdf)

    return top_pdf, bottom_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Export PDFs (TOPk and BOTTOMk), each page has 6 panels: 3 conditions × (fixations row + heatmap row), second image only."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--k", type=int, default=3)
    parser.add_argument("--half-box", type=int, default=310, help="Half side of centered image box in screen px.")
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    top_pdf, bottom_pdf = export_pdfs(
        root_path=args.root,
        output_dir=args.out,
        k=int(args.k),
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ TOP PDF: {top_pdf}")
    print(f"✅ BOTTOM PDF: {bottom_pdf}")


if __name__ == "__main__":
    main()


  out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)


✅ TOP PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/TOP3_second_image_fix_heat.pdf
✅ BOTTOM PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/BOTTOM3_second_image_fix_heat.pdf


In [3]:
# file: scripts/export_second_image_overlay_top_bottom_k.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b

    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s  # density
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})
    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _rank_pairs_within_condition(agg: pd.DataFrame) -> pd.DataFrame:
    def ranker(group: pd.DataFrame) -> pd.DataFrame:
        g_inc = group.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
        g_inc["rank_most_incorrect"] = np.arange(1, len(g_inc) + 1)

        g_cor = group.sort_values(["acc", "pair"], ascending=[False, True]).reset_index(drop=True)
        g_cor["rank_most_correct"] = np.arange(1, len(g_cor) + 1)

        merged = g_inc.merge(g_cor[["pair", "rank_most_correct"]], on="pair", how="left")
        merged["n_pairs_in_condition"] = len(group)
        return merged

    out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)
    out["condition"] = pd.Categorical(out["condition"], categories=list(CONDITIONS), ordered=True)
    return out


def _collect_trials(all_trials: Sequence[TrialRow], condition: str, pair_str: str) -> List[TrialRow]:
    return [t for t in all_trials if t.condition == condition and t.pair.as_str == pair_str]


def _select_top_k(agg_ranked: pd.DataFrame, k: int) -> Dict[str, pd.DataFrame]:
    out: Dict[str, pd.DataFrame] = {}
    for cond in CONDITIONS:
        sub = agg_ranked[agg_ranked["condition"] == cond].copy()
        out[cond] = sub.sort_values(["acc", "pair"], ascending=[False, True]).head(k).reset_index(drop=True)
    return out


def _select_bottom_k(agg_ranked: pd.DataFrame, k: int) -> Dict[str, pd.DataFrame]:
    out: Dict[str, pd.DataFrame] = {}
    for cond in CONDITIONS:
        sub = agg_ranked[agg_ranked["condition"] == cond].copy()
        out[cond] = sub.sort_values(["acc", "pair"], ascending=[True, True]).head(k).reset_index(drop=True)
    return out


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb

    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    vmax = float(np.quantile(heat, 0.99)) if (vmax_mode == "p99" and heat.max() > 0) else float(heat.max() or 1.0)
    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(title, fontsize=10)
    ax.grid(False)
    return hm


def export_pdfs(
    root_path: str,
    output_dir: str,
    k: int = 3,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
    single_pdf: bool = False,
) -> Tuple[Path, Optional[Path]]:
    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    agg_ranked = _rank_pairs_within_condition(agg)
    agg_ranked.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    top = _select_top_k(agg_ranked, k=k)
    bottom = _select_bottom_k(agg_ranked, k=k)

    top_pdf = out_dir / f"TOP{k}_second_image_overlay.pdf"
    bottom_pdf = out_dir / f"BOTTOM{k}_second_image_overlay.pdf"
    combined_pdf = out_dir / f"TOP{k}_BOTTOM{k}_second_image_overlay.pdf"

    def _write_group(pdf: PdfPages, group_name: str, selection: Dict[str, pd.DataFrame]) -> None:
        for rank_i in range(k):
            fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)
            last_hm = None

            for col, cond in enumerate(CONDITIONS):
                ax = axes[col]
                sub = selection[cond]
                if rank_i >= len(sub):
                    ax.axis("off")
                    continue

                row = sub.iloc[rank_i]
                pair_str = str(row["pair"])
                trials = _collect_trials(all_trials, cond, pair_str)
                if not trials:
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for t in trials[:3]:
                    sw, sh = _screen_dims(t.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)

                pts_list = [_extract_fixations(t.raw) for t in trials]
                pts = np.vstack([p for p in pts_list if len(p) > 0]) if any(len(p) > 0 for p in pts_list) else np.zeros((0, 2))

                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])
                n_pairs = int(row["n_pairs_in_condition"])
                pair_rank = int(row["rank_most_correct"]) if group_name == "TOP" else int(row["rank_most_incorrect"])

                title = (
                    f"Testing • {cond.upper()} • Acc: {acc*100:.1f}% • Pair {pair_rank}/{n_pairs}\n"
                    f"Second Image • right={right} wrong={wrong} • N={N} • Fix={len(pts)}"
                )

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title=title,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

            fig.suptitle(
                f"Testing Phase • {group_name} {rank_i+1}/{k} • Second Image Only • Heatmap + Fixations",
                fontsize=14,
            )

            if last_hm is not None:
                cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.9, pad=0.02)
                cbar.set_label("Fixation Density", rotation=90)

            pdf.savefig(fig)
            plt.close(fig)

    if single_pdf:
        with PdfPages(combined_pdf) as pdf:
            _write_group(pdf, "TOP", top)
            _write_group(pdf, "BOTTOM", bottom)
        return combined_pdf, None

    with PdfPages(top_pdf) as pdf:
        _write_group(pdf, "TOP", top)

    with PdfPages(bottom_pdf) as pdf:
        _write_group(pdf, "BOTTOM", bottom)

    return top_pdf, bottom_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Export PDFs where each page has 3 panels (FULL/CENTRAL/PERIPHERAL): second image with heatmap+fixations overlaid."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--k", type=int, default=3)
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])
    parser.add_argument("--single-pdf", action="store_true", help="If set, outputs one combined PDF (TOPk then BOTTOMk).")

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    a, b = export_pdfs(
        root_path=args.root,
        output_dir=args.out,
        k=int(args.k),
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
        single_pdf=bool(args.single_pdf),
    )
    if b is None:
        print(f"✅ PDF: {a}")
    else:
        print(f"✅ TOP PDF: {a}")
        print(f"✅ BOTTOM PDF: {b}")


if __name__ == "__main__":
    main()


  out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)


✅ TOP PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/TOP3_second_image_overlay.pdf
✅ BOTTOM PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/BOTTOM3_second_image_overlay.pdf


In [4]:
# file: scripts/export_second_image_overlay_ordered_accuracy.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue
        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()
        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s  # density
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _ranked_lists_by_condition(agg: pd.DataFrame, order: str) -> Dict[str, pd.DataFrame]:
    """
    Returns per-condition ordered list with rank_in_order (1..n) and n_pairs.
    order: 'high_to_low' or 'low_to_high'
    """
    out: Dict[str, pd.DataFrame] = {}
    asc = order == "low_to_high"
    for cond in CONDITIONS:
        sub = agg[agg["condition"] == cond].copy()
        sub = sub.sort_values(["acc", "pair"], ascending=[asc, True]).reset_index(drop=True)
        sub["rank_in_order"] = np.arange(1, len(sub) + 1)
        sub["n_pairs"] = len(sub)
        out[cond] = sub
    return out


def _collect_trials(all_trials: Sequence[TrialRow], condition: str, pair_str: str) -> List[TrialRow]:
    return [t for t in all_trials if t.condition == condition and t.pair.as_str == pair_str]


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb

    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(title, fontsize=10)
    ax.grid(False)
    return hm


def export_ordered_accuracy_pdf(
    root_path: str,
    output_dir: str,
    order: str = "high_to_low",
    max_pages: Optional[int] = None,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Path:
    if order not in ("high_to_low", "low_to_high"):
        raise ValueError("order must be 'high_to_low' or 'low_to_high'")

    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    agg.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    ranked = _ranked_lists_by_condition(agg, order=order)
    n_pages = min(len(ranked[c]) for c in CONDITIONS if not ranked[c].empty)
    if max_pages is not None:
        n_pages = min(n_pages, int(max_pages))

    out_pdf = out_dir / f"SECOND_IMAGE_OVERLAY_ORDERED_{order.upper()}.pdf"

    with PdfPages(out_pdf) as pdf:
        for i in range(n_pages):
            fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)
            last_hm = None

            for col, cond in enumerate(CONDITIONS):
                ax = axes[col]
                sub = ranked[cond]
                row = sub.iloc[i]
                pair_str = str(row["pair"])
                trials = _collect_trials(all_trials, cond, pair_str)

                if not trials:
                    ax.set_title(f"{cond.upper()} • missing trials", fontsize=10)
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for t in trials[:3]:
                    sw, sh = _screen_dims(t.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)
                pts = np.vstack([_extract_fixations(t.raw) for t in trials])
                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])
                rank_in_order = int(row["rank_in_order"])
                n_pairs = int(row["n_pairs"])

                title = (
                    f"Testing • {cond.upper()} • Acc: {acc*100:.1f}% • Rank {rank_in_order}/{n_pairs}\n"
                    f"Second Image • right={right} wrong={wrong} • N={N} • Fix={len(pts)}"
                )

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title=title,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

            fig.suptitle(
                f"Testing Phase • Ordered by Accuracy ({order.replace('_', '→')}) • Page {i+1}/{n_pages}",
                fontsize=14,
            )

            if last_hm is not None:
                cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.9, pad=0.02)
                cbar.set_label("Fixation Density", rotation=90)

            pdf.savefig(fig)
            plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF, pages ordered by accuracy. Each page: 3 panels (full/central/peripheral), second image with heatmap+fixations overlaid."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--order", default="high_to_low", choices=["high_to_low", "low_to_high"])
    parser.add_argument("--max-pages", type=int, default=None, help="Optional limit on number of pages (default: all ranks).")
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    out_pdf = export_ordered_accuracy_pdf(
        root_path=args.root,
        output_dir=args.out,
        order=str(args.order),
        max_pages=args.max_pages,
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/SECOND_IMAGE_OVERLAY_ORDERED_HIGH_TO_LOW.pdf


In [5]:
# file: scripts/export_bottom3_titles_onepdf.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _rank_most_incorrect(agg: pd.DataFrame) -> pd.DataFrame:
    def ranker(group: pd.DataFrame) -> pd.DataFrame:
        g = group.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
        g["pair_rank_incorrect"] = np.arange(1, len(g) + 1)
        g["n_pairs"] = len(g)
        return g

    out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)
    out["condition"] = pd.Categorical(out["condition"], categories=list(CONDITIONS), ordered=True)
    return out


def _collect_trials(all_trials: Sequence[TrialRow], condition: str, pair_str: str) -> List[TrialRow]:
    return [t for t in all_trials if t.condition == condition and t.pair.as_str == pair_str]


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title1: str,
    title2: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb

    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(f"{title1}\n{title2}", fontsize=10)
    ax.grid(False)
    return hm


def export_bottom3_each_condition_one_pdf(
    root_path: str,
    output_dir: str,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Path:
    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    ranked = _rank_most_incorrect(agg)
    ranked.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    out_pdf = out_dir / "BOTTOM3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf"

    with PdfPages(out_pdf) as pdf:
        for cond in CONDITIONS:
            sub = ranked[ranked["condition"] == cond].sort_values(
                ["acc", "pair"], ascending=[True, True]
            ).head(3)

            fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)
            last_hm = None

            for j, (_, row) in enumerate(sub.iterrows()):
                ax = axes[j]

                pair_str = str(row["pair"])
                trials = _collect_trials(all_trials, cond, pair_str)
                if not trials:
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for t in trials[:3]:
                    sw, sh = _screen_dims(t.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)

                pts = np.vstack([_extract_fixations(t.raw) for t in trials])
                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])
                pair_rank = int(row["pair_rank_incorrect"])
                n_pairs = int(row["n_pairs"])

                correctness_label = "Correct" if acc >= 0.5 else "Incorrect"

                title1 = f"Testing • Acc: {acc*100:.1f}% • Pair {pair_rank}/{n_pairs} • {correctness_label}"
                title2 = f"Subjects: right={right} • wrong={wrong} • N={N} subjects • Second: {len(pts)} fix"

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title1=title1,
                    title2=title2,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

            # If fewer than 3, hide remaining axes
            for k in range(len(sub), 3):
                axes[k].axis("off")

            fig.suptitle(
                f"Testing Phase - {cond.upper()} Viewing Condition\n"
                f"Bottom 3 Trials (Low Accuracy)\n"
                f"Ordered from Most Incorrect \u2192 Least Incorrect",
                fontsize=14,
                fontweight="bold",
            )

            if last_hm is not None:
                cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.9, pad=0.02)
                cbar.set_label("Fixation Density", rotation=90)

            pdf.savefig(fig)
            plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF, 3 pages (full/central/peripheral). Each page shows bottom 3 pairs (most incorrect->least incorrect) with titles like your example."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    out_pdf = export_bottom3_each_condition_one_pdf(
        root_path=args.root,
        output_dir=args.out,
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


  out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/BOTTOM3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf


In [6]:
# file: scripts/export_top3_titles_onepdf.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _rank_most_correct(agg: pd.DataFrame) -> pd.DataFrame:
    def ranker(group: pd.DataFrame) -> pd.DataFrame:
        g = group.sort_values(["acc", "pair"], ascending=[False, True]).reset_index(drop=True)
        g["pair_rank_correct"] = np.arange(1, len(g) + 1)
        g["n_pairs"] = len(g)
        return g

    out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)
    out["condition"] = pd.Categorical(out["condition"], categories=list(CONDITIONS), ordered=True)
    return out


def _collect_trials(all_trials: Sequence[TrialRow], condition: str, pair_str: str) -> List[TrialRow]:
    return [t for t in all_trials if t.condition == condition and t.pair.as_str == pair_str]


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title1: str,
    title2: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(f"{title1}\n{title2}", fontsize=10)
    ax.grid(False)
    return hm


def export_top3_each_condition_one_pdf(
    root_path: str,
    output_dir: str,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Path:
    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    ranked = _rank_most_correct(agg)
    ranked.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    out_pdf = out_dir / "TOP3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf"

    with PdfPages(out_pdf) as pdf:
        for cond in CONDITIONS:
            sub = ranked[ranked["condition"] == cond].sort_values(
                ["acc", "pair"], ascending=[False, True]
            ).head(3)

            fig, axes = plt.subplots(1, 3, figsize=(18, 6), constrained_layout=True)
            last_hm = None

            for j, (_, row) in enumerate(sub.iterrows()):
                ax = axes[j]

                pair_str = str(row["pair"])
                trials = _collect_trials(all_trials, cond, pair_str)
                if not trials:
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for t in trials[:3]:
                    sw, sh = _screen_dims(t.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)

                pts = np.vstack([_extract_fixations(t.raw) for t in trials])
                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])
                pair_rank = int(row["pair_rank_correct"])
                n_pairs = int(row["n_pairs"])

                correctness_label = "Correct" if acc >= 0.5 else "Incorrect"

                title1 = f"Testing • Acc: {acc*100:.1f}% • Pair {pair_rank}/{n_pairs} • {correctness_label}"
                title2 = f"Subjects: right={right} • wrong={wrong} • N={N} subjects • Second: {len(pts)} fix"

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title1=title1,
                    title2=title2,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

            for k in range(len(sub), 3):
                axes[k].axis("off")

            fig.suptitle(
                f"Testing Phase - {cond.upper()} Viewing Condition\n"
                f"Top 3 Trials (High Accuracy)\n"
                f"Ordered from Most Correct \u2192 Least Correct",
                fontsize=14,
                fontweight="bold",
            )

            if last_hm is not None:
                cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.9, pad=0.02)
                cbar.set_label("Fixation Density", rotation=90)

            pdf.savefig(fig)
            plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF, 3 pages (full/central/peripheral). Each page shows top 3 pairs (most correct->least correct) with titles like your example."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    out_pdf = export_top3_each_condition_one_pdf(
        root_path=args.root,
        output_dir=args.out,
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


  out = agg.groupby("condition", group_keys=False).apply(ranker).reset_index(drop=True)


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/TOP3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf


In [7]:
# file: scripts/export_top3_bottom3_grid_onepdf.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: str) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = x.strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title1: str,
    title2: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(f"{title1}\n{title2}", fontsize=10)
    ax.grid(False)
    return hm


def export_top_bottom_3_one_pdf(
    root_path: str,
    output_dir: str,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Path:
    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    agg = _aggregate_accuracy(all_trials)
    agg.to_csv(out_dir / "accuracy_by_pair_condition.csv", index=False)

    out_pdf = out_dir / "TOP3_BOTTOM3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf"

    with PdfPages(out_pdf) as pdf:
        for cond in CONDITIONS:
            cond_agg = agg[agg["condition"] == cond].copy()
            cond_agg = cond_agg.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
            n_pairs = len(cond_agg)

            bottom3 = cond_agg.head(3).copy()
            top3 = cond_agg.tail(3).sort_values(["acc", "pair"], ascending=[False, True]).copy()

            fig, axes = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=True)
            last_hm = None

            def render_row(row_df: pd.DataFrame, row_idx: int, label: str, rank_base: str) -> None:
                nonlocal last_hm
                for j in range(3):
                    ax = axes[row_idx, j]
                    if j >= len(row_df):
                        ax.axis("off")
                        continue

                    row = row_df.iloc[j]
                    pair_str = str(row["pair"])
                    trials = [t for t in all_trials if t.condition == cond and t.pair.as_str == pair_str]
                    if not trials:
                        ax.axis("off")
                        continue

                    img_path = _find_image_file(root, trials[0].second_image)
                    if img_path is None:
                        ax.set_title(f"{cond.upper()} • missing second image", fontsize=10)
                        ax.axis("off")
                        continue

                    sw, sh = screen_default
                    for t in trials[:3]:
                        sw, sh = _screen_dims(t.raw, (sw, sh))

                    l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                    box_size = int(half_box * 2)
                    img_arr = _load_image_resized(img_path, size_px=box_size)

                    pts = np.vstack([_extract_fixations(t.raw) for t in trials])
                    heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                    acc = float(row["acc"])
                    right = int(row["right"])
                    wrong = int(row["wrong"])
                    N = int(row["N"])

                    # Pair rank: mimic your existing PDFs:
                    # - top row: 1/24, 2/24, 3/24 ...
                    # - bottom row: 1/24, 2/24, 3/24 ... (most incorrect -> least incorrect)
                    # If you want global rank within all pairs, swap this logic.
                    if rank_base == "top":
                        pair_rank = j + 1
                    else:
                        pair_rank = j + 1

                    correctness_label = "Correct" if acc >= 0.5 else "Incorrect"
                    title1 = f"Testing • Acc: {acc*100:.1f}% • Pair {pair_rank}/{n_pairs} • {correctness_label}"
                    title2 = f"Subjects: right={right} • wrong={wrong} • N={N} subjects • Second: {len(pts)} fix"

                    last_hm = _plot_overlay(
                        ax=ax,
                        img_arr=img_arr,
                        pts=pts,
                        heat=heat,
                        screen_w=sw,
                        screen_h=sh,
                        extent_ltrb=(l, ttop, r, b),
                        title1=title1,
                        title2=title2,
                        heatmap_cmap=heatmap_cmap,
                        heatmap_alpha=heatmap_alpha,
                        vmax_mode=vmax_mode,
                    )
                    ax.text(
                        0.01,
                        0.99,
                        label,
                        transform=ax.transAxes,
                        va="top",
                        ha="left",
                        fontsize=10,
                        fontweight="bold",
                        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
                    )

            # Top row: Top 3 (most correct -> least correct)
            render_row(top3.reset_index(drop=True), row_idx=0, label="TOP", rank_base="top")
            # Bottom row: Bottom 3 (most incorrect -> least incorrect)
            render_row(bottom3.reset_index(drop=True), row_idx=1, label="BOTTOM", rank_base="bottom")

            fig.suptitle(
                f"Testing Phase - {cond.upper()} Viewing Condition\n"
                f"Top 3 (High Accuracy) + Bottom 3 (Low Accuracy)\n"
                f"Top: Most Correct \u2192 Least Correct   |   Bottom: Most Incorrect \u2192 Least Incorrect",
                fontsize=14,
                fontweight="bold",
            )

            if last_hm is not None:
                cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.85, pad=0.02)
                cbar.set_label("Fixation Density", rotation=90)

            pdf.savefig(fig)
            plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF (3 pages): per viewing condition, one figure with 6 panels (Top 3 + Bottom 3)."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    out_pdf = export_top_bottom_3_one_pdf(
        root_path=args.root,
        output_dir=args.out,
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/TOP3_BOTTOM3_EACH_VC_SECOND_IMAGE_OVERLAY.pdf


In [8]:
# file: scripts/sanity_check_testing_json.py

from __future__ import annotations

import argparse
import json
import math
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple

import pandas as pd


CONDITIONS = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


def _norm(s: Any) -> str:
    return str(s).strip().lower()


def _normalize_image_name(name: Any) -> str:
    s = _norm(name)
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = _norm(x)
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, int) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = _norm(x)
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = _norm(x)
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Tuple[Optional[bool], Optional[bool]]:
    """
    Returns: (correct_inferred, correct_from_acc_field)
    - correct_from_acc_field is only from acc-like boolean field if present
    - correct_inferred uses acc if available else subj_answer==correct_response if available
    """
    acc_field = None
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            acc_field = _as_bool(item.get(k))
            break

    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    from_answers = (subj == corr) if (subj is not None and corr is not None) else None

    inferred = acc_field if acc_field is not None else from_answers
    return inferred, acc_field


def _extract_fixations(item: Dict[str, Any]) -> Tuple[List[float], List[float]]:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, list) or not isinstance(ys, list):
        return [], []
    return xs, ys


def _finite_float_list(xs: Sequence[Any]) -> Tuple[bool, List[float]]:
    out: List[float] = []
    try:
        for v in xs:
            f = float(v)
            if not math.isfinite(f):
                return False, []
            out.append(f)
        return True, out
    except Exception:
        return False, []


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"testing.json not found under {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _maybe_find_image(root: Path, image_name: str) -> bool:
    base = _normalize_image_name(image_name)
    # quick heuristic: if exact filename exists anywhere, ok
    for ext in IMAGE_EXTS:
        if list(root.rglob(base + ext)):
            return True
    # otherwise look for any file with same stem
    for p in root.rglob("*"):
        if p.is_file() and p.suffix.lower() in IMAGE_EXTS:
            if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
                return True
    return False


def sanity_check(
    root: str,
    json_path: Optional[str],
    out_dir: str,
    check_images: bool = False,
) -> Path:
    root_p = Path(root).expanduser().resolve()
    out_p = Path(out_dir).expanduser().resolve()
    out_p.mkdir(parents=True, exist_ok=True)

    if json_path:
        jp = Path(json_path).expanduser().resolve()
        if not jp.exists():
            raise FileNotFoundError(f"--json not found: {jp}")
    else:
        jp = _find_testing_json(root_p)

    raw = json.loads(jp.read_text(encoding="utf-8"))
    if not isinstance(raw, list):
        raise ValueError("testing.json must be a list of dicts")
    if not raw or not all(isinstance(x, dict) for x in raw):
        raise ValueError("testing.json must be a non-empty list of objects")

    issues: List[Dict[str, Any]] = []
    rows: List[Dict[str, Any]] = []

    pair_first_seen_idx: Dict[str, Dict[str, int]] = {c: {} for c in CONDITIONS}
    pair_id_map: Dict[str, Dict[str, int]] = {c: {} for c in CONDITIONS}
    pair_id_counter: Dict[str, int] = {c: 0 for c in CONDITIONS}

    required_keys = ("viewing_condition", "first_image", "second_image")

    for idx, item in enumerate(raw):
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        first_img = item.get("first_image")
        second_img = item.get("second_image")
        subj_id = item.get("subject_id")
        trial_index = item.get("trial_index")

        missing = [k for k in required_keys if item.get(k) in (None, "", [])]
        if missing:
            issues.append(
                {
                    "row_index": idx,
                    "type": "missing_required_keys",
                    "missing_keys": ",".join(missing),
                    "condition": item.get("viewing_condition"),
                }
            )
            continue

        if cond is None:
            issues.append(
                {
                    "row_index": idx,
                    "type": "invalid_condition",
                    "value": item.get("viewing_condition"),
                }
            )
            continue

        a = _normalize_image_name(first_img)
        b = _normalize_image_name(second_img)
        pair = PairKey(*sorted((a, b))).as_str

        # pair_id per condition by first appearance in JSON order
        if pair not in pair_id_map[cond]:
            pair_id_counter[cond] += 1
            pair_id_map[cond][pair] = pair_id_counter[cond]
            pair_first_seen_idx[cond][pair] = idx

        correct_inferred, correct_from_acc = _detect_correct(item)

        # fix arrays checks
        fix_x_raw, fix_y_raw = _extract_fixations(item)
        if bool(fix_x_raw) != bool(fix_y_raw):
            issues.append(
                {
                    "row_index": idx,
                    "type": "fix_arrays_missing_one_side",
                    "condition": cond,
                    "pair": pair,
                    "has_fix_x": isinstance(item.get("fix_x"), list),
                    "has_fix_y": isinstance(item.get("fix_y"), list),
                }
            )

        ok_x, fix_x = _finite_float_list(fix_x_raw) if fix_x_raw else (True, [])
        ok_y, fix_y = _finite_float_list(fix_y_raw) if fix_y_raw else (True, [])
        if not ok_x or not ok_y:
            issues.append(
                {
                    "row_index": idx,
                    "type": "fix_arrays_non_numeric_or_non_finite",
                    "condition": cond,
                    "pair": pair,
                }
            )
        if len(fix_x) != len(fix_y):
            issues.append(
                {
                    "row_index": idx,
                    "type": "fix_arrays_length_mismatch",
                    "condition": cond,
                    "pair": pair,
                    "len_x": len(fix_x),
                    "len_y": len(fix_y),
                }
            )

        # acc vs answers consistency
        subj = _normalize_answer(item.get("subj_answer"))
        corr = _normalize_answer(item.get("correct_response"))
        if correct_from_acc is not None and subj is not None and corr is not None:
            from_answers = subj == corr
            if from_answers != correct_from_acc:
                issues.append(
                    {
                        "row_index": idx,
                        "type": "acc_mismatch_subj_vs_correct",
                        "condition": cond,
                        "pair": pair,
                        "acc_field": correct_from_acc,
                        "subj_answer": subj,
                        "correct_response": corr,
                    }
                )

        # optional image existence
        img_ok = None
        if check_images:
            img_ok = bool(_maybe_find_image(root_p, str(second_img)))

            if not img_ok:
                issues.append(
                    {
                        "row_index": idx,
                        "type": "missing_image_file_second_image",
                        "condition": cond,
                        "pair": pair,
                        "second_image": str(second_img),
                    }
                )

        rows.append(
            {
                "row_index": idx,
                "condition": cond,
                "pair": pair,
                "pair_id": pair_id_map[cond][pair],
                "pair_first_seen_row": pair_first_seen_idx[cond][pair],
                "subject_id": subj_id,
                "trial_index": trial_index,
                "first_image": str(first_img),
                "second_image": str(second_img),
                "correct": correct_inferred,
                "correct_from_acc": correct_from_acc,
                "n_fix": min(len(fix_x), len(fix_y)),
                "image_ok_second": img_ok,
            }
        )

    df = pd.DataFrame(rows)
    issues_df = pd.DataFrame(issues)

    # aggregates
    summary: Dict[str, Any] = {}
    summary["json_path"] = str(jp)
    summary["n_rows_total"] = len(raw)
    summary["n_rows_valid"] = len(df)
    summary["n_rows_with_issues"] = int(issues_df["row_index"].nunique()) if not issues_df.empty else 0
    summary["issues_by_type"] = (
        issues_df["type"].value_counts().to_dict() if not issues_df.empty else {}
    )

    cond_counts = df["condition"].value_counts().to_dict()
    summary["rows_by_condition"] = {c: int(cond_counts.get(c, 0)) for c in CONDITIONS}

    # pair summary
    pair_summary = (
        df.groupby(["condition", "pair", "pair_id"], as_index=False)
        .agg(
            n_trials=("row_index", "size"),
            n_subjects=("subject_id", lambda s: s.nunique(dropna=True)),
            right=("correct", lambda s: int(pd.Series(s).dropna().sum()) if len(pd.Series(s).dropna()) else 0),
            n_correct_known=("correct", lambda s: int(pd.Series(s).dropna().shape[0])),
            first_seen=("pair_first_seen_row", "min"),
        )
    )
    pair_summary["wrong"] = pair_summary["n_correct_known"] - pair_summary["right"]
    pair_summary["acc"] = pair_summary.apply(
        lambda r: (r["right"] / r["n_correct_known"]) if r["n_correct_known"] else float("nan"), axis=1
    )

    summary["unique_pairs_by_condition"] = {
        c: int(pair_summary[pair_summary["condition"] == c]["pair"].nunique()) for c in CONDITIONS
    }

    # cross-condition overlap
    pairs_by_cond = {
        c: set(pair_summary[pair_summary["condition"] == c]["pair"].tolist()) for c in CONDITIONS
    }
    summary["pair_overlap_full_central"] = len(pairs_by_cond["full"] & pairs_by_cond["central"])
    summary["pair_overlap_full_peripheral"] = len(pairs_by_cond["full"] & pairs_by_cond["peripheral"])
    summary["pair_overlap_central_peripheral"] = len(pairs_by_cond["central"] & pairs_by_cond["peripheral"])
    summary["pair_overlap_all_three"] = len(
        pairs_by_cond["full"] & pairs_by_cond["central"] & pairs_by_cond["peripheral"]
    )

    # pair id map output
    pid_rows: List[Dict[str, Any]] = []
    for c in CONDITIONS:
        for pair, pid in pair_id_map[c].items():
            pid_rows.append(
                {
                    "condition": c,
                    "pair": pair,
                    "pair_id": pid,
                    "first_seen_row": pair_first_seen_idx[c][pair],
                }
            )
    pair_id_df = pd.DataFrame(pid_rows).sort_values(["condition", "pair_id"])

    # save artifacts
    (out_p / "sanity_report.json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
    df.to_csv(out_p / "rows_normalized.csv", index=False)
    pair_summary.to_csv(out_p / "pair_summary.csv", index=False)
    pair_id_df.to_csv(out_p / "pair_id_map.csv", index=False)
    if not issues_df.empty:
        issues_df.to_csv(out_p / "row_issues.csv", index=False)

    # console report
    print("\n" + "=" * 80)
    print("SANITY CHECK REPORT")
    print("=" * 80)
    print(f"JSON: {jp}")
    print(f"Total rows: {summary['n_rows_total']}")
    print(f"Valid rows parsed: {summary['n_rows_valid']}")
    print(f"Rows with any issues: {summary['n_rows_with_issues']}")
    print("\nRows by condition:")
    for c in CONDITIONS:
        print(f"  {c}: {summary['rows_by_condition'][c]}")
    print("\nUnique pairs by condition:")
    for c in CONDITIONS:
        print(f"  {c}: {summary['unique_pairs_by_condition'][c]}")
    print("\nPair overlap (should usually be 0 across conditions):")
    print(f"  full ∩ central: {summary['pair_overlap_full_central']}")
    print(f"  full ∩ peripheral: {summary['pair_overlap_full_peripheral']}")
    print(f"  central ∩ peripheral: {summary['pair_overlap_central_peripheral']}")
    print(f"  all three: {summary['pair_overlap_all_three']}")

    if summary["issues_by_type"]:
        print("\nIssues by type:")
        for k, v in summary["issues_by_type"].items():
            print(f"  {k}: {v}")
    else:
        print("\nIssues by type: none ✅")

    print(f"\nSaved to: {out_p}")
    print("  - sanity_report.json")
    print("  - rows_normalized.csv")
    print("  - pair_summary.csv")
    print("  - pair_id_map.csv")
    if not issues_df.empty:
        print("  - row_issues.csv")
    print("=" * 80 + "\n")

    return out_p / "sanity_report.json"


def main() -> None:
    parser = argparse.ArgumentParser(description="Sanity check testing.json (schema, fixations, pairs, acc).")
    parser.add_argument("--root", default=".", help="Root folder to search for testing.json (and images if enabled).")
    parser.add_argument("--json", default=None, help="Optional explicit path to testing.json.")
    parser.add_argument("--out", default="sanity_check_out", help="Output directory for reports.")
    parser.add_argument("--check-images", action="store_true", help="Also verify second_image files exist (slower).")

    args, _ = parser.parse_known_args()

    sanity_check(
        root=str(args.root),
        json_path=args.json,
        out_dir=str(args.out),
        check_images=bool(args.check_images),
    )


if __name__ == "__main__":
    main()



SANITY CHECK REPORT
JSON: /Users/daisybuathatseephol/Documents/three_json_output/Testing/testing.json
Total rows: 2304
Valid rows parsed: 2304
Rows with any issues: 1

Rows by condition:
  full: 768
  central: 768
  peripheral: 768

Unique pairs by condition:
  full: 24
  central: 24
  peripheral: 24

Pair overlap (should usually be 0 across conditions):
  full ∩ central: 0
  full ∩ peripheral: 0
  central ∩ peripheral: 0
  all three: 0

Issues by type:
  acc_mismatch_subj_vs_correct: 1

Saved to: /Users/daisybuathatseephol/Documents/three_json_output/sanity_check_out
  - sanity_report.json
  - rows_normalized.csv
  - pair_summary.csv
  - pair_id_map.csv
  - row_issues.csv



In [10]:
# file: scripts/export_top_bottom3_onepage_correct_pairid.py

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: Any) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = str(x).strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue
        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _build_pair_id_map_from_json_order(trials: Sequence[TrialRow]) -> Dict[str, Dict[str, int]]:
    """
    Pair ID = first time a (condition, pair) appears in testing.json order.
    This matches your sanity_check pair_id_map.csv behavior and eliminates '?'
    """
    out: Dict[str, Dict[str, int]] = {c: {} for c in CONDITIONS}
    counters = {c: 0 for c in CONDITIONS}
    for t in trials:
        ps = t.pair.as_str
        if ps not in out[t.condition]:
            counters[t.condition] += 1
            out[t.condition][ps] = counters[t.condition]
    return out


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title1: str,
    title2: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
) -> Any:
    l, t, r, b = extent_ltrb
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(f"{title1}\n{title2}", fontsize=10)
    ax.grid(False)
    return hm


def export_one_page_top_bottom3(
    root_path: str,
    output_dir: str,
    condition: str,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
) -> Path:
    cond = _clean_condition(condition)
    if cond is None:
        raise ValueError(f"--condition must be one of {CONDITIONS}")

    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    pair_id_map = _build_pair_id_map_from_json_order(all_trials)

    # Ensure we really have 24 ids for this condition (else fail loudly)
    n_ids = len(pair_id_map[cond])
    if n_ids != 24:
        raise ValueError(
            f"Expected 24 pairs for condition={cond}, but found {n_ids}. "
            f"If this is intentional, change the check or print pair_id_map[{cond}]."
        )

    agg = _aggregate_accuracy(all_trials)
    cond_agg = agg[agg["condition"] == cond].copy()
    cond_agg = cond_agg.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
    n_pairs = len(cond_agg)

    bottom3 = cond_agg.head(3).copy()  # most incorrect -> less incorrect
    top3 = cond_agg.tail(3).sort_values(["acc", "pair"], ascending=[False, True]).copy()  # most correct -> less correct

    out_pdf = out_dir / f"{cond.upper()}_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf"

    with PdfPages(out_pdf) as pdf:
        fig, axes = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=True)
        last_hm = None

        def render(row_df: pd.DataFrame, row_idx: int, label: str) -> None:
            nonlocal last_hm
            for j in range(3):
                ax = axes[row_idx, j]
                if j >= len(row_df):
                    ax.axis("off")
                    continue

                row = row_df.iloc[j]
                pair_str = str(row["pair"])
                trials = [t for t in all_trials if t.condition == cond and t.pair.as_str == pair_str]
                if not trials:
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title("missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for t in trials[:3]:
                    sw, sh = _screen_dims(t.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)
                pts = np.vstack([_extract_fixations(t.raw) for t in trials])
                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])

                pair_id = pair_id_map[cond].get(pair_str)
                if pair_id is None:
                    raise RuntimeError(f"Pair ID missing for condition={cond} pair={pair_str}")

                correctness_label = "Correct" if acc >= 0.5 else "Incorrect"
                title1 = f"Testing • Acc: {acc*100:.1f}% • Pair {pair_id}/{n_pairs} • {correctness_label}"
                title2 = f"Subjects: right={right} • wrong={wrong} • N={N} subjects • Second: {len(pts)} fix"

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title1=title1,
                    title2=title2,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

                ax.text(
                    0.01,
                    0.99,
                    label,
                    transform=ax.transAxes,
                    va="top",
                    ha="left",
                    fontsize=10,
                    fontweight="bold",
                    bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
                )

        render(top3.reset_index(drop=True), row_idx=0, label="TOP 3")
        render(bottom3.reset_index(drop=True), row_idx=1, label="BOTTOM 3")

        fig.suptitle(
            f"Testing Phase - {cond.upper()} Viewing Condition\n"
            f"Top 3 (High Accuracy) + Bottom 3 (Low Accuracy)\n"
            f"Top: Most Correct \u2192 Least Correct   |   Bottom: Most Incorrect \u2192 Least Incorrect",
            fontsize=14,
            fontweight="bold",
        )

        if last_hm is not None:
            cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.85, pad=0.02)
            cbar.set_label("Fixation Density", rotation=90)

        pdf.savefig(fig)
        plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF, one page: Top 3 + Bottom 3 for ONE viewing condition with correct Pair IDs (no '?')."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument("--condition", default="full", choices=list(CONDITIONS))
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])  # ignore Jupyter argv
    else:
        args, _unknown = parser.parse_known_args(argv)

    out_pdf = export_one_page_top_bottom3(
        root_path=args.root,
        output_dir=args.out,
        condition=str(args.condition),
        half_box=int(args.half_box),
        screen_default=(int(args.screen_w), int(args.screen_h)),
        sigma=float(args.sigma),
        heatmap_alpha=float(args.heatmap_alpha),
        heatmap_cmap=str(args.heatmap_cmap),
        vmax_mode=str(args.vmax_mode),
    )
    print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/FULL_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf


In [13]:
# file: scripts/export_top_bottom3_onepage_correct_pairid.py
"""
Export fixation+FDM overlays for Top 3 (highest accuracy) and Bottom 3 (lowest accuracy)
pairs within a viewing condition, using the *correct* Pair ID (1..24) derived from
first appearance order in testing.json for that condition.

Outputs:
- If --condition full|central|peripheral: one PDF (one page, 2x3 grid)
- If --condition all: three PDFs (one per condition)

Example:
  python scripts/export_top_bottom3_onepage_correct_pairid.py \
    --root /path/to/three_json_output --out plots --condition all
"""

from __future__ import annotations

import argparse
import json
import math
import re
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.backends.backend_pdf import PdfPages
from PIL import Image

try:
    from scipy.ndimage import gaussian_filter  # type: ignore
except Exception:  # pragma: no cover
    gaussian_filter = None


CONDITIONS: Tuple[str, ...] = ("full", "central", "peripheral")
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff")


@dataclass(frozen=True)
class PairKey:
    a: str
    b: str

    @property
    def as_str(self) -> str:
        return f"{self.a}__{self.b}"


@dataclass
class TrialRow:
    condition: str
    pair: PairKey
    first_image: str
    second_image: str
    correct: Optional[bool]
    raw: Dict[str, Any]


def _normalize_image_name(name: Any) -> str:
    s = str(name).strip().lower()
    s = re.sub(r"\.(jpg|jpeg|png|bmp|tif|tiff)$", "", s)
    return s


def _clean_condition(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s if s in CONDITIONS else None


def _as_bool(x: Any) -> Optional[bool]:
    if x is None:
        return None
    if isinstance(x, bool):
        return x
    if isinstance(x, (int, np.integer)) and x in (0, 1):
        return bool(x)
    if isinstance(x, str):
        v = str(x).strip().lower()
        if v in ("true", "t", "yes", "y", "correct", "right", "1"):
            return True
        if v in ("false", "f", "no", "n", "incorrect", "wrong", "0"):
            return False
    return None


def _normalize_answer(x: Any) -> Optional[str]:
    if x is None:
        return None
    s = str(x).strip().lower()
    return s or None


def _detect_correct(item: Dict[str, Any]) -> Optional[bool]:
    for k in ("acc", "correct", "is_correct", "response_correct", "trial_correct", "accuracy"):
        if k in item:
            b = _as_bool(item.get(k))
            if b is not None:
                return b
    subj = _normalize_answer(item.get("subj_answer"))
    corr = _normalize_answer(item.get("correct_response"))
    if subj is not None and corr is not None:
        return subj == corr
    return None


def _find_testing_json(root: Path) -> Path:
    hits = list(root.rglob("testing.json"))
    if not hits:
        raise FileNotFoundError(f"❌ testing.json not found under: {root}")
    hits.sort(key=lambda p: (len(p.parts), str(p)))
    return hits[0]


def _find_image_file(root: Path, image_name: str) -> Optional[Path]:
    base = _normalize_image_name(image_name)
    candidates: List[Path] = []
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in IMAGE_EXTS:
            continue
        if _normalize_image_name(p.name) == base or _normalize_image_name(p.stem) == base:
            candidates.append(p)
    if not candidates:
        return None
    candidates.sort(key=lambda p: (len(p.parts), str(p)))
    return candidates[0]


def _iter_trials(data: Sequence[Dict[str, Any]]) -> Iterable[TrialRow]:
    for item in data:
        cond = _clean_condition(item.get("viewing_condition") or item.get("condition"))
        if cond is None:
            continue

        first = item.get("first_image")
        second = item.get("second_image")
        if not first or not second:
            continue

        first_s = str(first).strip()
        second_s = str(second).strip()

        a = _normalize_image_name(first_s)
        b = _normalize_image_name(second_s)
        pair = PairKey(*sorted((a, b)))

        yield TrialRow(
            condition=cond,
            pair=pair,
            first_image=first_s,
            second_image=second_s,
            correct=_detect_correct(item),
            raw=item,
        )


def _extract_fixations(item: Dict[str, Any]) -> np.ndarray:
    xs = item.get("fix_x")
    ys = item.get("fix_y")
    if not isinstance(xs, (list, tuple)) or not isinstance(ys, (list, tuple)):
        return np.zeros((0, 2), dtype=float)
    if len(xs) != len(ys) or len(xs) == 0:
        return np.zeros((0, 2), dtype=float)
    try:
        arr = np.column_stack([np.asarray(xs, dtype=float), np.asarray(ys, dtype=float)])
        arr = arr[np.isfinite(arr).all(axis=1)]
        return arr
    except Exception:
        return np.zeros((0, 2), dtype=float)


def _screen_dims(item: Dict[str, Any], default: Tuple[int, int]) -> Tuple[int, int]:
    for kw, kh in (
        ("screen_width", "screen_height"),
        ("window_width", "window_height"),
        ("display_width", "display_height"),
        ("screenW", "screenH"),
    ):
        if kw in item and kh in item:
            try:
                w = int(float(item[kw]))
                h = int(float(item[kh]))
                if w > 0 and h > 0:
                    return w, h
            except Exception:
                pass
    return default


def _image_box_extent(screen_w: int, screen_h: int, half_box: int) -> Tuple[int, int, int, int]:
    l = screen_w // 2 - half_box
    t = screen_h // 2 - half_box
    r = screen_w // 2 + half_box
    b = screen_h // 2 + half_box
    return l, t, r, b


def _load_image_resized(path: Path, size_px: int) -> np.ndarray:
    img = Image.open(path).convert("L")
    img = img.resize((size_px, size_px), resample=Image.BICUBIC)
    return np.asarray(img)


def _fdm_in_box(
    pts: np.ndarray,
    l: int,
    t: int,
    r: int,
    b: int,
    size_px: int,
    sigma: float,
) -> np.ndarray:
    if len(pts) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    in_box = (pts[:, 0] >= l) & (pts[:, 0] <= r) & (pts[:, 1] >= t) & (pts[:, 1] <= b)
    p = pts[in_box]
    if len(p) == 0:
        return np.zeros((size_px, size_px), dtype=float)

    x = (p[:, 0] - l) / max(1e-9, (r - l)) * (size_px - 1)
    y = (p[:, 1] - t) / max(1e-9, (b - t)) * (size_px - 1)
    x = np.clip(x, 0, size_px - 1)
    y = np.clip(y, 0, size_px - 1)

    heat, _, _ = np.histogram2d(y, x, bins=[size_px, size_px], range=[[0, size_px], [0, size_px]])
    heat = heat.astype(float)

    if gaussian_filter is not None:
        heat = gaussian_filter(heat, sigma=sigma)
    else:
        k = int(max(3, math.ceil(sigma * 3)) * 2 + 1)
        ax = np.arange(k) - k // 2
        kernel = np.exp(-(ax**2) / (2 * sigma**2))
        kernel /= kernel.sum()
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=0, arr=heat)
        heat = np.apply_along_axis(lambda m: np.convolve(m, kernel, mode="same"), axis=1, arr=heat)

    s = heat.sum()
    if s > 0:
        heat /= s
    return heat


def _aggregate_accuracy(trials: Sequence[TrialRow]) -> pd.DataFrame:
    rows: List[Dict[str, Any]] = []
    for t in trials:
        if t.correct is None:
            continue
        rows.append({"condition": t.condition, "pair": t.pair.as_str, "correct": int(bool(t.correct))})

    df = pd.DataFrame(rows)
    if df.empty:
        raise ValueError("❌ No correctness detected (acc / subj_answer vs correct_response).")

    return (
        df.groupby(["condition", "pair"], as_index=False)
        .agg(N=("correct", "size"), right=("correct", "sum"))
        .assign(wrong=lambda x: x["N"] - x["right"])
        .assign(acc=lambda x: x["right"] / x["N"])
    )


def _build_pair_id_map_from_json_order(trials: Sequence[TrialRow]) -> Dict[str, Dict[str, int]]:
    out: Dict[str, Dict[str, int]] = {c: {} for c in CONDITIONS}
    counters = {c: 0 for c in CONDITIONS}
    for t in trials:
        ps = t.pair.as_str
        if ps not in out[t.condition]:
            counters[t.condition] += 1
            out[t.condition][ps] = counters[t.condition]
    return out


def _plot_overlay(
    ax: plt.Axes,
    img_arr: np.ndarray,
    pts: np.ndarray,
    heat: np.ndarray,
    screen_w: int,
    screen_h: int,
    extent_ltrb: Tuple[int, int, int, int],  # (l,t,r,b)
    title1: str,
    title2: str,
    heatmap_cmap: str,
    heatmap_alpha: float,
    vmax_mode: str,
):
    l, t, r, b = extent_ltrb
    ax.imshow(img_arr, cmap="gray", origin="upper", extent=(l, r, b, t))

    if vmax_mode == "p99":
        vmax = float(np.quantile(heat, 0.99)) if heat.max() > 0 else 1.0
    else:
        vmax = float(heat.max()) if heat.max() > 0 else 1.0

    hm = ax.imshow(
        heat,
        origin="upper",
        extent=(l, r, b, t),
        cmap=heatmap_cmap,
        alpha=heatmap_alpha,
        vmin=0.0,
        vmax=vmax,
    )

    if len(pts) > 0:
        ax.scatter(
            pts[:, 0],
            pts[:, 1],
            s=30,
            facecolors=(1.0, 0.42, 0.42, 0.50),
            edgecolors=(1.0, 1.0, 1.0, 0.80),
            linewidths=0.8,
        )

    ax.set_xlim(0, screen_w)
    ax.set_ylim(screen_h, 0)
    ax.set_xlabel("x (screen px)")
    ax.set_ylabel("y (screen px)")
    ax.set_title(f"{title1}\n{title2}", fontsize=10)
    ax.grid(False)
    return hm


def export_one_page_top_bottom3(
    root_path: str,
    output_dir: str,
    condition: str,
    half_box: int = 310,
    screen_default: Tuple[int, int] = (1000, 800),
    sigma: float = 18.0,
    heatmap_alpha: float = 0.70,
    heatmap_cmap: str = "turbo",
    vmax_mode: str = "p99",
    expected_pairs: int = 24,
) -> Path:
    cond = _clean_condition(condition)
    if cond is None:
        raise ValueError(f"--condition must be one of {CONDITIONS}")

    root = Path(root_path).expanduser().resolve()
    out_dir = Path(output_dir).expanduser().resolve()
    out_dir.mkdir(parents=True, exist_ok=True)

    json_path = _find_testing_json(root)
    data = json.loads(json_path.read_text(encoding="utf-8"))
    if not isinstance(data, list):
        raise ValueError("❌ testing.json must be a list of objects.")

    all_trials = list(_iter_trials(data))
    pair_id_map = _build_pair_id_map_from_json_order(all_trials)

    n_ids = len(pair_id_map[cond])
    if expected_pairs and n_ids != expected_pairs:
        raise ValueError(f"Expected {expected_pairs} pairs for {cond}, found {n_ids}.")

    agg = _aggregate_accuracy(all_trials)
    cond_agg = agg[agg["condition"] == cond].copy()
    cond_agg = cond_agg.sort_values(["acc", "pair"], ascending=[True, True]).reset_index(drop=True)
    n_pairs = len(cond_agg)

    bottom3 = cond_agg.head(3).copy()
    top3 = cond_agg.tail(3).sort_values(["acc", "pair"], ascending=[False, True]).copy()

    out_pdf = out_dir / f"{cond.upper()}_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf"

    with PdfPages(out_pdf) as pdf:
        fig, axes = plt.subplots(2, 3, figsize=(18, 10), constrained_layout=True)
        last_hm = None

        def render(row_df: pd.DataFrame, row_idx: int, label: str) -> None:
            nonlocal last_hm
            for j in range(3):
                ax = axes[row_idx, j]
                if j >= len(row_df):
                    ax.axis("off")
                    continue

                row = row_df.iloc[j]
                pair_str = str(row["pair"])
                trials = [t for t in all_trials if t.condition == cond and t.pair.as_str == pair_str]
                if not trials:
                    ax.axis("off")
                    continue

                img_path = _find_image_file(root, trials[0].second_image)
                if img_path is None:
                    ax.set_title("missing second image", fontsize=10)
                    ax.axis("off")
                    continue

                sw, sh = screen_default
                for tr in trials[:3]:
                    sw, sh = _screen_dims(tr.raw, (sw, sh))

                l, ttop, r, b = _image_box_extent(sw, sh, half_box=half_box)
                box_size = int(half_box * 2)

                img_arr = _load_image_resized(img_path, size_px=box_size)
                pts = np.vstack([_extract_fixations(tr.raw) for tr in trials])
                heat = _fdm_in_box(pts, l=l, t=ttop, r=r, b=b, size_px=box_size, sigma=sigma)

                acc = float(row["acc"])
                right = int(row["right"])
                wrong = int(row["wrong"])
                N = int(row["N"])

                pair_id = pair_id_map[cond].get(pair_str)
                if pair_id is None:
                    raise RuntimeError(f"Pair ID missing for condition={cond} pair={pair_str}")

                correctness_label = "Correct" if acc >= 0.5 else "Incorrect"
                title1 = f"Testing • Acc: {acc*100:.1f}% • Pair {pair_id}/{n_pairs} • {correctness_label}"
                title2 = f"Subjects: right={right} • wrong={wrong} • N={N} subjects • Second: {len(pts)} fix"

                last_hm = _plot_overlay(
                    ax=ax,
                    img_arr=img_arr,
                    pts=pts,
                    heat=heat,
                    screen_w=sw,
                    screen_h=sh,
                    extent_ltrb=(l, ttop, r, b),
                    title1=title1,
                    title2=title2,
                    heatmap_cmap=heatmap_cmap,
                    heatmap_alpha=heatmap_alpha,
                    vmax_mode=vmax_mode,
                )

                ax.text(
                    0.01,
                    0.99,
                    label,
                    transform=ax.transAxes,
                    va="top",
                    ha="left",
                    fontsize=10,
                    fontweight="bold",
                    bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7),
                )

        render(top3.reset_index(drop=True), row_idx=0, label="TOP 3")
        render(bottom3.reset_index(drop=True), row_idx=1, label="BOTTOM 3")

        fig.suptitle(
            f"Testing Phase - {cond.upper()} Viewing Condition\n"
            f"Top 3 (High Accuracy) + Bottom 3 (Low Accuracy)\n"
            f"Top: Most Correct \u2192 Least Correct   |   Bottom: Most Incorrect \u2192 Least Incorrect",
            fontsize=14,
            fontweight="bold",
        )

        if last_hm is not None:
            cbar = fig.colorbar(last_hm, ax=axes.ravel().tolist(), shrink=0.85, pad=0.02)
            cbar.set_label("Fixation Density", rotation=90)

        pdf.savefig(fig)
        plt.close(fig)

    return out_pdf


def _running_in_notebook(argv: List[str]) -> bool:
    return "-f" in argv or "ipykernel" in argv[0].lower()


def main() -> None:
    parser = argparse.ArgumentParser(
        description="One PDF, one page: Top 3 + Bottom 3 for one or all viewing conditions (correct Pair IDs)."
    )
    parser.add_argument("--root", default=".", help="Folder to search for testing.json and images.")
    parser.add_argument("--out", default="plots", help="Output directory.")
    parser.add_argument(
        "--condition",
        default="all",
        choices=["all", "full", "central", "peripheral"],
        help="Which condition to export (or all).",
    )
    parser.add_argument("--half-box", type=int, default=310)
    parser.add_argument("--screen-w", type=int, default=1000)
    parser.add_argument("--screen-h", type=int, default=800)
    parser.add_argument("--sigma", type=float, default=18.0)
    parser.add_argument("--heatmap-alpha", type=float, default=0.70)
    parser.add_argument("--heatmap-cmap", default="turbo")
    parser.add_argument("--vmax-mode", default="p99", choices=["p99", "max"])
    parser.add_argument("--expected-pairs", type=int, default=24)

    argv = sys.argv[1:]
    if _running_in_notebook(sys.argv):
        args = parser.parse_args([])
    else:
        args, _unknown = parser.parse_known_args(argv)

    conditions = list(CONDITIONS) if args.condition == "all" else [str(args.condition)]

    for cond in conditions:
        out_pdf = export_one_page_top_bottom3(
            root_path=str(args.root),
            output_dir=str(args.out),
            condition=cond,
            half_box=int(args.half_box),
            screen_default=(int(args.screen_w), int(args.screen_h)),
            sigma=float(args.sigma),
            heatmap_alpha=float(args.heatmap_alpha),
            heatmap_cmap=str(args.heatmap_cmap),
            vmax_mode=str(args.vmax_mode),
            expected_pairs=int(args.expected_pairs),
        )
        print(f"✅ PDF: {out_pdf}")


if __name__ == "__main__":
    main()


✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/FULL_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf
✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/CENTRAL_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf
✅ PDF: /Users/daisybuathatseephol/Documents/three_json_output/plots/PERIPHERAL_TOP3_BOTTOM3_ONEPAGE_SECOND_IMAGE_OVERLAY.pdf
