In [None]:
from pathlib import Path
from typing import Dict, List, Iterable, Optional, Tuple
import json
import re

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from scipy.ndimage import gaussian_filter

# ===== CONFIG =====
SCREEN_SIZE = (1024, 768)

datasets = {
    "Training 1": {"json": "Training 1/training1.json", "images": "Training 1/training1_images", "dual_images": False},
    "Training 2": {"json": "Training 2/training2.json", "images": "Training 2/training2_images", "dual_images": True},
    "Testing":    {"json": "Testing/testing.json",       "images": "Testing/testing_images",       "dual_images": True},
}

# ===== LOAD =====
def load_datasets(cfg: Dict) -> Dict:
    all_data = {}
    for name, paths in cfg.items():
        p = Path(paths["json"])
        if not p.exists():
            print(f"‚ö†Ô∏è  {name}: JSON not found at {p}")
            continue
        with p.open("r", encoding="utf-8") as f:
            trials = json.load(f)
        all_data[name] = {"data": trials, "image_folder": Path(paths["images"]), "dual_images": paths["dual_images"]}
        print(f"‚úÖ {name}: {len(trials)} trials ‚Ä¢ images: {paths['images']}")
    return all_data

# ===== FIELD ACCESS =====
SUBJECT_KEYS = ("subject", "subject_id", "subjectID", "participant", "participant_id", "worker_id", "uid", "id")

def get_subject_id(t: Dict) -> Optional[str]:
    for k in SUBJECT_KEYS:
        if k in t and t[k] is not None:
            return str(t[k])
    return None

def _norm(v) -> Optional[str]:
    if v is None:
        return None
    s = str(v).strip().lower()
    return s if s else None

def infer_correctness(t: Dict) -> Optional[bool]:
    # Prefer explicit acc flag
    if "acc" in t:
        v = t["acc"]
        if isinstance(v, bool):
            return v
        if isinstance(v, (int, float)):
            return bool(v)
        if isinstance(v, str):
            sv = _norm(v)
            if sv in {"1", "true", "correct", "right"}:
                return True
            if sv in {"0", "false", "incorrect", "wrong"}:
                return False
    # Fallback: compare subject answer to correct response
    ans = _norm(t.get("subj_answer"))
    gt  = _norm(t.get("correct_response"))
    if ans is not None and gt is not None:
        return ans == gt
    return None

def split_by_correctness(trials: List[Dict]) -> Dict[str, List[Dict]]:
    correct, incorrect = [], []
    for t in trials:
        flg = infer_correctness(t)
        if flg is True:
            correct.append(t)
        elif flg is False:
            incorrect.append(t)
    if not correct and not incorrect and trials:
        correct = trials[:]  # why: ensure visuals even if correctness missing
    return {"correct": correct, "incorrect": incorrect}

def unique_subject_count(it: Iterable[Dict]) -> int:
    ids = {sid for t in it if (sid := get_subject_id(t)) is not None}
    return len(ids) if ids else sum(1 for _ in it)

# ===== FIXATIONS =====
def _box(screen_size: Tuple[int, int]) -> Tuple[int, int, int, int]:
    w, h = screen_size
    return (w // 2 - 310, h // 2 - 310, w // 2 + 310, h // 2 + 310)

def _collect_fixations(trials: List[Dict], image_key: str, screen_size=SCREEN_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    l, t, r, b = _box(screen_size)
    xs_all, ys_all = [], []
    for tr in trials:
        xs = np.asarray(tr.get("fix_x", []), dtype=float)
        ys = np.asarray(tr.get("fix_y", []), dtype=float)
        order = np.asarray(tr.get("fix_index", np.arange(1, len(xs) + 1)), dtype=float)
        idx = tr.get("test_image_fixation_idx")
        if idx is not None:
            mask = order < idx if image_key == "first_image" else order >= idx
            xs, ys = xs[mask], ys[mask]
        m = (xs >= l) & (xs <= r) & (ys >= t) & (ys <= b)
        xs_all.extend(xs[m]); ys_all.extend(ys[m])
    return np.array(xs_all), np.array(ys_all)

def count_fixations(trials: List[Dict], image_key: str, screen_size=SCREEN_SIZE) -> int:
    xs, _ = _collect_fixations(trials, image_key, screen_size)
    return int(xs.size)

# ===== PLOTTING (unchanged look) =====
def _imshow_bg(ax, image_folder: Path, img_name: str):
    l, t, r, b = _box(SCREEN_SIZE)
    img = Image.open(image_folder / img_name).convert("RGB")
    ax.imshow(img, extent=(l, r, b, t))

def plot_fixations(ax, trials_subset: List[Dict], image_folder: Path, img_name: str, image_key: str):
    _imshow_bg(ax, image_folder, img_name)
    xs, ys = _collect_fixations(trials_subset, image_key, SCREEN_SIZE)
    if xs.size:
        ax.scatter(xs, ys, s=30, c="#FF6B6B", alpha=0.5, edgecolor="white", linewidth=0.8)
    ax.set_xlim(0, SCREEN_SIZE[0]); ax.set_ylim(SCREEN_SIZE[1], 0)
    ax.set_xlabel("x (screen px)"); ax.set_ylabel("y (screen px)")

def plot_heatmap(ax, trials_subset: List[Dict], image_folder: Path, img_name: str, image_key: str):
    _imshow_bg(ax, image_folder, img_name)
    xs, ys = _collect_fixations(trials_subset, image_key, SCREEN_SIZE)
    im = None
    if xs.size:
        l, t, r, b = _box(SCREEN_SIZE)
        H, xedges, yedges = np.histogram2d(xs, ys, bins=[np.linspace(l, r, 150), np.linspace(t, b, 150)])
        H = gaussian_filter(H, sigma=15)
        extent = [xedges[0], xedges[-1], yedges[-1], yedges[0]]
        im = ax.imshow(H.T, extent=extent, origin="upper", cmap="jet", alpha=0.6, interpolation="bilinear")
    ax.set_xlim(0, SCREEN_SIZE[0]); ax.set_ylim(SCREEN_SIZE[1], 0)
    ax.set_xlabel("x (screen px)"); ax.set_ylabel("y (screen px)")
    return im

# ===== TITLES =====
def build_title(dataset_name: str,
                item_kind: str,    # "Image" or "Pair"
                idx: int, total: int,
                subset_label: str, # "Correct"/"Incorrect"
                right_subjects: int, wrong_subjects: int,
                n_subjects_subset: int,
                fix_first: int, fix_second: Optional[int],
                viewing: Optional[str] = None) -> str:
    line1 = f"{dataset_name}"
    if viewing:
        line1 += f" ‚Ä¢ {str(viewing).capitalize()} Viewing"
    line1 += f" ‚Ä¢ {item_kind} {idx}/{total} ‚Ä¢ {subset_label}"
    line2 = f"Subjects: right={right_subjects} ‚Ä¢ wrong={wrong_subjects} ‚Ä¢ N={n_subjects_subset} subjects ‚Ä¢ First: {fix_first} fix"
    if fix_second is not None:
        line2 += f" ‚Ä¢ Second: {fix_second} fix"
    return f"{line1}\n{line2}"

# ===== UTIL =====
def sanitize(name: str, maxlen: int = 80) -> str:
    name = re.sub(r"[^\w\-]+", "-", str(name))
    name = re.sub(r"-+", "-", name).strip("-")
    return name[:maxlen]

def ensure_dir(p: Path) -> Path:
    p.mkdir(parents=True, exist_ok=True); return p

# ===== MAIN GENERATOR =====
def generate_all_fdms(all_data: Dict):
    root = ensure_dir(Path("fdm_outputs_correctness"))
    print(f"\nüìÅ Output root: {root.resolve()}")

    # --- Training 1 ---
    if "Training 1" in all_data:
        cfg = all_data["Training 1"]; data, folder = cfg["data"], cfg["image_folder"]
        uniques = sorted({t.get("first_image") for t in data if t.get("first_image")})
        base = ensure_dir(root / "training1")
        print(f"\nüìä Training 1: {len(uniques)} images")
        for i, first in enumerate(uniques, 1):
            trials = [t for t in data if t.get("first_image") == first]
            if not trials:
                continue
            split = split_by_correctness(trials)
            right_n_all = unique_subject_count(split["correct"])
            wrong_n_all = unique_subject_count(split["incorrect"])
            subdir = ensure_dir(base / f"image_{i:02d}_{sanitize(first)}")

            for key in ("correct", "incorrect"):
                subset = split[key]
                fix_first = count_fixations(subset, "first_image", SCREEN_SIZE)
                n_subjects_subset = unique_subject_count(subset)
                title = build_title(
                    dataset_name="Training 1",
                    item_kind="Image",
                    idx=i, total=len(uniques),
                    subset_label=key.capitalize(),
                    right_subjects=right_n_all,
                    wrong_subjects=wrong_n_all,
                    n_subjects_subset=n_subjects_subset,
                    fix_first=fix_first, fix_second=None, viewing=None
                )

                fig = plt.figure(figsize=(10, 14), dpi=150)
                gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1, 0.05], hspace=0.35, wspace=0.05)

                ax1 = fig.add_subplot(gs[0, 0]); plot_fixations(ax1, subset, folder, first, "first_image")
                ax1.set_title("Fixation Points Only", fontsize=11, fontweight="bold")

                ax2 = fig.add_subplot(gs[1, 0]); im = plot_heatmap(ax2, subset, folder, first, "first_image")
                ax2.set_title("Density Heatmap Only", fontsize=11, fontweight="bold")

                if im is not None:
                    cbar_ax = fig.add_subplot(gs[1, 1]); cbar = plt.colorbar(im, cax=cbar_ax); cbar.set_label("Density", fontsize=9)

                fig.suptitle(title, fontsize=12, fontweight="bold", y=0.96)
                out = subdir / f"image_{i:02d}_{key}.pdf"; plt.savefig(out, bbox_inches="tight", dpi=150); plt.close(fig)

    # --- Training 2 ---
    if "Training 2" in all_data:
        cfg = all_data["Training 2"]; data, folder = cfg["data"], cfg["image_folder"]
        pairs = sorted({(t.get("first_image"), t.get("second_image")) for t in data if t.get("first_image") and t.get("second_image")})
        base = ensure_dir(root / "training2")
        print(f"\nüìä Training 2: {len(pairs)} pairs")
        for i, (first, second) in enumerate(pairs, 1):
            trials = [t for t in data if t.get("first_image") == first and t.get("second_image") == second]
            if not trials:
                continue
            split = split_by_correctness(trials)
            right_n_all = unique_subject_count(split["correct"])
            wrong_n_all = unique_subject_count(split["incorrect"])
            subdir = ensure_dir(base / f"pair_{i:02d}_{sanitize(first)}__{sanitize(second)}")

            for key in ("correct", "incorrect"):
                subset = split[key]
                fix_first  = count_fixations(subset, "first_image", SCREEN_SIZE)
                fix_second = count_fixations(subset, "second_image", SCREEN_SIZE)
                n_subjects_subset = unique_subject_count(subset)
                title = build_title(
                    dataset_name="Training 2",
                    item_kind="Pair",
                    idx=i, total=len(pairs),
                    subset_label=key.capitalize(),
                    right_subjects=right_n_all,
                    wrong_subjects=wrong_n_all,
                    n_subjects_subset=n_subjects_subset,
                    fix_first=fix_first, fix_second=fix_second, viewing=None
                )

                fig = plt.figure(figsize=(16, 14), dpi=150)
                gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], hspace=0.35, wspace=0.3)

                ax1 = fig.add_subplot(gs[0, 0]); plot_fixations(ax1, subset, folder, first,  "first_image")
                ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight="bold")

                ax2 = fig.add_subplot(gs[0, 1]); plot_fixations(ax2, subset, folder, second, "second_image")
                ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight="bold")

                ax3 = fig.add_subplot(gs[1, 0]); im1 = plot_heatmap(ax3, subset, folder, first,  "first_image")
                ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight="bold")

                ax4 = fig.add_subplot(gs[1, 1]); _ = plot_heatmap(ax4, subset, folder, second, "second_image")
                ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight="bold")

                if im1 is not None:
                    cbar_ax = fig.add_subplot(gs[1, 2]); cbar = plt.colorbar(im1, cax=cbar_ax); cbar.set_label("Fixation Density", fontsize=10)

                fig.suptitle(title, fontsize=12, fontweight="bold", y=0.96)
                out = subdir / f"pair_{i:02d}_{key}.pdf"; plt.savefig(out, bbox_inches="tight", dpi=150); plt.close(fig)

    # --- Testing ---
    if "Testing" in all_data:
        cfg = all_data["Testing"]; data, folder = cfg["data"], cfg["image_folder"]
        combos = sorted({
            (t.get("first_image"), t.get("second_image"), t.get("viewing_condition"))
            for t in data if t.get("first_image") and t.get("second_image") and t.get("viewing_condition")
        })
        base = ensure_dir(root / "testing")
        print(f"\nüìä Testing: {len(combos)} pairs √ó viewing")
        for i, (first, second, viewing) in enumerate(combos, 1):
            trials = [t for t in data if t.get("first_image") == first and t.get("second_image") == second and t.get("viewing_condition") == viewing]
            if not trials:
                continue
            split = split_by_correctness(trials)
            right_n_all = unique_subject_count(split["correct"])
            wrong_n_all = unique_subject_count(split["incorrect"])
            subdir = ensure_dir(base / f"pair_{i:03d}_{sanitize(str(viewing))}_{sanitize(first)}__{sanitize(second)}")

            for key in ("correct", "incorrect"):
                subset = split[key]
                fix_first  = count_fixations(subset, "first_image", SCREEN_SIZE)
                fix_second = count_fixations(subset, "second_image", SCREEN_SIZE)
                n_subjects_subset = unique_subject_count(subset)
                title = build_title(
                    dataset_name="Testing",
                    item_kind="Pair",
                    idx=i, total=len(combos),
                    subset_label=key.capitalize(),
                    right_subjects=right_n_all,
                    wrong_subjects=wrong_n_all,
                    n_subjects_subset=n_subjects_subset,
                    fix_first=fix_first, fix_second=fix_second, viewing=str(viewing)
                )

                fig = plt.figure(figsize=(16, 14), dpi=150)
                gs = fig.add_gridspec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 0.05], hspace=0.35, wspace=0.3)

                ax1 = fig.add_subplot(gs[0, 0]); plot_fixations(ax1, subset, folder, first,  "first_image")
                ax1.set_title("First Image - Fixation Points Only", fontsize=11, fontweight="bold")

                ax2 = fig.add_subplot(gs[0, 1]); plot_fixations(ax2, subset, folder, second, "second_image")
                ax2.set_title("Second Image - Fixation Points Only", fontsize=11, fontweight="bold")

                ax3 = fig.add_subplot(gs[1, 0]); im1 = plot_heatmap(ax3, subset, folder, first,  "first_image")
                ax3.set_title("First Image - Density Heatmap", fontsize=11, fontweight="bold")

                ax4 = fig.add_subplot(gs[1, 1]); _ = plot_heatmap(ax4, subset, folder, second, "second_image")
                ax4.set_title("Second Image - Density Heatmap", fontsize=11, fontweight="bold")

                if im1 is not None:
                    cbar_ax = fig.add_subplot(gs[1, 2]); cbar = plt.colorbar(im1, cax=cbar_ax); cbar.set_label("Fixation Density", fontsize=10)

                fig.suptitle(title, fontsize=12, fontweight="bold", y=0.96)
                out = subdir / f"pair_{i:03d}_{viewing}_{key}.pdf"; plt.savefig(out, bbox_inches="tight", dpi=150); plt.close(fig)

    print("\nüéâ Done. Titles now match your spec, visuals unchanged.")
    print(f"üìÅ Root: {root}/")

# ===== RUN =====
if __name__ == "__main__":
    print("Loading datasets‚Ä¶")
    all_data = load_datasets(datasets)
    print("\n" + "=" * 50)
    print("GENERATING CORRECT/INCORRECT FDMs (Final titles)")
    print("=" * 50)
    generate_all_fdms(all_data)


Loading datasets‚Ä¶
‚úÖ Training 1: 1364 trials ‚Ä¢ images: Training 1/training1_images
‚úÖ Training 2: 2320 trials ‚Ä¢ images: Training 2/training2_images
‚úÖ Testing: 2304 trials ‚Ä¢ images: Testing/testing_images

GENERATING CORRECT/INCORRECT FDMs (Final titles)

üìÅ Output root: /Users/daisybuathatseephol/Documents/three_json_output/fdm_outputs_correctness

üìä Training 1: 40 images

üìä Training 2: 72 pairs

üìä Testing: 72 pairs √ó viewing

üéâ Done. Titles now match your spec, visuals unchanged.
üìÅ Root: fdm_outputs_correctness/
