In [None]:
"""Reproducibility across seeds & training variations.

This script rebuilds the *reproducibility/stability/coverage* analyses used in the paper (Fig. 2 and Fig. 3).
It also operates on sets of 0FP detections (optionally with TP≥x support) from multiple runs.

Author: Najeeb Jebreel, optmized by Cloude Sonnet 4.5
Date: 2025
"""

# ============================= Imports =======================================
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import shutil
from matplotlib.ticker import MaxNLocator, PercentFormatter

from analysis_utils import * 
from analysis_utils import _nice_bounds_numeric

In [None]:
# ============================= YOU FILL THESE ================================
"""
Each directory is expected to include CSV_BASENAME with per-sample detections
at ultra-low FPR (0.001%) produced by *online LiRA (shadow-based)*.
The five "seed*" entries are the *Identical* scenario (same config, different seeds).
The remaining entries each change ONE factor (+1 different): batch size/dropout, dropout/weight-decay,
architecture, or transfer learning (TL). Later we also average over +2 and +3 different variants.
The INPUTS below are generated in the analysis_results folder from post_analysis scripts. Make sure they are generated before running this script.
"""


INPUTS = {
    # Identical setting, different seeds {0,7,42,123,1234}
    "seed0":    "cifar10/resnet18/2025-10-19_1230",
    "seed7":    "cifar10/resnet18/2025-10-05_1535",
    "seed42":   "cifar10/resnet18/2025-10-07_1640",
    "seed123":  "cifar10/resnet18/2025-10-09_1745",
    "seed1234": "cifar10/resnet18/2025-10-15_1850",

    # +1 different (BS=512, Drop=0.2)
    "diff_bs_drop": "cifar10/resnet18/22025-10-11_1855",

    # +1 different (Drop=0.25, WD=3e-3)
    "diff_drop_wd": "cifar10/resnet18/2025-10-13_1725",

    # +1 different (Arch)
    "diff_arch": "cifar10/wrn28-2/2025-10-14_1820",

    # +1 different (TL)
    "transfer": "cifar10/efficientnetv2_rw_s/2025-10-21_1440",
}

# CSV filename expected in each directory. Must contain at least:
#   - SAMPLE_ID_COL (string identifiers of samples flagged as members)
#   - "tp" column (integer support within the run; number of shadow votes)
# The CSV is sorted by vulnerability score in its original pipeline, but only the columns above are required here.
CSV_BASENAME = "samples_vulnerability_ranked_online_shadow_0p001pct.csv"
SAMPLE_ID_COL = "sample_id"


# ---------------------------- Outputs ---------------------------------------
# Figures and tables are written here. Directories are created if missing.
OUT_FIG_DIR = "figures"
OUT_TAB_DIR = "tables"
os.makedirs(OUT_FIG_DIR, exist_ok=True)
os.makedirs(OUT_TAB_DIR, exist_ok=True)


# ============================= Docstring / Header ============================
"""

Aggregates across combinations of runs to compute:

  • Avg. Jaccard = |∩| / |∪|  (agreement / reproducibility)
  • Avg. intersection = |∩|  (size of common core / stability)
  • Avg. union = |∪|        (coverage / breadth)

Figure mapping for reproducibility:
  - The three line panels reproduce **Fig. 2 (a–c)**:
      jaccard_noleg.pdf   → Fig. 2a  (Avg. Jaccard vs k)
      intersection.pdf    → Fig. 2b  (Avg. intersection vs k)
      union.pdf           → Fig. 2c  (Avg. union vs k)
    A separate legend strip (legend.pdf) allows precise multi-panel assembly.

  - The heatmaps (Identical seeds only; FP=0; sweeping TP≥x) reproduce **Fig. 3**.

Exports (PDFs only):
  figures/
    jaccard_noleg.pdf          # Panel A (Avg. Jaccard), no legend      [Fig. 2a]
    intersection.pdf           # Panel B (Avg. intersection)             [Fig. 2b]
    union.pdf                  # Panel C (Avg. union)                    [Fig. 2c]
    legend.pdf                 # shared legend (thin strip) for Fig. 2
tables/
  stability_7scenarios.csv     # exact numbers used in the Fig. 2 plots

Scenarios & labels:
  - "Identical (2–5 seeds)"                # 5 same-config runs; vary seeds only
  - "+1 different (BS=512, Drop=0.2)"      # change batch size & dropout
  - "+1 different (Drop=0.25, WD=3e-3)"    # change dropout & weight decay
  - "+1 different (Arch)"                  # change architecture
  - "+1 different (TL)"                    # add transfer learning
  - "+2 different"                         # average over all pairs of +1 variants
  - "+3 different"                         # average over all triplets of +1 variants
"""

# Base typography (points). If LaTeX is available.
BASE_FONTSIZE = 8.5
PAPER_USES_TIMES = True
HAS_TEX = shutil.which("latex") is not None

# Choose LaTeX or fallback font configuration. PDF/PS font types set for vector text in outputs.
if HAS_TEX:
    mpl.rcParams.update({
        "text.usetex": True,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "axes.labelsize": BASE_FONTSIZE,
        "axes.titlesize": BASE_FONTSIZE,
        "xtick.labelsize": BASE_FONTSIZE - 1,
        "ytick.labelsize": BASE_FONTSIZE - 1,
        "legend.fontsize": BASE_FONTSIZE - 1,
        "figure.titlesize": BASE_FONTSIZE,
        "font.family": "serif",
        "text.latex.preamble": r"\usepackage{newtxtext}\usepackage{newtxmath}" if PAPER_USES_TIMES else r"",
    })
else:
    mpl.rcParams.update({
        "text.usetex": False,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "axes.labelsize": BASE_FONTSIZE,
        "axes.titlesize": BASE_FONTSIZE,
        "xtick.labelsize": BASE_FONTSIZE - 1,
        "ytick.labelsize": BASE_FONTSIZE - 1,
        "legend.fontsize": BASE_FONTSIZE - 1,
        "figure.titlesize": BASE_FONTSIZE,
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Times", "DejaVu Serif", "CMU Serif", "Nimbus Roman"],
        "mathtext.fontset": "stix" if PAPER_USES_TIMES else "cm",
    })


# ============================= Load named sets ===============================
# Baselines (required). These five are the *Identical (2–5 seeds)* scenario used in Fig. 2 and Fig. 3.
BASE_TOKENS = ["seed0","seed7","seed42","seed123","seed1234"]
for t in BASE_TOKENS:
    if t not in INPUTS:
        # Early, explicit failure if a required directory is missing.
        raise KeyError(f"Missing INPUTS['{t}']")
# Build (name, set_of_ids) tuples for baseline runs with TP≥1 by default.
sets_baseline = [(t, to_id_set(INPUTS[t])) for t in BASE_TOKENS]

# Variants (required for respective scenarios): each introduces exactly one change.
# These will be used to compute "+1 different" scenarios; later we aggregate over 2/3-change combos.
V_SEED3_BS512_DROP02 = ("seed3_bs512_drp0.2_wd1e-3", to_id_set(INPUTS["diff_bs_drop"]))
V_DROP25_WD3E3       = ("seed42_drp25_wd3e-3",       to_id_set(INPUTS["diff_drop_wd"]))
V_ARCH               = ("arch",                      to_id_set(INPUTS["diff_arch"]))
V_TL                 = ("tl",                        to_id_set(INPUTS["transfer"]))


# ============================= Scenarios =====================================
# 1) Identical (2–5 seeds): vary k from 2..5 over the five seed runs.
scen_identical = compute_scenario("Identical (2–5 seeds)", sets_baseline, kmin=2, kmax=5)

# 2) +1 different (BS=512, Drop=0.2): baseline seeds + one altered run.
scen_bs_drop = compute_scenario("+1 different (BS=512, DRP=20%)", sets_baseline + [V_SEED3_BS512_DROP02])

# 3) +1 different (Drop=0.25, WD=3e-3)
scen_drop_wd = compute_scenario("+1 different (DRP=25%, WD=3e-3)", sets_baseline + [V_DROP25_WD3E3])

# 4) +1 different (Arch)
scen_arch = compute_scenario("+1 different (Arch)", sets_baseline + [V_ARCH])

# 5) +1 different (TL)
scen_tl = compute_scenario("+1 different (TL)", sets_baseline + [V_TL])

# 6) +2 different: average the metrics across all pairs of (+1) variants.
two_variant_configs = [
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3],
    sets_baseline + [V_SEED3_BS512_DROP02, V_ARCH],
    sets_baseline + [V_SEED3_BS512_DROP02, V_TL],
    sets_baseline + [V_DROP25_WD3E3,       V_ARCH],
    sets_baseline + [V_DROP25_WD3E3,       V_TL],
    sets_baseline + [V_ARCH,               V_TL],
]
scen_plus2 = aggregate_over_configs("+2 different", two_variant_configs)

# 7) +3 different: same idea for all triplets of (+1) variants.
three_variant_configs = [
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3, V_ARCH],
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3, V_TL],
    sets_baseline + [V_SEED3_BS512_DROP02, V_ARCH,         V_TL],
    sets_baseline + [V_DROP25_WD3E3,       V_ARCH,         V_TL],
]
scen_plus3 = aggregate_over_configs("+3 different", three_variant_configs)

# Concatenate all scenarios → master table for plotting (Fig. 2).
all_results = pd.concat(
    [scen_identical, scen_bs_drop, scen_drop_wd, scen_arch, scen_tl, scen_plus2, scen_plus3],
    ignore_index=True
)

# Save table with exact numbers behind Fig. 2 (useful for reproducibility/CI).
all_results.to_csv(os.path.join(OUT_TAB_DIR, "reproducibility_7scenarios.csv"), index=False, float_format="%.4f")

# Compute dynamic y-limits after results are known (keeps axes tight yet readable).
DYNAMIC_YLIMS = compute_dynamic_ylims(all_results)


# ---------------------------- Styling ---------------------------------------
# Global seaborn theme tuned for paper figures; colorblind palette for accessibility.
sns.set_theme(context="paper", style="whitegrid", font_scale=1.05)
palette = sns.color_palette("colorblind", 8)
# Consistent color/linestyle mapping per scenario to match figure captions.
COL = {
    "Identical (2–5 seeds)":                 palette[0],
    "+1 different (BS=512, DRP=20%)":       palette[4],
    "+1 different (DRP=25%, WD=3e-3)":     palette[1],
    "+1 different (Arch)":                   palette[6],
    "+1 different (TL)":                     palette[2],
    "+2 different":                          palette[3],
    "+3 different":                          palette[5],
}
STYLE = {
    "Identical (2–5 seeds)": "solid",
    "+1 different (BS=512, DRP=20%)": (0, (7,2)),
    "+1 different (DRP=25%, WD=3e-3)": (0, (5,2)),
    "+1 different (Arch)": (0, (3,1,1,1)),
    "+1 different (TL)": (0, (1,1)),
    "+2 different": (0, (3,1,1,1)),
    "+3 different": (0, (9,2,1,2)),
}
# Marker/line width/size; adjusted for legibility at small panel sizes.
MARK, LW, MS = "o", 2.3, 6.0
# Desired scenario order. Filter to only those present in all_results.
ORDER_ALL = [
    "Identical (2–5 seeds)",
    "+1 different (BS=512, DRP=20%)",
    "+1 different (DRP=25%, WD=3e-3)",
    "+1 different (Arch)",
    "+1 different (TL)",
    "+2 different",
    "+3 different",
]
ORDER = [k for k in ORDER_ALL if (all_results["scenario"] == k).any()]

# Ensure vector text in saved PDFs/PS.
plt.rcParams.update({"pdf.fonttype": 42, "ps.fonttype": 42})


def _endpoint(ax, x, y, txt, color, dy=0):
    """Annotate the last point of a line with its numeric value (e.g., '35%' or '507')."""
    ax.annotate(txt, xy=(x, y), xytext=(4, dy), textcoords="offset points",
                fontsize=BASE_FONTSIZE-1, color=color, va="center", ha="left")

def draw_lines(ax, df, ykey, ylabel, pct=False):
    """Draw a multi-scenario line plot vs k for a given metric key.

    Arguments:
      ax    : Matplotlib axes to draw on.
      df    : DataFrame with columns ['scenario','k', ykey].
      ykey  : One of {'avg_jaccard','avg_intersection','avg_union'}.
      ylabel: Axis label text.
      pct   : If True, format y as percentage (used for Avg. Jaccard).

    Styling:
      - Uses COLOR/STYLE dictionaries for consistent legend mapping.
      - Adds two faint reference lines at 10% and 5% when pct=True (context for low-agreement regimes).
      - X ticks forced to integers (k is integer).
    """
    for scen in ORDER:
        sub = df[df["scenario"] == scen].sort_values("k")
        if sub.empty:
            continue
        ax.plot(sub["k"], sub[ykey],
                color=COL[scen], linestyle=STYLE[scen],
                marker=MARK, ms=MS, lw=LW, label=scen)
        x_end, y_end = sub["k"].iloc[-1], sub[ykey].iloc[-1]
        txt = f"{y_end*100:.1f}%" if pct else f"{y_end:.0f}"
        _endpoint(ax, x_end, y_end, txt, COL[scen])
    ax.set_xlabel(r"$k$")
    ax.set_ylabel(ylabel)
    ax.set_ylim(*DYNAMIC_YLIMS[ykey])
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.grid(True, linestyle=":", linewidth=0.8, alpha=0.6)
    ax.minorticks_on()
    ax.grid(which="minor", axis="y", linestyle=":", alpha=0.25)
    if pct:
        ax.yaxis.set_major_formatter(PercentFormatter(1.0))
        ax.axhline(0.10, color="grey", lw=1, ls=":", alpha=0.45)
        ax.axhline(0.05, color="grey", lw=1, ls=":", alpha=0.45)


# ---------- export panels (PDF only) ----------
def save_panel(path_pdf, ykey, ylabel, pct=False):
    """Save a single metric-vs-k panel (PDF). Legend is removed here and exported separately.

    The three calls below (Jaccard/Intersection/Union) correspond to **Fig. 2 (a–c)**.
    """
    fig, ax = plt.subplots(figsize=(3.4, 2.4))
    draw_lines(ax, all_results, ykey, ylabel, pct=pct)
    if ax.get_legend(): ax.get_legend().remove()
    fig.tight_layout()
    fig.savefig(path_pdf, bbox_inches="tight")
    plt.close(fig)

def save_legend(path_pdf):
    """Save a standalone legend as a thin PDF strip.

    This decouples legend placement from the panels so they can be composed
    precisely in a multi-panel figure (as done for Fig. 2).
    """
    fig_tmp, ax_tmp = plt.subplots()
    draw_lines(ax_tmp, all_results, "avg_jaccard", "Avg. Jaccard", pct=True)
    handles, labels = ax_tmp.get_legend_handles_labels()
    plt.close(fig_tmp)
    
    fig_leg = plt.figure(figsize=(5.0, 0.45))
    fig_leg.legend(handles, labels, loc="center", ncol=3, frameon=True,
                   fontsize=BASE_FONTSIZE-0.5, columnspacing=1.0, handletextpad=0.6)
    fig_leg.savefig(path_pdf, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig_leg)


# ---------------------------- Make files -------------------------------------
# Rebuild Fig. 2 (a–c) panels; legend exported separately for layout control.
save_panel(os.path.join(OUT_FIG_DIR, "jaccard_noleg.pdf"), "avg_jaccard", "Avg. Jaccard", pct=True)   # Fig. 2a
save_panel(os.path.join(OUT_FIG_DIR, "intersection.pdf"), "avg_intersection", "Avg. intersection")    # Fig. 2b
save_panel(os.path.join(OUT_FIG_DIR, "union.pdf"), "avg_union", "Avg. union")                         # Fig. 2c
save_legend(os.path.join(OUT_FIG_DIR, "legend.pdf"))

print("✓ Generated 4 figures:")
print(f"  {OUT_FIG_DIR}/jaccard_noleg.pdf")
print(f"  {OUT_FIG_DIR}/intersection.pdf")
print(f"  {OUT_FIG_DIR}/union.pdf")
print(f"  {OUT_FIG_DIR}/legend.pdf")
print(f"✓ Table: {OUT_TAB_DIR}/reproducibility_7scenarios.csv")



✓ Generated 4 figures:
  figures/jaccard_noleg.pdf
  figures/intersection.pdf
  figures/union.pdf
  figures/legend.pdf
✓ Table: tables/reproducibility_7scenarios.csv


In [5]:
# ============================= TP>=x @ FP=0 — compact heatmaps (Identical only) =============================
# Reproduces **Fig. 3** (Identical seeds; FP=0) by sweeping within-run support thresholds TP≥x.
# The x-axis in the heatmaps corresponds to TP≥x; rows are k ∈ {2,3,4,5}.
# Thresholds to report
TP_THRESHOLDS = [1, 2, 3, 4, 5, 10, 20, 64]

def compute_identical_over_thresholds(tp_values):
    """Return long DF over thresholds for the Identical (2–5 seeds) case only.

    For each threshold x in TP_THRESHOLDS:
      1) Rebuild each baseline set with TP≥x (stronger per-run support).
      2) Recompute metrics over k=2..5 using the same combination logic.
      3) Append to a long-form DataFrame with 'tp_threshold' attached.
    """
    rows = []
    for tpx in tp_values:
        sets_baseline_x = [(name, to_id_set(INPUTS[name], tp_threshold=tpx)) for name,_ in sets_baseline]
        df = compute_scenario("Identical (2–5 seeds)", sets_baseline_x, kmin=2, kmax=5)
        df["tp_threshold"] = tpx
        rows.append(df)
    out = pd.concat(rows, ignore_index=True)
    return out

# Long-form table used to generate the three heatmaps of Fig. 3 (Jaccard/∩/∪).
ident_only = compute_identical_over_thresholds(TP_THRESHOLDS)

# Save tidy table with exact values behind Fig. 3.
out_tbl = ident_only[["tp_threshold","k","avg_jaccard","avg_intersection","avg_union"]].copy()
out_tbl.to_csv(os.path.join(OUT_TAB_DIR, "tpgeq_x_fp0_identical.csv"), index=False, float_format="%.6f")


def _fmt_percent(v): 
    return "" if pd.isna(v) else f"{v*100:.0f}%"

def _fmt_int(v):
    return "" if pd.isna(v) else f"{int(round(v))}"

def _pivot(df, value):
    """Pivot helper producing k×(TP≥x) matrices for a chosen value column."""
    p = df.pivot(index="k", columns="tp_threshold", values=value).sort_index()
    return p.reindex(columns=TP_THRESHOLDS)

# Pivot data into matrices for each metric (columns = thresholds, rows = k).
p_j = _pivot(ident_only, "avg_jaccard")
p_i = _pivot(ident_only, "avg_intersection")
p_u = _pivot(ident_only, "avg_union")

# Panel-specific ranges (keep correct normalization), but unify the colormap
# to achieve a compact, consistent aesthetic across the three matrices.
CMAP = "viridis"
lims_j = (0, 1)
lims_i = _nice_bounds_numeric(np.nanmin(p_i.values), np.nanmax(p_i.values))
lims_u = _nice_bounds_numeric(np.nanmin(p_u.values), np.nanmax(p_u.values))

# A single-row, three-column layout; spacing tuned for captions and labels.
fig, axes = plt.subplots(1, 3, figsize=(7.2, 2.6))
plt.subplots_adjust(wspace=0.08, left=0.08, right=0.96, top=0.88, bottom=0.22)

# Panel A – Avg. Jaccard (percentage heatmap, no colorbar for compactness).
sns.heatmap(
    p_j, ax=axes[0], cmap=CMAP, vmin=lims_j[0], vmax=lims_j[1],
    annot=p_j.applymap(_fmt_percent), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=False, linewidths=0.5, linecolor="white"
)
axes[0].set_title("Avg. Jaccard (Identical)", fontsize=BASE_FONTSIZE+0.3)
axes[0].set_xlabel("")  # no xlabel on left panel
axes[0].set_ylabel("Runs combined (k)")
axes[0].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Panel B – Avg. Intersection (integer counts, still without colorbar).
sns.heatmap(
    p_i, ax=axes[1], cmap=CMAP, vmin=lims_i[0], vmax=lims_i[1],
    annot=p_i.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=False, linewidths=0.5, linecolor="white"
)
axes[1].set_title("Avg. Intersection", fontsize=BASE_FONTSIZE+0.3)
axes[1].set_xlabel("TP ≥ x  at  FP = 0")  # xlabel only under the middle panel
axes[1].set_ylabel("")
axes[1].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Panel C – Avg. Union (integer counts) with a single shared, unlabeled colorbar.
# We attach the only colorbar here to keep the figure compact and uniform.
hm = sns.heatmap(
    p_u, ax=axes[2], cmap=CMAP, vmin=lims_u[0], vmax=lims_u[1],
    annot=p_u.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=True, linewidths=0.5, linecolor="white",
    cbar_kws={"orientation": "vertical", "fraction": 0.046, "pad": 0.02}
)
axes[2].set_title("Avg. Union", fontsize=BASE_FONTSIZE+0.3)
axes[2].set_xlabel("")  # no xlabel on right panel
axes[2].set_ylabel("")
axes[2].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Make the colorbar unlabeled and tickless (unified look, no labels)
cb = hm.collections[0].colorbar
cb.set_label("")                 # no label text
cb.set_ticks([])                 # remove tick labels
cb.outline.set_linewidth(0.6)    # keep a subtle outline

# Align tick labels across all panels for easy reading.
for ax in axes:
    ax.set_xticklabels([str(t) for t in TP_THRESHOLDS], rotation=0)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

# Save the Fig. 3-style composite and close.
out_heatmap = os.path.join(OUT_FIG_DIR, "tpgeq_x_fp0_identical_heatmaps_unified.pdf")
fig.savefig(out_heatmap, bbox_inches="tight")
plt.close(fig)
print("✓ Unified-heatmap figure:", out_heatmap)


  annot=p_j.applymap(_fmt_percent), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
  annot=p_i.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
  annot=p_u.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},


✓ Unified-heatmap figure: figures\tpgeq_x_fp0_identical_heatmaps_unified.pdf
