In [None]:
# --- Reproduce CV folds, GT/preds, and plots ---
from pathlib import Path
import pickle, yaml
import numpy as np
import pandas as pd
from scipy.io import loadmat
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    classification_report
)
import matplotlib.pyplot as plt

# 1) POINT THIS TO YOUR run's --outdir
RUN_DIR = Path("output/OVTMA_fov216")
EVAL_DIR = RUN_DIR / "evaluate" / "roi_supervised_best"

# 2) Load saved ROI embedding, ROI order, trained fold classifiers
E = loadmat(str(EVAL_DIR / "roi_embedding.mat"))["E"]           # (n_roi, d)
group_ids = pickle.load(open(EVAL_DIR / "group_ids.pkl", "rb")) # len n_roi, ROI keys aligned to rows of E
clf_list  = pickle.load(open(EVAL_DIR / "best_clf_list.pkl", "rb"))

# 3) Load labels & diagnostics (to match encodings/splits)
df = pd.read_csv(RUN_DIR / "dataframes" / "df.csv")
diag = yaml.safe_load(open(EVAL_DIR / "best_roi_supervision.yaml"))["diagnostics"]
classes_str = diag["classes"]                     # label names in the fit
n_splits_eff = int(diag["n_splits_effective"])

# ROI-level labels in the embedding order
roi_lab = (pd.DataFrame({"ROI": group_ids})
           .merge(df[["ROI","roi_label"]].drop_duplicates(), on="ROI", how="left"))
assert not roi_lab["roi_label"].isna().any(), "Missing ROI labels."

# Use the same label mapping as training (LabelEncoder equivalent)
# Ensure order matches classes_str from training
le = LabelEncoder(); le.classes_ = np.array(classes_str, dtype=object)
y = le.transform(roi_lab["roi_label"].values)

# Optional grouping (if present in df)
groups = None
use_groups = False
if "Subject" in df.columns:
    gmap = df[["ROI","Subject"]].drop_duplicates()
    groups = pd.DataFrame({"ROI": group_ids}).merge(gmap, on="ROI", how="left")["Subject"].values
    # Heuristic: if unique groups >= n_splits and matches clf_count, use GroupKFold
    use_groups = pd.Series(groups).notna().all() and pd.Series(groups).nunique() >= n_splits_eff

# 4) Rebuild the splitter exactly like training
if use_groups:
    splitter = GroupKFold(n_splits=n_splits_eff)
    split_iter = list(splitter.split(E, y, groups=groups))
else:
    # Training used shuffle=True, random_state=cfg.random_state (defaults to 42)
    try:
        run_cfg = yaml.safe_load(open(RUN_DIR / "config" / "resolved_config.yaml"))["roi_supervision"]
        rs = int(run_cfg.get("random_state", 42))
    except Exception:
        rs = 42
    splitter = StratifiedKFold(n_splits=n_splits_eff, shuffle=True, random_state=rs)
    split_iter = list(splitter.split(E, y))

assert len(split_iter) == len(clf_list), "Mismatch: fold count vs saved classifiers."

# 5) Collect per-fold predictions
oof_rows = []
for fold_idx, ((tr, va), clf) in enumerate(zip(split_iter, clf_list)):
    y_true = y[va]
    y_pred = clf.predict(E[va])
    proba  = clf.predict_proba(E[va])  # columns correspond to le.classes_
    for i, idx in enumerate(va):
        row = {
            "fold": fold_idx,
            "ROI": group_ids[idx],
            "y_true": le.classes_[y_true[i]],
            "y_pred": le.classes_[y_pred[i]],
        }
        # add per-class probabilities (nice to have)
        for ci, cname in enumerate(le.classes_):
            row[f"proba_{cname}"] = float(proba[i, ci])
        oof_rows.append(row)

oof = pd.DataFrame(oof_rows)
display(oof.head())  # your per-ROI GT/preds by fold

# 6) Overall confusion matrix (OOF)
cm = confusion_matrix(oof["y_true"], oof["y_pred"], labels=le.classes_)
fig = plt.figure(figsize=(6, 5), dpi=120)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
disp.plot(values_format="d")
plt.title("ROI-level Confusion Matrix (OOF)")
plt.tight_layout()
plt.show()

# 7) Per-class precision/recall (OOF)
report = classification_report(
    oof["y_true"], oof["y_pred"], labels=le.classes_, output_dict=True, zero_division=0
)
per_class = pd.DataFrame({c: report[c] for c in le.classes_}).T[["precision","recall","f1-score","support"]]
display(per_class)

# Simple bar plots for precision/recall
for metric in ["precision","recall"]:
    plt.figure(figsize=(7,3), dpi=120)
    ax = per_class[metric].plot(kind="bar")
    plt.ylabel(metric.capitalize())
    plt.xlabel("Class")
    plt.title(f"Per-class {metric.capitalize()} (OOF)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.show()
