# Benchmark histograms + ROC (dedicated)

This notebook is a focused, decluttered way to:
- Generate **overall summary tables** across all datasets in a benchmark run
- Plot **per-dataset histograms (ID vs OOD)** + **ROC/AUROC** for each detector

It reads the benchmark artifacts produced by `scripts/benchmark_ood_evaluation.py`:
- `*_results_dataset_based.csv` (per-dataset detector metrics)
- `*_{detector}_detailed.csv` (per-sample scores + flags)

**Convention:** higher score = more OOD.


In [97]:
from __future__ import annotations

from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# === CONFIG ===
# Point this at the benchmark output directory you want to analyze.
# Default: full-suite run (baselines + 6 Stein) across adversarial/cifar10c/cifar10p/ood_classics.
#RESULTS_DIR = Path("../results/benchmark_results_full_suite_v1").resolve()
RESULTS_DIR = Path("../results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t500_seed0").resolve()
RESULTS_DIR = Path("../results/benchmark_results_full_suite__baselines_plus_stein_perdiml2_ablations__ddpm_xt_sigma_t500_seed0").resolve()

# Choose which ID/OOD definition to analyze.
# You can run the notebook twice with different modes; outputs will be written into separate subfolders.
OOD_DEFINITION_MODE = "dataset"  # 'dataset' | 'misclassified' | 'dataset_and_misclassified'

OUT_DIR_BASE = (RESULTS_DIR / "plots_notebook").resolve()
OUT_DIR = (OUT_DIR_BASE / f"mode_{OOD_DEFINITION_MODE}").resolve()

OUT_DIR.mkdir(parents=True, exist_ok=True)
(OUT_DIR / "summaries").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "visualizations").mkdir(parents=True, exist_ok=True)
(OUT_DIR / "plots").mkdir(parents=True, exist_ok=True)

print("RESULTS_DIR:", RESULTS_DIR)
print("OUT_DIR:", OUT_DIR)
print("OOD_DEFINITION_MODE:", OOD_DEFINITION_MODE)

print("RESULTS_DIR=", RESULTS_DIR)
print("OUT_DIR=", OUT_DIR)
print("OOD_DEFINITION_MODE=", OOD_DEFINITION_MODE)

RESULTS_DIR: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0
OUT_DIR: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset
OOD_DEFINITION_MODE: dataset
RESULTS_DIR= /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0
OUT_DIR= /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset
OOD_DEFINITION_MODE= dataset


In [98]:
import re
from dataclasses import dataclass
from typing import Literal, Optional


def sanitize_filename(name: str) -> str:
    name = name.replace(":", "_")
    name = re.sub(r"[^a-zA-Z0-9._-]+", "_", name)
    return name


def roc_curve_auc(scores: np.ndarray, labels: np.ndarray):
    """ROC + AUROC assuming higher score = more OOD (label==1)."""
    scores = np.asarray(scores, dtype=np.float64)
    labels = np.asarray(labels, dtype=np.int64)
    if scores.size == 0:
        return np.array([0.0, 1.0]), np.array([0.0, 1.0]), float("nan")

    order = np.argsort(-scores, kind="mergesort")
    s = scores[order]
    y = labels[order]

    tps = np.cumsum(y == 1)
    fps = np.cumsum(y == 0)
    P = float(tps[-1])
    N = float(fps[-1])
    if P == 0 or N == 0:
        return np.array([0.0, 1.0]), np.array([0.0, 1.0]), float("nan")

    distinct_last = np.r_[s[1:] != s[:-1], True]
    tps = tps[distinct_last]
    fps = fps[distinct_last]

    fpr = fps / N
    tpr = tps / P

    if fpr.size == 0 or fpr[0] != 0.0 or tpr[0] != 0.0:
        fpr = np.r_[0.0, fpr]
        tpr = np.r_[0.0, tpr]
    if fpr[-1] != 1.0 or tpr[-1] != 1.0:
        fpr = np.r_[fpr, 1.0]
        tpr = np.r_[tpr, 1.0]

    auroc = float(np.trapezoid(tpr, fpr))
    return fpr, tpr, auroc


def split_id_ood_masks(df_det: pd.DataFrame, mode: Literal["dataset", "misclassified", "dataset_and_misclassified"]):
    is_ood = df_det["is_ood"].astype(int).to_numpy()
    is_correct = df_det["is_classified_correctly"].astype(int).to_numpy()

    if mode == "dataset":
        id_mask = (is_ood == 0)
        ood_mask = (is_ood == 1)
    elif mode == "misclassified":
        id_mask = (is_ood == 0) & (is_correct == 1)
        ood_mask = (is_ood == 0) & (is_correct == 0)
    elif mode == "dataset_and_misclassified":
        id_mask = (is_ood == 0) & (is_correct == 1)
        ood_mask = (is_ood == 1) | ((is_ood == 0) & (is_correct == 0))
    else:
        raise ValueError(mode)

    return id_mask, ood_mask


def select_score_column(df_det: pd.DataFrame, detector_name: str, stein_tail: Literal["two_sided", "upper"]):
    if detector_name.startswith("stein_"):
        col = "stein_oodness_two_sided" if stein_tail == "two_sided" else "stein_oodness_upper"
        if col in df_det.columns and df_det[col].notna().any():
            return df_det[col].astype(float).to_numpy(), col
        return df_det["score"].astype(float).to_numpy(), "score"

    return df_det["score"].astype(float).to_numpy(), "score"


@dataclass(frozen=True)
class DatasetRun:
    prefix: str
    id_dataset: str
    ood_dataset_sanitized: str
    results_csv: Path


def discover_dataset_runs(results_dir: Path) -> list[DatasetRun]:
    runs = []
    for p in sorted(results_dir.glob("*_results_dataset_based.csv")):
        stem = p.name.replace("_results_dataset_based.csv", "")
        if "_vs_" not in stem:
            continue
        id_dataset, ood_sanitized = stem.split("_vs_", 1)
        runs.append(DatasetRun(prefix=stem, id_dataset=id_dataset, ood_dataset_sanitized=ood_sanitized, results_csv=p))
    return runs


In [99]:
# Discover all dataset runs from RESULTS_DIR
runs = discover_dataset_runs(RESULTS_DIR)
manifest_df = pd.DataFrame(
    [
        {
            "dataset_prefix": r.prefix,
            "id_dataset": r.id_dataset,
            "ood_dataset_sanitized": r.ood_dataset_sanitized,
            "results_csv": str(r.results_csv),
        }
        for r in runs
    ]
)

print("Discovered runs:", len(manifest_df))
manifest_df.head()


Discovered runs: 45


Unnamed: 0,dataset_prefix,id_dataset,ood_dataset_sanitized,results_csv
0,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,/Users/michalkozyra/Developer/PhD/stein_shift_...
1,cifar10_vs_adversarial_autoattack_linf_4_255,cifar10,adversarial_autoattack_linf_4_255,/Users/michalkozyra/Developer/PhD/stein_shift_...
2,cifar10_vs_adversarial_autoattack_linf_8_255,cifar10,adversarial_autoattack_linf_8_255,/Users/michalkozyra/Developer/PhD/stein_shift_...
3,cifar10_vs_adversarial_fgsm_linf_4_255,cifar10,adversarial_fgsm_linf_4_255,/Users/michalkozyra/Developer/PhD/stein_shift_...
4,cifar10_vs_adversarial_fgsm_linf_8_255,cifar10,adversarial_fgsm_linf_8_255,/Users/michalkozyra/Developer/PhD/stein_shift_...


In [100]:
# Load all dataset-based results into one dataframe
frames = []
for _, row in manifest_df.iterrows():
    p = Path(row["results_csv"])
    if not p.exists():
        continue
    df = pd.read_csv(p)
    df["dataset_prefix"] = row["dataset_prefix"]
    df["id_dataset"] = row["id_dataset"]
    df["ood_dataset_sanitized"] = row["ood_dataset_sanitized"]
    frames.append(df)

df_all = pd.concat(frames, ignore_index=True) if frames else pd.DataFrame()
print("df_all rows:", len(df_all), "datasets:", int(df_all["dataset_prefix"].nunique()) if len(df_all) else 0)
df_all.head()


df_all rows: 360 datasets: 45


Unnamed: 0,Detector,AUROC,FPR95,AUPR_IN,AUPR_OUT,AUTC,stein_tail,AUROC_two_sided,FPR95_two_sided,AUPR_IN_two_sided,...,spearman_correlation_ood_upper,id_top1_accuracy,id_top5_accuracy,ood_top1_accuracy,ood_top5_accuracy,dataset_prefix,id_dataset,ood_dataset_sanitized,ood_top1_confidence,ood_entropy
0,stein_per_dimension_l2,0.524716,0.9018,0.552778,0.507765,0.487639,two_sided,0.524716,0.9018,0.552778,...,0.131501,0.8456,0.9929,0.165,0.9345,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,,
1,stein_per_dimension_l2_no_lap,0.521862,0.9044,0.546251,0.508143,0.489065,two_sided,0.521862,0.9044,0.546251,...,0.142271,0.8456,0.9929,0.165,0.9345,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,,
2,stein_per_dimension_l2_lap_only,0.460465,0.9603,0.471942,0.47221,0.519768,two_sided,0.460465,0.9603,0.471942,...,0.11331,0.8456,0.9929,0.165,0.9345,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,,
3,stein_per_dimension_l2_grad_only,0.473228,0.9555,0.481894,0.479253,0.513384,two_sided,0.473228,0.9555,0.481894,...,0.166697,0.8456,0.9929,0.165,0.9345,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,,
4,stein_per_dimension_l2_score_only,0.500809,0.9488,0.499524,0.501727,0.499597,two_sided,0.500809,0.9488,0.499524,...,0.001109,0.8456,0.9929,0.165,0.9345,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,,


In [101]:
# Per-dataset histogram + ROC plots (saved to OUT_DIR/plots/...)

PLOT_BENCHMARKS = ["adversarial", "cifar10c", "cifar10p", "ood_classics"]


def _detailed_csv_path(dataset_prefix: str, detector: str) -> Path:
    return RESULTS_DIR / f"{dataset_prefix}_{detector}_detailed.csv"


def _benchmark_from_prefix(dataset_prefix: str) -> str:
    if "_vs_adversarial_" in dataset_prefix:
        return "adversarial"
    if "_vs_cifar10c_" in dataset_prefix:
        return "cifar10c"
    if "_vs_cifar10p_" in dataset_prefix:
        return "cifar10p"
    # classics
    if dataset_prefix in {"cifar10_vs_svhn", "cifar10_vs_lsun", "cifar10_vs_isun", "cifar10_vs_textures", "cifar10_vs_places365"}:
        return "ood_classics"
    return "unknown"


def _stein_tail_policy_for_prefix(dataset_prefix: str) -> str:
    bench = _benchmark_from_prefix(dataset_prefix)
    # Policy used everywhere else in the notebook:
    # - adversarial: two-sided
    # - cifar10c/cifar10p/ood_classics: upper
    return "two_sided" if bench == "adversarial" else "upper"


# If True, generate *both* tails for Stein detectors (useful for debugging).
# If False (default), generate exactly one tail per dataset according to policy.
PLOT_BOTH_STEIN_TAILS = False


def plot_dataset_detector(dataset_prefix: str, detector: str, *, ood_definition_mode: str = OOD_DEFINITION_MODE):
    det_csv = _detailed_csv_path(dataset_prefix, detector)
    if not det_csv.exists():
        raise FileNotFoundError(det_csv)

    df_det = pd.read_csv(det_csv)
    id_mask, ood_mask = split_id_ood_masks(df_det, ood_definition_mode)
    if int(id_mask.sum()) == 0 or int(ood_mask.sum()) == 0:
        print(f"[skip] empty split for {dataset_prefix} {detector} mode={ood_definition_mode} (n_id={int(id_mask.sum())}, n_ood={int(ood_mask.sum())})")
        return

    if detector.startswith("stein_"):
        if PLOT_BOTH_STEIN_TAILS:
            tails = ["two_sided", "upper"]
        else:
            tails = [_stein_tail_policy_for_prefix(dataset_prefix)]
    else:
        tails = [None]

    for tail in tails:
        stein_tail = tail if tail is not None else "two_sided"
        scores, score_col = select_score_column(df_det, detector, stein_tail=stein_tail)
        labels = np.zeros(len(scores), dtype=np.int64)
        labels[ood_mask] = 1

        fpr, tpr, au = roc_curve_auc(scores[id_mask | ood_mask], labels[id_mask | ood_mask])
        fp95 = _fpr95_from_roc(fpr, tpr)

        out_dir = OUT_DIR / "plots" / ood_definition_mode / detector
        out_dir.mkdir(parents=True, exist_ok=True)
        out_path = out_dir / f"{dataset_prefix}__{detector}__tail_{tail if tail is not None else 'na'}.png"

        if detector.startswith("stein_"):
            fig, axs = plt.subplots(2, 2, figsize=(12, 8))
            ax_hist = axs[0, 0]
            ax_roc = axs[0, 1]
            ax_raw = axs[1, 0]
            ax_txt = axs[1, 1]

            ax_hist.hist(scores[id_mask], bins=60, alpha=0.6, label="ID")
            ax_hist.hist(scores[ood_mask], bins=60, alpha=0.6, label="OOD")
            ax_hist.set_title(f"{dataset_prefix}\n{detector} score={score_col} tail={tail}")
            ax_hist.legend()

            ax_roc.plot(fpr, tpr)
            ax_roc.set_xlim(0, 1)
            ax_roc.set_ylim(0, 1)
            ax_roc.set_xlabel("FPR")
            ax_roc.set_ylabel("TPR")
            ax_roc.set_title(f"ROC AUROC={au:.4f} FPR95={fp95:.4f}")

            raw = df_det["score"].astype(float).to_numpy()
            ax_raw.hist(raw[id_mask], bins=60, alpha=0.6, label="ID")
            ax_raw.hist(raw[ood_mask], bins=60, alpha=0.6, label="OOD")
            ax_raw.set_title("Raw score (from CSV)")
            ax_raw.legend()

            ax_txt.axis("off")
            ax_txt.text(
                0.0,
                0.95,
                "\n".join(
                    [
                        f"dataset_prefix: {dataset_prefix}",
                        f"detector: {detector}",
                        f"ood_definition_mode: {ood_definition_mode}",
                        f"tail: {tail}",
                        f"score_col: {score_col}",
                        f"n_id: {int(id_mask.sum())}",
                        f"n_ood: {int(ood_mask.sum())}",
                        f"AUROC: {au:.6f}",
                        f"FPR95: {fp95:.6f}",
                    ]
                ),
                va="top",
            )
        else:
            fig, axs = plt.subplots(1, 2, figsize=(12, 4))
            ax_hist, ax_roc = axs
            ax_hist.hist(scores[id_mask], bins=60, alpha=0.6, label="ID")
            ax_hist.hist(scores[ood_mask], bins=60, alpha=0.6, label="OOD")
            ax_hist.set_title(f"{dataset_prefix}\n{detector} score={score_col}")
            ax_hist.legend()

            ax_roc.plot(fpr, tpr)
            ax_roc.set_xlim(0, 1)
            ax_roc.set_ylim(0, 1)
            ax_roc.set_xlabel("FPR")
            ax_roc.set_ylabel("TPR")
            ax_roc.set_title(f"ROC AUROC={au:.4f} FPR95={fp95:.4f}")

        fig.tight_layout()
        fig.savefig(out_path, dpi=160)
        plt.close(fig)
        print("Wrote:", out_path)


# Example (uncomment to run):
# plot_dataset_detector("cifar10_vs_svhn", "stein_per_dimension_l2")


In [102]:
# Generate per-dataset plots into OUT_DIR/plots/...
# This is intentionally optional because it can be time-consuming.

GENERATE_DATASET_PLOTS = True
MAX_DATASET_RUNS = None  # e.g. 10 for quick smoke test

# If None: use all detectors present in df_all.
# Otherwise: restrict to a specific list.
DATASET_PLOT_DETECTORS = None

# Only plot these benchmark families (adversarial/cifar10c/cifar10p/ood_classics)
DATASET_PLOT_BENCHMARKS = ["adversarial", "cifar10c", "cifar10p", "ood_classics"]

if GENERATE_DATASET_PLOTS:
    if df_all is None or len(df_all) == 0:
        raise ValueError("df_all is empty; run the data-loading cells first")

    # dataset runs to plot
    runs_df = df_sub[["dataset_prefix", "benchmark"]].drop_duplicates().copy()
    runs_df = runs_df[runs_df["benchmark"].isin(DATASET_PLOT_BENCHMARKS)]
    runs = runs_df["dataset_prefix"].astype(str).tolist()
    if MAX_DATASET_RUNS is not None:
        runs = runs[: int(MAX_DATASET_RUNS)]

    # detectors to plot
    if DATASET_PLOT_DETECTORS is None:
        dets = sorted(df_all["Detector"].astype(str).unique().tolist())
    else:
        dets = list(DATASET_PLOT_DETECTORS)

    print(f"Generating dataset plots: n_runs={len(runs)}, n_detectors={len(dets)}, mode={OOD_DEFINITION_MODE}")

    n_ok = 0
    n_skip = 0
    n_fail = 0
    for dataset_prefix in runs:
        for det in dets:
            try:
                plot_dataset_detector(dataset_prefix, det, ood_definition_mode=OOD_DEFINITION_MODE)
                n_ok += 1
            except FileNotFoundError:
                # detector might not have detailed CSV saved (e.g. skipped/failed detectors)
                n_skip += 1
            except Exception as e:
                print(f"[fail] {dataset_prefix} {det}: {type(e).__name__}: {e}")
                n_fail += 1

    print(f"Done dataset plots: ok={n_ok}, skipped={n_skip}, failed={n_fail}")
else:
    print("Skipping dataset plots (set GENERATE_DATASET_PLOTS=True to enable)")


Generating dataset plots: n_runs=45, n_detectors=8, mode=dataset
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/plots/dataset/stein_per_dimension_l2/cifar10_vs_adversarial_autoattack_linf_2_255__stein_per_dimension_l2__tail_two_sided.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/plots/dataset/stein_per_dimension_l2_grad_only/cifar10_vs_adversarial_autoattack_linf_2_255__stein_per_dimension_l2_grad_only__tail_two_sided.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/plots/dataset/stein_per_dimension_l2_lap_only/cifar10_vs_adversarial_autoattack_linf_2_255__stein_per_dimension_l2_lap_only__tail_two_sided.png
Wrote: /Users/michalkozyra/Develo

In [103]:
# Quick sanity check


In [104]:
# Adversarial-only summary tables (two-sided Stein)
#
# This section computes summary statistics per detector and per attack configuration
# from the exported `*_results_dataset_based.csv` files.
#
# - Scope: adversarial benchmarks only (`cifar10_vs_adversarial_*`)
# - Stein policy: use the two-sided columns when available (`AUROC_two_sided`, `FPR95_two_sided`)
#   otherwise fall back to `AUROC`/`FPR95`.


In [105]:
import re

# Restore combined summary base (adversarial + cifar10c + cifar10p + ood_classics) and derived columns.
# This restores the previously-added derived columns like AUROC_used/FPR95_used and df_sub.

CLASSIC_OOD = {"svhn", "lsun", "isun", "textures", "places365"}

# include: adversarial / cifar10c / cifar10p / classic ood (svhn, etc)
df_sub = df_all[
    df_all["dataset_prefix"].str.contains(r"_vs_adversarial_", regex=True)
    | df_all["dataset_prefix"].str.contains(r"_vs_cifar10c_", regex=True)
    | df_all["dataset_prefix"].str.contains(r"_vs_cifar10p_", regex=True)
    | df_all["ood_dataset_sanitized"].isin(CLASSIC_OOD)
].copy()

is_adv = df_sub["dataset_prefix"].str.contains(r"_vs_adversarial_", regex=True)
is_c10c = df_sub["dataset_prefix"].str.contains(r"_vs_cifar10c_", regex=True)
is_c10p = df_sub["dataset_prefix"].str.contains(r"_vs_cifar10p_", regex=True)
is_classics = df_sub["ood_dataset_sanitized"].isin(CLASSIC_OOD)

# Category label used throughout notebook
# NOTE: new 4th category alongside adversarial/cifar10c/cifar10p
df_sub["benchmark"] = np.where(
    is_adv,
    "adversarial",
    np.where(is_c10c, "cifar10c", np.where(is_c10p, "cifar10p", np.where(is_classics, "ood_classics", None))),
)

# --- parse dataset metadata ---
pat_adv = re.compile(
    r"_vs_adversarial_(?P<attack>autoattack|pgd|fgsm)_(?P<threat>linf|l2)_(?P<eps_num>\d+)_255(?:_steps[_=](?P<steps>\d+))?"
)
pat_c10c = re.compile(r"_vs_cifar10c_(?P<corruption>[a-z0-9_]+)$")
pat_c10p = re.compile(r"_vs_cifar10p_(?P<perturbation>[a-z0-9_]+)$")

for col in ["attack", "threat", "eps", "steps", "corruption", "perturbation"]:
    if col not in df_sub.columns:
        df_sub[col] = None

if is_adv.any():
    gd = df_sub.loc[is_adv, "dataset_prefix"].str.extract(pat_adv)
    df_sub.loc[is_adv, "attack"] = gd["attack"].values
    df_sub.loc[is_adv, "threat"] = gd["threat"].values
    df_sub.loc[is_adv, "eps"] = gd["eps_num"].astype(float).values / 255.0
    df_sub.loc[is_adv, "steps"] = pd.to_numeric(gd["steps"], errors="coerce").values

if is_c10c.any():
    gd = df_sub.loc[is_c10c, "dataset_prefix"].str.extract(pat_c10c)
    df_sub.loc[is_c10c, "corruption"] = gd["corruption"].values

if is_c10p.any():
    gd = df_sub.loc[is_c10p, "dataset_prefix"].str.extract(pat_c10p)
    df_sub.loc[is_c10p, "perturbation"] = gd["perturbation"].values

# --- derived "used" metrics columns ---
is_stein = df_sub["Detector"].astype(str).str.startswith("stein")

auroc_base = df_sub["AUROC"].astype(float)
fpr_base = df_sub["FPR95"].astype(float)

auroc_two = df_sub["AUROC_two_sided"].astype(float) if "AUROC_two_sided" in df_sub.columns else auroc_base
fpr_two = df_sub["FPR95_two_sided"].astype(float) if "FPR95_two_sided" in df_sub.columns else fpr_base
auroc_up = df_sub["AUROC_upper"].astype(float) if "AUROC_upper" in df_sub.columns else auroc_base
fpr_up = df_sub["FPR95_upper"].astype(float) if "FPR95_upper" in df_sub.columns else fpr_base

stein_auroc = np.where(df_sub["benchmark"] == "adversarial", auroc_two, auroc_up)
stein_fpr = np.where(df_sub["benchmark"] == "adversarial", fpr_two, fpr_up)

df_sub["AUROC_used"] = np.where(is_stein, stein_auroc, auroc_base)
df_sub["FPR95_used"] = np.where(is_stein, stein_fpr, fpr_base)

# Split convenience frames (used by later cells)
df_adv = df_sub[df_sub["benchmark"] == "adversarial"].copy()
df_c10c = df_sub[df_sub["benchmark"] == "cifar10c"].copy()
df_c10p = df_sub[df_sub["benchmark"] == "cifar10p"].copy()
df_classics = df_sub[df_sub["benchmark"] == "ood_classics"].copy()

print("Counts:")
print(" adversarial datasets:", int(df_adv["dataset_prefix"].nunique()) if len(df_adv) else 0)
print(" cifar10c datasets:", int(df_c10c["dataset_prefix"].nunique()) if len(df_c10c) else 0)
print(" cifar10p datasets:", int(df_c10p["dataset_prefix"].nunique()) if len(df_c10p) else 0)
print(" ood_classics datasets:", int(df_classics["dataset_prefix"].nunique()) if len(df_classics) else 0)

# keep backward-compatible preview
(df_sub[["dataset_prefix", "benchmark", "Detector", "AUROC_used", "FPR95_used"]].head(10))

Counts:
 adversarial datasets: 10
 cifar10c datasets: 19
 cifar10p datasets: 11
 ood_classics datasets: 5


Unnamed: 0,dataset_prefix,benchmark,Detector,AUROC_used,FPR95_used
0,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2,0.524716,0.9018
1,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_no_lap,0.521862,0.9044
2,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_lap_only,0.460465,0.9603
3,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_grad_only,0.473228,0.9555
4,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_score_only,0.500809,0.9488
5,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_lap_only_std,0.460465,0.9603
6,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_no_lap_std,0.521862,0.9044
7,cifar10_vs_adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_std_balanced,0.497459,0.9262
8,cifar10_vs_adversarial_autoattack_linf_4_255,adversarial,stein_per_dimension_l2,0.622627,0.9189
9,cifar10_vs_adversarial_autoattack_linf_4_255,adversarial,stein_per_dimension_l2_no_lap,0.619524,0.9268


In [106]:
# Summary tables + visualizations for all benchmark families (including new: ood_classics)

VIS_DIR = OUT_DIR / "visualizations"
SUM_DIR = OUT_DIR / "summaries"


def _savefig(fig, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    fig.tight_layout()
    fig.savefig(path, dpi=180)
    plt.close(fig)
    print("Wrote:", path)


# Convenience: a compact table we can use for plotting
ultimate_plot = ultimate.copy()
ultimate_plot = ultimate_plot[ultimate_plot["ood_definition_mode"] == OOD_DEFINITION_MODE].copy()

print("ultimate_plot rows:", len(ultimate_plot))
ultimate_plot[["benchmark", "detector", "AUROC", "FPR95"]].head()


ultimate_plot rows: 360


Unnamed: 0,benchmark,detector,AUROC,FPR95
0,adversarial,stein_per_dimension_l2,0.592372,0.9405
3,adversarial,stein_per_dimension_l2_no_lap,0.591303,0.9388
6,adversarial,stein_per_dimension_l2_lap_only,0.460535,0.9562
9,adversarial,stein_per_dimension_l2_grad_only,0.4733,0.9608
12,adversarial,stein_per_dimension_l2_score_only,0.564742,0.9306


In [107]:
# Visualization: AUROC distribution per detector (boxplots)

for bench in ["adversarial", "cifar10c", "cifar10p", "ood_classics"]:
    dfb = ultimate_plot[ultimate_plot["benchmark"] == bench].copy()
    if len(dfb) == 0:
        print("[skip] no rows for", bench)
        continue

    fig, ax = plt.subplots(figsize=(max(10, 0.45 * dfb["detector"].nunique()), 5))
    sns.boxplot(data=dfb, x="detector", y="AUROC", ax=ax)
    ax.set_title(f"{bench}: AUROC distribution by detector (mode={OOD_DEFINITION_MODE})")
    ax.set_ylim(0.0, 1.0)
    ax.tick_params(axis="x", rotation=45)

    _savefig(fig, VIS_DIR / f"{bench}_auroc_boxplot_by_detector__mode_{OOD_DEFINITION_MODE}.png")


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_auroc_boxplot_by_detector__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10c_auroc_boxplot_by_detector__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10p_auroc_boxplot_by_detector__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/ood_classics_auroc_boxplot_by_detector__mode_dataset.png


In [108]:
# Visualization: overall mean AUROC per detector (bar plots), annotate with mean FPR95

for bench in ["adversarial", "cifar10c", "cifar10p", "ood_classics"]:
    dfb = ultimate_plot[ultimate_plot["benchmark"] == bench].copy()
    if len(dfb) == 0:
        continue

    mean_df = dfb.groupby("detector", as_index=False).agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
    mean_df = mean_df.sort_values("AUROC", ascending=False)

    out_csv = SUM_DIR / f"{bench}_overall_detector_mean__mode_{OOD_DEFINITION_MODE}.csv"
    mean_df.to_csv(out_csv, index=False)
    print("Wrote:", out_csv)

    fig, ax = plt.subplots(figsize=(max(10, 0.45 * len(mean_df)), 5))
    sns.barplot(data=mean_df, x="detector", y="AUROC", ax=ax)
    ax.set_title(f"{bench}: mean AUROC by detector (mode={OOD_DEFINITION_MODE})")
    ax.set_ylim(0.0, 1.0)
    ax.tick_params(axis="x", rotation=45)

    # annotate with FPR95
    for i, r in enumerate(mean_df.itertuples(index=False)):
        ax.text(i, float(r.AUROC) + 0.01, f"FPR95={float(r.FPR95):.3f}", ha="center", va="bottom", fontsize=8, rotation=90)

    _savefig(fig, VIS_DIR / f"{bench}_mean_auroc_bar_by_detector__mode_{OOD_DEFINITION_MODE}.png")


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/adversarial_overall_detector_mean__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_mean_auroc_bar_by_detector__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/cifar10c_overall_detector_mean__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10c_mean_auroc_bar_by_detector__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_s

In [109]:
# Visualization: heatmaps (mean AUROC)

# CIFAR-10-C heatmap: detector x corruption
c10c = ultimate_plot[ultimate_plot["benchmark"] == "cifar10c"].copy()
if len(c10c):
    pivot = c10c.pivot_table(index="detector", columns="member_key", values="AUROC", aggfunc="mean")
    fig, ax = plt.subplots(figsize=(max(10, 0.25 * pivot.shape[1]), max(6, 0.35 * pivot.shape[0])))
    sns.heatmap(pivot, vmin=0.0, vmax=1.0, cmap="viridis", ax=ax)
    ax.set_title(f"cifar10c: mean AUROC by corruption (mode={OOD_DEFINITION_MODE})")
    _savefig(fig, VIS_DIR / f"cifar10c_heatmap_mean_auroc_by_corruption__mode_{OOD_DEFINITION_MODE}.png")

# CIFAR-10-P heatmap: detector x perturbation
c10p = ultimate_plot[ultimate_plot["benchmark"] == "cifar10p"].copy()
if len(c10p):
    pivot = c10p.pivot_table(index="detector", columns="member_key", values="AUROC", aggfunc="mean")
    fig, ax = plt.subplots(figsize=(max(10, 0.25 * pivot.shape[1]), max(6, 0.35 * pivot.shape[0])))
    sns.heatmap(pivot, vmin=0.0, vmax=1.0, cmap="viridis", ax=ax)
    ax.set_title(f"cifar10p: mean AUROC by perturbation (mode={OOD_DEFINITION_MODE})")
    _savefig(fig, VIS_DIR / f"cifar10p_heatmap_mean_auroc_by_perturbation__mode_{OOD_DEFINITION_MODE}.png")

# Adversarial heatmap: detector x member_key (attack:threat:eps:steps)
adv = ultimate_plot[ultimate_plot["benchmark"] == "adversarial"].copy()
if len(adv):
    pivot = adv.pivot_table(index="detector", columns="member_key", values="AUROC", aggfunc="mean")
    fig, ax = plt.subplots(figsize=(max(10, 0.25 * pivot.shape[1]), max(6, 0.35 * pivot.shape[0])))
    sns.heatmap(pivot, vmin=0.0, vmax=1.0, cmap="viridis", ax=ax)
    ax.set_title(f"adversarial: mean AUROC by attack config (mode={OOD_DEFINITION_MODE})")
    _savefig(fig, VIS_DIR / f"adversarial_heatmap_mean_auroc__mode_{OOD_DEFINITION_MODE}.png")

# NEW: classic OOD heatmap: detector x dataset (svhn/lsun/isun/textures/places365)
classic = ultimate_plot[ultimate_plot["benchmark"] == "ood_classics"].copy()
if len(classic):
    pivot = classic.pivot_table(index="detector", columns="member_key", values="AUROC", aggfunc="mean")
    fig, ax = plt.subplots(figsize=(max(10, 0.6 * pivot.shape[1]), max(6, 0.35 * pivot.shape[0])))
    sns.heatmap(pivot, vmin=0.0, vmax=1.0, cmap="viridis", ax=ax)
    ax.set_title(f"ood_classics: mean AUROC by dataset (mode={OOD_DEFINITION_MODE})")
    _savefig(fig, VIS_DIR / f"ood_classics_heatmap_mean_auroc_by_dataset__mode_{OOD_DEFINITION_MODE}.png")


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10c_heatmap_mean_auroc_by_corruption__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10p_heatmap_mean_auroc_by_perturbation__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_heatmap_mean_auroc__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/ood_classics_heatmap_mean_auroc_by_dataset__mode_dataset.png


In [110]:
# Rankings tables per family_key (member_key) and overall

rank_tables = []

for bench in ["adversarial", "cifar10c", "cifar10p", "ood_classics"]:
    dfb = ultimate_plot[ultimate_plot["benchmark"] == bench].copy()
    if len(dfb) == 0:
        continue

    # Per member_key ranking (mean over datasets in that group; for adversarial member_key already includes eps/steps)
    by_member = (
        dfb.groupby(["member_key", "detector"], as_index=False)
        .agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
        .sort_values(["member_key", "AUROC"], ascending=[True, False])
    )
    out_member = SUM_DIR / f"{bench}_detector_by_member_key__mode_{OOD_DEFINITION_MODE}.csv"
    by_member.to_csv(out_member, index=False)
    print("Wrote:", out_member)

    # Overall detector ranking (pooled mean over member_key)
    overall = dfb.groupby("detector", as_index=False).agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
    overall = overall.sort_values("AUROC", ascending=False)
    out_overall = SUM_DIR / f"{bench}_overall_detector_mean__mode_{OOD_DEFINITION_MODE}.csv"
    # (already written in cell 15; keep this as idempotent)
    overall.to_csv(out_overall, index=False)

    overall["benchmark"] = bench
    rank_tables.append(overall)

rank_overall = pd.concat(rank_tables, ignore_index=True) if rank_tables else pd.DataFrame()
rank_overall.head()


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/adversarial_detector_by_member_key__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/cifar10c_detector_by_member_key__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/cifar10p_detector_by_member_key__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/ood_classics_detector_by_member_key__mode_dataset.csv


Unnamed: 0,detector,AUROC,FPR95,benchmark
0,stein_per_dimension_l2_std_balanced,0.62684,0.84978,adversarial
1,stein_per_dimension_l2_no_lap_std,0.626399,0.87107,adversarial
2,stein_per_dimension_l2_no_lap,0.626399,0.87107,adversarial
3,stein_per_dimension_l2,0.6256,0.87059,adversarial
4,stein_per_dimension_l2_grad_only,0.609333,0.84167,adversarial


In [111]:
# Visualization: overall ranking barplots (mean AUROC), annotate with FPR95

for bench in ["adversarial", "cifar10c", "cifar10p", "ood_classics"]:
    dfb = ultimate_plot[ultimate_plot["benchmark"] == bench].copy()
    if len(dfb) == 0:
        continue

    overall = dfb.groupby("detector", as_index=False).agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
    overall = overall.sort_values("AUROC", ascending=False)

    fig, ax = plt.subplots(figsize=(max(10, 0.45 * len(overall)), 5))
    sns.barplot(data=overall, x="AUROC", y="detector", ax=ax)
    ax.set_title(f"{bench}: overall ranking by mean AUROC (mode={OOD_DEFINITION_MODE})")
    ax.set_xlim(0.0, 1.0)

    for i, r in enumerate(overall.itertuples(index=False)):
        ax.text(float(r.AUROC) + 0.01, i, f"FPR95={float(r.FPR95):.3f}", va="center", fontsize=9)

    _savefig(fig, VIS_DIR / f"{bench}_overall_ranking_mean_auroc__mode_{OOD_DEFINITION_MODE}.png")


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_overall_ranking_mean_auroc__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10c_overall_ranking_mean_auroc__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/cifar10p_overall_ranking_mean_auroc__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/ood_classics_overall_ranking_mean_auroc__mode_dataset.png


In [112]:
# Visualization: top-N ranking plots per member_key (attack/corruption/perturbation/classic-dataset)

TOPN = 10

for bench in ["adversarial", "cifar10c", "cifar10p", "ood_classics"]:
    dfb = ultimate_plot[ultimate_plot["benchmark"] == bench].copy()
    if len(dfb) == 0:
        continue

    by_member = (
        dfb.groupby(["member_key", "detector"], as_index=False)
        .agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
        .sort_values(["member_key", "AUROC"], ascending=[True, False])
    )

    for member, g in by_member.groupby("member_key"):
        g = g.sort_values("AUROC", ascending=False).head(TOPN)
        if len(g) == 0:
            continue

        fig, ax = plt.subplots(figsize=(10, max(3, 0.35 * len(g))))
        sns.barplot(data=g, x="AUROC", y="detector", ax=ax)
        ax.set_xlim(0.0, 1.0)
        ax.set_title(f"{bench}: top{TOPN} by AUROC for {member} (mode={OOD_DEFINITION_MODE})")

        for i, r in enumerate(g.itertuples(index=False)):
            ax.text(float(r.AUROC) + 0.01, i, f"FPR95={float(r.FPR95):.3f}", va="center", fontsize=9)

        out = VIS_DIR / f"{bench}_ranking_top{TOPN}_mean_auroc__member_{sanitize_filename(str(member))}__mode_{OOD_DEFINITION_MODE}.png"
        _savefig(fig, out)


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_ranking_top10_mean_auroc__member_autoattack_linf_0.00784313725490196_nan__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_ranking_top10_mean_auroc__member_autoattack_linf_0.01568627450980392_nan__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/visualizations/adversarial_ranking_top10_mean_auroc__member_autoattack_linf_0.03137254901960784_nan__mode_dataset.png
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_datase

In [113]:
# Convenience exports: one combined CSV for plotting elsewhere

combined_out = SUM_DIR / f"ultimate_metrics_long__mode_{OOD_DEFINITION_MODE}.csv"
ultimate_plot.to_csv(combined_out, index=False)
print("Wrote:", combined_out)

# Also export classic-only slice (useful for quick inspection)
classic_out = SUM_DIR / f"ood_classics_metrics_long__mode_{OOD_DEFINITION_MODE}.csv"
ultimate_plot[ultimate_plot["benchmark"] == "ood_classics"].to_csv(classic_out, index=False)
print("Wrote:", classic_out)


Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/ultimate_metrics_long__mode_dataset.csv
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/ood_classics_metrics_long__mode_dataset.csv


In [114]:
# Ultimate table (restored): one row per dataset_prefix Ã— detector Ã— mode
# NOTE: This is used for downstream pivot tables / aggregation.

MODES = ["dataset", "misclassified", "dataset_and_misclassified"]


def _fpr95_from_roc(fpr: np.ndarray, tpr: np.ndarray) -> float:
    fpr = np.asarray(fpr, dtype=np.float64)
    tpr = np.asarray(tpr, dtype=np.float64)
    if fpr.size == 0 or tpr.size == 0:
        return float("nan")
    idx = np.where(tpr >= 0.95)[0]
    if idx.size == 0:
        return float("nan")
    return float(fpr[idx[0]])


def _stein_tail_for_benchmark(bench: str) -> str:
    return "two_sided" if bench == "adversarial" else "upper"


def _compute_metrics_from_detailed(df_det: pd.DataFrame, detector: str, mode: str, stein_tail: str):
    id_mask, ood_mask = split_id_ood_masks(df_det, mode=mode)
    scores, score_col = select_score_column(df_det, detector, stein_tail=stein_tail)

    id_scores = scores[id_mask]
    ood_scores = scores[ood_mask]
    labels = np.r_[np.zeros_like(id_scores, dtype=np.int64), np.ones_like(ood_scores, dtype=np.int64)]
    all_scores = np.r_[id_scores, ood_scores]

    fpr, tpr, auroc = roc_curve_auc(all_scores, labels)
    fpr95 = _fpr95_from_roc(fpr, tpr)

    return float(auroc), float(fpr95), int(id_scores.size), int(ood_scores.size), str(score_col)


base_cols = [
    "dataset_prefix",
    "id_dataset",
    "ood_dataset_sanitized",
    "benchmark",
    "Detector",
    "attack",
    "threat",
    "eps",
    "steps",
    "corruption",
    "perturbation",
    "AUROC_used",
    "FPR95_used",
]

# Defensive: ensure df_sub includes *all* detectors for the selected dataset_prefix set.
# We've seen cases where the dataset-level plots exist for baselines, but downstream
# summary tables (ultimate/overall rankings) end up Stein-only due to partial reruns
# or stale intermediate frames.

df_sub_for_ultimate = df_sub.copy()

# Only consider the dataset_prefix universe weâ€™re summarizing.
prefixes = sorted(df_sub_for_ultimate["dataset_prefix"].astype(str).unique().tolist())

# Build a stable (dataset_prefix, Detector) key for joins.
def _mk_key(df: pd.DataFrame) -> pd.Series:
    return df["dataset_prefix"].astype(str) + "||" + df["Detector"].astype(str)

existing_keys = set(_mk_key(df_sub_for_ultimate).tolist())

# Candidate extra rows: any detector rows present in df_all for those prefixes,
# but missing from df_sub_for_ultimate.
df_all_sel = df_all[df_all["dataset_prefix"].astype(str).isin(prefixes)].copy()
df_all_sel["__key__"] = _mk_key(df_all_sel)
extra = df_all_sel[~df_all_sel["__key__"].isin(existing_keys)].copy()

if len(extra):
    # Attach benchmark parsing metadata from df_sub_for_ultimate (per dataset_prefix).
    meta_cols = [
        "dataset_prefix",
        "benchmark",
        "attack",
        "threat",
        "eps",
        "steps",
        "corruption",
        "perturbation",
    ]
    meta = df_sub_for_ultimate.drop_duplicates(subset=["dataset_prefix"])[meta_cols].copy()
    extra = extra.merge(meta, on="dataset_prefix", how="left")

    # Recompute AUROC_used/FPR95_used for extra rows.
    # Baselines use raw AUROC/FPR95; Stein uses per-benchmark tail policy (same as df_sub logic).
    is_stein_extra = extra["Detector"].astype(str).str.startswith("stein")

    auroc_base = extra["AUROC"].astype(float)
    fpr_base = extra["FPR95"].astype(float)

    auroc_two = extra["AUROC_two_sided"].astype(float) if "AUROC_two_sided" in extra.columns else auroc_base
    fpr_two = extra["FPR95_two_sided"].astype(float) if "FPR95_two_sided" in extra.columns else fpr_base
    auroc_up = extra["AUROC_upper"].astype(float) if "AUROC_upper" in extra.columns else auroc_base
    fpr_up = extra["FPR95_upper"].astype(float) if "FPR95_upper" in extra.columns else fpr_base

    stein_auroc = np.where(extra["benchmark"].astype(str) == "adversarial", auroc_two, auroc_up)
    stein_fpr = np.where(extra["benchmark"].astype(str) == "adversarial", fpr_two, fpr_up)

    extra["AUROC_used"] = np.where(is_stein_extra, stein_auroc, auroc_base)
    extra["FPR95_used"] = np.where(is_stein_extra, stein_fpr, fpr_base)

    # Align columns and append.
    keep_cols = sorted(set(df_sub_for_ultimate.columns) | set(extra.columns))
    for c in keep_cols:
        if c not in df_sub_for_ultimate.columns:
            df_sub_for_ultimate[c] = None
        if c not in extra.columns:
            extra[c] = None
    df_sub_for_ultimate = pd.concat([df_sub_for_ultimate[keep_cols], extra[keep_cols]], ignore_index=True)

base = df_sub_for_ultimate[base_cols].copy()
base = base.drop_duplicates(subset=["dataset_prefix", "Detector"]).reset_index(drop=True)

print(
    "[ultimate] prefixes=", len(prefixes),
    " detectors=", int(base["Detector"].nunique()),
    " (including baselines=", int((~base["Detector"].astype(str).str.startswith("stein")).sum()), ")",
)

rows = []
for r in base.itertuples(index=False):
    dataset_prefix = str(r.dataset_prefix)
    det = str(r.Detector)
    bench = str(r.benchmark)
    stein_tail = _stein_tail_for_benchmark(bench)

    det_csv = RESULTS_DIR / f"{dataset_prefix}_{det}_detailed.csv"
    df_det = pd.read_csv(det_csv) if det_csv.exists() else None

    for mode in MODES:
        rec = {
            "dataset_prefix": dataset_prefix,
            "id_dataset": str(r.id_dataset),
            "ood_dataset_sanitized": str(r.ood_dataset_sanitized),
            "benchmark": bench,
            "detector": det,
            "ood_definition_mode": mode,
            "stein_tail_used": (stein_tail if det.startswith("stein_") else None),
            "attack": r.attack,
            "threat": r.threat,
            "eps": r.eps,
            "steps": r.steps,
            "corruption": r.corruption,
            "perturbation": r.perturbation,
        }

        # Default: take dataset-mode metrics from the summary CSV (fast) and recompute
        # from detailed CSV for the other ood_definition_mode variants.
        #
        # IMPORTANT: for Stein detectors we want *tail-aware* metrics in all summaries/plots.
        # The benchmark per-dataset CSV does not always include AUROC_upper/AUROC_two_sided
        # columns, so for Stein we must take AUROC/FPR95 from the detailed CSV using the
        # selected tail policy (upper for cifar10c/p + ood_classics, two_sided for adversarial).

        if mode == "dataset":
            rec["AUROC"] = float(r.AUROC_used)
            rec["FPR95"] = float(r.FPR95_used)
            rec["metric_source"] = "used_from_summary"
        else:
            rec["metric_source"] = "detailed_csv"

        if df_det is not None:
            au, fp, n_id, n_ood, score_col = _compute_metrics_from_detailed(df_det, det, mode, stein_tail)
            rec["AUROC_recomputed"] = au
            rec["FPR95_recomputed"] = fp
            rec["n_id"] = n_id
            rec["n_ood"] = n_ood
            rec["score_col"] = score_col

            # Use recomputed values for non-dataset modes, and ALSO for Stein dataset-mode
            # so that upper-tail is reflected in ultimate/summaries/visualizations.
            if (mode != "dataset") or det.startswith("stein_"):
                rec["AUROC"] = au
                rec["FPR95"] = fp
                rec["metric_source"] = (
                    "detailed_csv_used_tail" if (mode == "dataset" and det.startswith("stein_")) else rec["metric_source"]
                )
        else:
            rec["AUROC_recomputed"] = float("nan")
            rec["FPR95_recomputed"] = float("nan")
            rec["n_id"] = None
            rec["n_ood"] = None
            rec["score_col"] = None

        rows.append(rec)

ultimate = pd.DataFrame(rows)

ultimate["family_key"] = ultimate["benchmark"]
ultimate["member_key"] = np.where(
    ultimate["benchmark"] == "adversarial",
    ultimate["attack"].astype(str) + ":" + ultimate["threat"].astype(str) + ":" + ultimate["eps"].astype(str) + ":" + ultimate["steps"].astype(str),
    np.where(
        ultimate["benchmark"] == "cifar10c",
        ultimate["corruption"].astype(str),
        np.where(
            ultimate["benchmark"] == "cifar10p",
            ultimate["perturbation"].astype(str),
            np.where(ultimate["benchmark"] == "ood_classics", ultimate["ood_dataset_sanitized"].astype(str), "unknown"),
        ),
    ),
)

out_path = OUT_DIR / "summaries" / "ultimate_metrics_long.csv"
ultimate.to_csv(out_path, index=False)
print("Wrote:", out_path)
print("rows:", len(ultimate))
ultimate.head(10)

[ultimate] prefixes= 45  detectors= 8  (including baselines= 0 )
Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/ultimate_metrics_long.csv
rows: 1080


Unnamed: 0,dataset_prefix,id_dataset,ood_dataset_sanitized,benchmark,detector,ood_definition_mode,stein_tail_used,attack,threat,eps,...,AUROC,FPR95,metric_source,AUROC_recomputed,FPR95_recomputed,n_id,n_ood,score_col,family_key,member_key
0,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2,dataset,two_sided,autoattack,linf,0.007843,...,0.524784,0.9488,detailed_csv_used_tail,0.524784,0.9488,10000,10000,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
1,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2,misclassified,two_sided,autoattack,linf,0.007843,...,0.632341,0.876656,detailed_csv,0.632341,0.876656,8456,1544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
2,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2,dataset_and_misclassified,two_sided,autoattack,linf,0.007843,...,0.555343,0.939924,detailed_csv,0.555343,0.939924,8456,11544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
3,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_no_lap,dataset,two_sided,autoattack,linf,0.007843,...,0.521931,0.949,detailed_csv_used_tail,0.521931,0.949,10000,10000,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
4,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_no_lap,misclassified,two_sided,autoattack,linf,0.007843,...,0.639079,0.882569,detailed_csv,0.639079,0.882569,8456,1544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
5,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_no_lap,dataset_and_misclassified,two_sided,autoattack,linf,0.007843,...,0.554969,0.939688,detailed_csv,0.554969,0.939688,8456,11544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
6,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_lap_only,dataset,two_sided,autoattack,linf,0.007843,...,0.460535,0.9562,detailed_csv_used_tail,0.460535,0.9562,10000,10000,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
7,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_lap_only,misclassified,two_sided,autoattack,linf,0.007843,...,0.586684,0.890847,detailed_csv,0.586684,0.890847,8456,1544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
8,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_lap_only,dataset_and_misclassified,two_sided,autoattack,linf,0.007843,...,0.489079,0.949149,detailed_csv,0.489079,0.949149,8456,11544,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan
9,cifar10_vs_adversarial_autoattack_linf_2_255,cifar10,adversarial_autoattack_linf_2_255,adversarial,stein_per_dimension_l2_grad_only,dataset,two_sided,autoattack,linf,0.007843,...,0.4733,0.9608,detailed_csv_used_tail,0.4733,0.9608,10000,10000,stein_oodness_two_sided,adversarial,autoattack:linf:0.00784313725490196:nan


In [115]:
set(ultimate.detector)

{'stein_per_dimension_l2',
 'stein_per_dimension_l2_grad_only',
 'stein_per_dimension_l2_lap_only',
 'stein_per_dimension_l2_lap_only_std',
 'stein_per_dimension_l2_no_lap',
 'stein_per_dimension_l2_no_lap_std',
 'stein_per_dimension_l2_score_only',
 'stein_per_dimension_l2_std_balanced'}

In [None]:
# Pivot table: rows=(ood_dataset_sanitized, metric), cols=detector
# Adds a helper column `benchmark` (adversarial/cifar10c/cifar10p/ood_classics).

# Default detector set for the pivot table.
# Includes baselines + the new 6-Stein subset used by full-suite runs.
DETECTORS = [
    "msp",
    "energy",
    "odin",
    "mahalanobis",
    "knn",
    "gsc",
'stein_per_dimension_l2',
 'stein_per_dimension_l2_no_lap',
 'stein_per_dimension_l2_score_only',
]

# Select which split to pivot
PIVOT_OOD_DEFINITION_MODE = OOD_DEFINITION_MODE  # or set explicitly: 'dataset'/'misclassified'/'dataset_and_misclassified'

sel = ultimate[(ultimate["ood_definition_mode"] == PIVOT_OOD_DEFINITION_MODE) & (ultimate["detector"].isin(DETECTORS))].copy()

# Map each ood_dataset_sanitized to a single benchmark family (should be unique)
bench_map = sel.groupby("ood_dataset_sanitized", as_index=False)["benchmark"].agg(lambda s: s.dropna().astype(str).unique())
bench_map["benchmark"] = bench_map["benchmark"].apply(lambda a: (a[0] if len(a) else None))

# ensure single value per dataset+detector
agg = (
    sel.groupby(["ood_dataset_sanitized", "detector"], as_index=False)
    .agg(AUROC=("AUROC", "mean"), FPR95=("FPR95", "mean"))
)

long = agg.melt(
    id_vars=["ood_dataset_sanitized", "detector"],
    value_vars=["AUROC", "FPR95"],
    var_name="metric",
    value_name="value",
)

pivot = long.pivot_table(
    index=["ood_dataset_sanitized", "metric"],
    columns="detector",
    values="value",
    aggfunc="mean",
).reindex(columns=DETECTORS)

# Add benchmark column (repeated for both AUROC/FPR95 rows)
pivot_df = pivot.reset_index().merge(bench_map[["ood_dataset_sanitized", "benchmark"]], on="ood_dataset_sanitized", how="left")
pivot_df = pivot_df.set_index(["ood_dataset_sanitized", "metric"])

out_path = OUT_DIR / "summaries" / f"ultimate_pivot__oodmode_{PIVOT_OOD_DEFINITION_MODE}.csv"
pivot_df.to_csv(out_path)
print("Wrote:", out_path)

pivot_df

Wrote: /Users/michalkozyra/Developer/PhD/stein_shift_detection/results/benchmark_results_stein_perdiml2_only__ddpm_xt_sigma_t50_seed0/plots_notebook/mode_dataset/summaries/ultimate_pivot__oodmode_dataset.csv


Unnamed: 0_level_0,Unnamed: 1_level_0,msp,energy,odin,mahalanobis,knn,gsc,react,stein_per_dimension_l2,stein_per_dimension_l2_grad_only,stein_per_dimension_l2_lap_only,stein_per_dimension_l2_lap_only_std,stein_per_dimension_l2_no_lap,stein_per_dimension_l2_no_lap_std,stein_per_dimension_l2_score_only,stein_per_dimension_l2_std_balanced,benchmark
ood_dataset_sanitized,metric,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
adversarial_autoattack_linf_2_255,AUROC,,,,,,,,0.524784,0.473300,0.460535,0.460535,0.521931,0.521931,0.500878,0.497528,adversarial
adversarial_autoattack_linf_2_255,FPR95,,,,,,,,0.948800,0.960800,0.956200,0.956200,0.949000,0.949000,0.947400,0.954800,adversarial
adversarial_autoattack_linf_4_255,AUROC,,,,,,,,0.622695,0.667659,0.665778,0.665778,0.619593,0.619593,0.515140,0.651572,adversarial
adversarial_autoattack_linf_4_255,FPR95,,,,,,,,0.919700,0.875900,0.874800,0.874800,0.914400,0.914400,0.949600,0.887000,adversarial
adversarial_autoattack_linf_8_255,AUROC,,,,,,,,0.851822,0.868403,0.864422,0.864422,0.849997,0.849997,0.580506,0.862165,adversarial
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
places365,FPR95,,,,,,,,0.555800,0.545300,0.549900,0.549900,0.562200,0.562200,0.977600,0.546800,ood_classics
svhn,AUROC,,,,,,,,0.810411,0.821951,0.822856,0.822856,0.801572,0.801572,0.884691,0.823975,ood_classics
svhn,FPR95,,,,,,,,0.491700,0.483400,0.475200,0.475200,0.512300,0.512300,0.342600,0.471400,ood_classics
textures,AUROC,,,,,,,,0.770598,0.766633,0.730060,0.730060,0.777512,0.777512,0.551786,0.759923,ood_classics


In [117]:
pivot_df.iloc[:10,:]

Unnamed: 0_level_0,Unnamed: 1_level_0,msp,energy,odin,mahalanobis,knn,gsc,react,stein_per_dimension_l2,stein_per_dimension_l2_grad_only,stein_per_dimension_l2_lap_only,stein_per_dimension_l2_lap_only_std,stein_per_dimension_l2_no_lap,stein_per_dimension_l2_no_lap_std,stein_per_dimension_l2_score_only,stein_per_dimension_l2_std_balanced,benchmark
ood_dataset_sanitized,metric,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1
adversarial_autoattack_linf_2_255,AUROC,,,,,,,,0.524784,0.4733,0.460535,0.460535,0.521931,0.521931,0.500878,0.497528,adversarial
adversarial_autoattack_linf_2_255,FPR95,,,,,,,,0.9488,0.9608,0.9562,0.9562,0.949,0.949,0.9474,0.9548,adversarial
adversarial_autoattack_linf_4_255,AUROC,,,,,,,,0.622695,0.667659,0.665778,0.665778,0.619593,0.619593,0.51514,0.651572,adversarial
adversarial_autoattack_linf_4_255,FPR95,,,,,,,,0.9197,0.8759,0.8748,0.8748,0.9144,0.9144,0.9496,0.887,adversarial
adversarial_autoattack_linf_8_255,AUROC,,,,,,,,0.851822,0.868403,0.864422,0.864422,0.849997,0.849997,0.580506,0.862165,adversarial
adversarial_autoattack_linf_8_255,FPR95,,,,,,,,0.6261,0.56,0.6236,0.6236,0.6295,0.6295,0.9374,0.5956,adversarial
adversarial_fgsm_linf_4_255,AUROC,,,,,,,,0.511351,0.481649,0.449645,0.449645,0.531992,0.531992,0.53056,0.500344,adversarial
adversarial_fgsm_linf_4_255,FPR95,,,,,,,,0.9478,0.9532,0.9546,0.9546,0.9476,0.9476,0.9442,0.9513,adversarial
adversarial_fgsm_linf_8_255,AUROC,,,,,,,,0.513129,0.484302,0.454067,0.454067,0.525716,0.525716,0.626964,0.499295,adversarial
adversarial_fgsm_linf_8_255,FPR95,,,,,,,,0.9479,0.9523,0.953,0.953,0.9492,0.9492,0.9254,0.951,adversarial


In [118]:
pivot_df.groupby(["benchmark", "metric"]).mean()

Unnamed: 0_level_0,Unnamed: 1_level_0,msp,energy,odin,mahalanobis,knn,gsc,react,stein_per_dimension_l2,stein_per_dimension_l2_grad_only,stein_per_dimension_l2_lap_only,stein_per_dimension_l2_lap_only_std,stein_per_dimension_l2_no_lap,stein_per_dimension_l2_no_lap_std,stein_per_dimension_l2_score_only,stein_per_dimension_l2_std_balanced
benchmark,metric,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1
adversarial,AUROC,,,,,,,,0.614447,0.609333,0.601011,0.601011,0.618048,0.618048,0.533745,0.617656
adversarial,FPR95,,,,,,,,0.86504,0.84167,0.84263,0.84263,0.86685,0.86685,0.94365,0.84625
cifar10c,AUROC,,,,,,,,0.619308,0.616676,0.618394,0.618394,0.625016,0.625016,0.636395,0.625423
cifar10c,FPR95,,,,,,,,0.869111,0.869505,0.868079,0.868079,0.865789,0.865789,0.786979,0.866611
cifar10p,AUROC,,,,,,,,0.595368,0.596303,0.593248,0.593248,0.593646,0.593646,0.578873,0.594547
cifar10p,FPR95,,,,,,,,0.861109,0.861609,0.859282,0.859282,0.8639,0.8639,0.875618,0.860618
ood_classics,AUROC,,,,,,,,0.764703,0.760316,0.71483,0.71483,0.774439,0.774439,0.445187,0.750762
ood_classics,FPR95,,,,,,,,0.56774,0.5543,0.58616,0.58616,0.56264,0.56264,0.86226,0.56782


In [119]:
pivot_df.groupby(level="metric").mean(numeric_only=True)

Unnamed: 0_level_0,msp,energy,odin,mahalanobis,knn,gsc,react,stein_per_dimension_l2,stein_per_dimension_l2_grad_only,stein_per_dimension_l2_lap_only,stein_per_dimension_l2_lap_only_std,stein_per_dimension_l2_no_lap,stein_per_dimension_l2_no_lap_std,stein_per_dimension_l2_score_only,stein_per_dimension_l2_std_balanced
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
AUROC,,,,,,,,0.628531,0.626024,0.619099,0.619099,0.632402,0.632402,0.578277,0.630076
FPR95,,,,,,,,0.832764,0.826367,0.828949,0.828949,0.83188,0.83188,0.851827,0.827422


In [120]:
DETECTORS_TO_REPORT = ['msp',
 'energy',
 'odin',
 'mahalanobis',
 'knn',
 'stein_per_dimension_l2']

In [121]:
pivot_df.loc[:,DETECTORS_TO_REPORT].to_latex(escape=False)

'\\begin{tabular}{llrrrrrr}\n\\toprule\n &  & msp & energy & odin & mahalanobis & knn & stein_per_dimension_l2 \\\\\nood_dataset_sanitized & metric &  &  &  &  &  &  \\\\\n\\midrule\n\\multirow[t]{2}{*}{adversarial_autoattack_linf_2_255} & AUROC & NaN & NaN & NaN & NaN & NaN & 0.524784 \\\\\n & FPR95 & NaN & NaN & NaN & NaN & NaN & 0.948800 \\\\\n\\cline{1-8}\n\\multirow[t]{2}{*}{adversarial_autoattack_linf_4_255} & AUROC & NaN & NaN & NaN & NaN & NaN & 0.622695 \\\\\n & FPR95 & NaN & NaN & NaN & NaN & NaN & 0.919700 \\\\\n\\cline{1-8}\n\\multirow[t]{2}{*}{adversarial_autoattack_linf_8_255} & AUROC & NaN & NaN & NaN & NaN & NaN & 0.851822 \\\\\n & FPR95 & NaN & NaN & NaN & NaN & NaN & 0.626100 \\\\\n\\cline{1-8}\n\\multirow[t]{2}{*}{adversarial_fgsm_linf_4_255} & AUROC & NaN & NaN & NaN & NaN & NaN & 0.511351 \\\\\n & FPR95 & NaN & NaN & NaN & NaN & NaN & 0.947800 \\\\\n\\cline{1-8}\n\\multirow[t]{2}{*}{adversarial_fgsm_linf_8_255} & AUROC & NaN & NaN & NaN & NaN & NaN & 0.513129 \\\\\

In [122]:
import re

def format_latex_fn(df):
    """
    Returns a *compilable* LaTeX tabular string with:
      - benchmark as the 2nd column (right after ood_dataset_sanitized)
      - metric as the 3rd column
      - multirow on ood_dataset_sanitized (assumes AUROC then FPR95 per dataset)
      - correct tabular column spec and cline range

    Required columns (must exist in df):
      ood_dataset_sanitized, benchmark, metric,
      msp, energy, odin, mahalanobis, knn,
      stein_full_fixed0, stein_full_top1, stein_per_dimension_l2,
      stein_first_order_fixed0, stein_first_order_top1, stein_first_order_all
    """
    desired_order = [
        "ood_dataset_sanitized",
        "metric",
        "msp", "energy", "odin", "mahalanobis", "knn", "gsc","stein_per_dimension_l2",
    ]

    missing = [c for c in desired_order if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns: {missing}")

    df2 = df.loc[:, desired_order].copy()

    # 14 columns total: 3 text-ish + 11 numeric (benchmark treated as text)
    # l l l r... (11 r's)
    col_format = "lll" + "r" * (len(desired_order) - 3)

    latex = df2.to_latex(
        index=False,
        escape=True,           # escape underscores etc.
        column_format=col_format,
        float_format="%.6f",
        longtable=False
    )

    # Ensure tabular spec is exactly what we want (pandas sometimes inserts its own)
    latex = re.sub(
        r"\\begin\{tabular\}\{.*?\}",
        rf"\\begin{{tabular}}{{{col_format}}}",
        latex
    )

    ncols = len(desired_order)

    # Insert multirow in first column and clines after each FPR95 row
    lines = latex.splitlines()
    out = []
    last_ds = None

    for line in lines:
        stripped = line.strip()

        # pass through structural lines
        if (
            stripped.startswith("\\begin{tabular}")
            or stripped.startswith("\\end{tabular}")
            or stripped.startswith("\\toprule")
            or stripped.startswith("\\midrule")
            or stripped.startswith("\\bottomrule")
        ):
            out.append(line)
            continue

        # keep non-row lines
        if "&" not in line:
            out.append(line)
            continue

        cells = [c.strip() for c in line.split("&")]

        # header row: keep as is (it already has the right order)
        if cells[0] == "ood\\_dataset\\_sanitized":
            out.append(line)
            continue

        ds = cells[0]
        metric = cells[2]  # metric is now the 3rd column

        if ds == last_ds:
            cells[0] = ""  # second row of pair
        else:
            last_ds = ds
            # Only apply multirow on the AUROC row (assumes AUROC then FPR95)
            if "AUROC" in metric:
                cells[0] = f"\\multirow[t]{{2}}{{*}}{{{ds}}}"

        new_line = " & ".join(cells)
        out.append(new_line)

        if "FPR95" in metric:
            out.append(f"\\cline{{1-{ncols}}}")

    latex2 = "\n".join(out)

    # Remove a trailing \cline right before \bottomrule (optional cleanup)
    latex2 = latex2.replace(f"\\cline{{1-{ncols}}}\n\\bottomrule", "\\bottomrule")

    return latex2


In [123]:
desired_order = [
        "ood_dataset_sanitized",
        "metric",
        "msp", "energy", "odin", "mahalanobis", "knn", "gsc","stein_per_dimension_l2",
    ]

RAW_LATEX = format_latex_fn(pivot_df.reset_index().loc[:,desired_order])

In [124]:
from collections import defaultdict

METHOD_COL_START = 2  # after dataset, metric

OOD_GROUPS = {
    "adversarial": lambda d: d.startswith("adversarial"),
    "cifar10c": lambda d: d.startswith("cifar10c"),
    "cifar10p": lambda d: d.startswith("cifar10p"),
    "ood_benchmarks": lambda d: d.lower() in {
        "isun", "lsun", "places365", "svhn", "textures"
    },
}

# ---------- parse ----------
rows = []
for line in RAW_LATEX.splitlines():
    line = line.strip()
    if not line or line.startswith("\\"):
        continue
    if "&" not in line:
        continue

    cells = [c.strip() for c in line.rstrip("\\").split("&")]

    # âœ… robust header skip (works with escaped underscores)
    if len(cells) > 1 and cells[1] == "metric":
        continue

    rows.append(cells)

# ---------- group by dataset (logical, not visual) ----------
by_dataset = defaultdict(list)
last_dataset = None

for r in rows:
    dataset = r[0] if r[0] else last_dataset
    by_dataset[dataset].append(r)
    last_dataset = dataset

# ---------- bold extrema ----------
def bold_extreme(cells):
    metric = cells[1]
    vals = list(map(float, cells[METHOD_COL_START:]))

    if metric == "AUROC":
        idx = vals.index(max(vals))
    else:  # FPR95
        idx = vals.index(min(vals))

    col = METHOD_COL_START + idx
    cells[col] = rf"\textbf{{{cells[col]}}}"
    return cells

for ds, rs in by_dataset.items():
    for i in range(len(rs)):
        rs[i] = bold_extreme(rs[i])

# ---------- restore visual structure ----------
for ds, rs in by_dataset.items():
    for i in range(1, len(rs)):
        rs[i][0] = ""   # blank dataset cell for FPR95 rows

# ---------- split into 4 tables ----------
tables = defaultdict(list)
for ds, rs in by_dataset.items():
    for name, fn in OOD_GROUPS.items():
        if fn(ds):
            tables[name].extend(rs)

# ---------- emit ----------
HEADER = r"""
\begin{table}[!tb]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{lllrrrrrr}
\toprule
ood\_dataset\_sanitized & metric & msp & energy & odin & mahalanobis & knn & gsc & stein\_per\_dimension\_l2 \\
\midrule
"""

FOOTER = r"""
\bottomrule
\end{tabular}%
}
\end{table}
"""

def fmt_sig(x, sig=4):
    """
    Format a number (string or float) to `sig` significant digits.
    Preserves LaTeX \\textbf{...} wrappers if present.
    """
    x = x.strip()
    is_bold = x.startswith(r"\textbf{") and x.endswith("}")
    if is_bold:
        x_inner = x[len(r"\textbf{"):-1]
    else:
        x_inner = x

    try:
        val = float(x_inner)
        formatted = f"{val:.{sig}g}"
    except ValueError:
        # not a number (e.g., dataset name)
        return x

    return rf"\textbf{{{formatted}}}" if is_bold else formatted

def emit(rows):
    out_lines = []
    for r in rows:
        r_out = r.copy()
        for i in range(METHOD_COL_START, len(r_out)):
            r_out[i] = fmt_sig(r_out[i], sig=4)
        out_lines.append(" & ".join(r_out) + r" \\")
    return "\n".join(out_lines)

for name in ["adversarial", "cifar10c", "cifar10p", "ood_benchmarks"]:
    print(f"\n% ===== TABLE: {name.upper()} =====")
    print(HEADER)
    print(emit(tables[name]))
    print(FOOTER)



% ===== TABLE: ADVERSARIAL =====

\begin{table}[!tb]
\centering
\resizebox{\textwidth}{!}{%
\begin{tabular}{lllrrrrrr}
\toprule
ood\_dataset\_sanitized & metric & msp & energy & odin & mahalanobis & knn & gsc & stein\_per\_dimension\_l2 \\
\midrule

adversarial\_autoattack\_linf\_2\_255 & AUROC & \textbf{nan} & nan & nan & nan & nan & nan & 0.5248 \\
 & FPR95 & \textbf{nan} & nan & nan & nan & nan & nan & 0.9488 \\
adversarial\_autoattack\_linf\_4\_255 & AUROC & \textbf{nan} & nan & nan & nan & nan & nan & 0.6227 \\
 & FPR95 & \textbf{nan} & nan & nan & nan & nan & nan & 0.9197 \\
adversarial\_autoattack\_linf\_8\_255 & AUROC & \textbf{nan} & nan & nan & nan & nan & nan & 0.8518 \\
 & FPR95 & \textbf{nan} & nan & nan & nan & nan & nan & 0.6261 \\
adversarial\_fgsm\_linf\_4\_255 & AUROC & \textbf{nan} & nan & nan & nan & nan & nan & 0.5114 \\
 & FPR95 & \textbf{nan} & nan & nan & nan & nan & nan & 0.9478 \\
adversarial\_fgsm\_linf\_8\_255 & AUROC & \textbf{nan} & nan & nan & nan & nan 