In [5]:
# ============================= Imports =======================================
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from pathlib import Path
from itertools import combinations
import shutil
from matplotlib.ticker import MaxNLocator, PercentFormatter

In [None]:
# ============================= YOU FILL THESE ================================
INPUTS = {
    # Identical seeds — weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed{0,7,42,123,1234}
    "seed0":    "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed0",
    "seed7":    "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed7",
    "seed42":   "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed42",
    "seed123":  "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed123",
    "seed1234": "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed1234",

    # +1 different (BS=512, Drop=0.2)
    "diff_bs_drop": "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.2_wd1e-3_bs512",

    # +1 different (Drop=0.25, WD=3e-3)
    "diff_drop_wd": "cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.25_wd3e-3",

    # +1 different (Arch)
    "diff_arch": "cifar10/wrn28-2/weak_rotate_jitter_cutmix_drop0.1_wd1e-3",

    # +1 different (TL)
    "transfer": "cifar10/efficientnetv2_rw_s/pretrained_weak_rotate_jitter_cutmix_drp0.25_wd_5e-2",
}

CSV_BASENAME = "samples_vulnerability_ranked_online_shadow_0p001pct.csv"
SAMPLE_ID_COL = "sample_id"



# ---------------------------- Outputs ---------------------------------------
OUT_FIG_DIR = "figures"
OUT_TAB_DIR = "tables"
os.makedirs(OUT_FIG_DIR, exist_ok=True)
os.makedirs(OUT_TAB_DIR, exist_ok=True)



'\nStability across seeds & training variations — publication-grade figures.\n\nExports (PDFs only):\n  figures/\n    jaccard_noleg.pdf          # Panel A (Avg. Jaccard), no legend\n    intersection.pdf           # Panel B (Avg. intersection)\n    union.pdf                  # Panel C (Avg. union)\n    legend.pdf                 # shared legend (thin strip)\ntables/\n  stability_7scenarios.csv     # exact numbers used in plots\n\nScenarios & labels:\n  - "Identical (2–5 seeds)"\n  - "+1 different (BS=512, Drop=0.2)"\n  - "+1 different (Drop=0.25, WD=3e-3)"\n  - "+1 different (Arch)"\n  - "+1 different (TL)"\n  - "+2 different"              # average over all variant pairs\n  - "+3 different"              # average over all variant triplets\n'

In [6]:
# ---- IEEE-friendly font setup with auto-fallback (Windows-safe) -------------
BASE_FONTSIZE = 8.5
PAPER_USES_TIMES = True
HAS_TEX = shutil.which("latex") is not None

if HAS_TEX:
    mpl.rcParams.update({
        "text.usetex": True,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "axes.labelsize": BASE_FONTSIZE,
        "axes.titlesize": BASE_FONTSIZE,
        "xtick.labelsize": BASE_FONTSIZE - 1,
        "ytick.labelsize": BASE_FONTSIZE - 1,
        "legend.fontsize": BASE_FONTSIZE - 1,
        "figure.titlesize": BASE_FONTSIZE,
        "font.family": "serif",
        "text.latex.preamble": r"\usepackage{newtxtext}\usepackage{newtxmath}" if PAPER_USES_TIMES else r"",
    })
else:
    mpl.rcParams.update({
        "text.usetex": False,
        "pdf.fonttype": 42, "ps.fonttype": 42,
        "axes.labelsize": BASE_FONTSIZE,
        "axes.titlesize": BASE_FONTSIZE,
        "xtick.labelsize": BASE_FONTSIZE - 1,
        "ytick.labelsize": BASE_FONTSIZE - 1,
        "legend.fontsize": BASE_FONTSIZE - 1,
        "figure.titlesize": BASE_FONTSIZE,
        "font.family": "serif",
        "font.serif": ["Times New Roman", "Times", "DejaVu Serif", "CMU Serif", "Nimbus Roman"],
        "mathtext.fontset": "stix" if PAPER_USES_TIMES else "cm",
    })


# ============================= Utilities =====================================
def _csv_path(p: str | Path) -> Path:
    """Return a real CSV path whether user provided a directory or a CSV file."""
    p = Path(p)
    if p.is_dir():
        p = p / CSV_BASENAME
    if p.suffix.lower() != ".csv":
        raise FileNotFoundError(f"Expected a directory or CSV; got: {p}")
    if not p.exists():
        raise FileNotFoundError(f"CSV not found: {p}")
    return p

def to_id_set(path_or_dir: str | Path, col=SAMPLE_ID_COL, tp_threshold: int = 1) -> set[str]:
    """Load CSV and return set of sample IDs where TP >= tp_threshold."""
    csvp = _csv_path(path_or_dir)
    df = pd.read_csv(csvp)
    if col not in df.columns:
        raise KeyError(f"Expected '{col}' in {csvp}; got {list(df.columns)}")
    if "tp" not in df.columns:
        raise KeyError(f"Expected 'tp' column in {csvp}")
    filtered = df[df["tp"] >= tp_threshold]
    return set(filtered[col].dropna().astype(str))


# Dynamic Y-limits computed from the data
def _nice_bounds_numeric(ymin, ymax):
    span = ymax - ymin
    pad = 0.05 * max(1.0, span)
    lo = ymin - pad
    hi = ymax + pad
    rng = hi - lo
    if   rng <= 20:  step = 1
    elif rng <= 100: step = 5
    elif rng <= 500: step = 10
    elif rng <= 2000: step = 50
    else:            step = 100
    lo = np.floor(lo / step) * step
    hi = np.ceil(hi / step) * step
    return lo, hi

def _nice_bounds_percent(ymin, ymax):
    lo = max(0.0, ymin - 0.02)
    hi = min(1.0, ymax + 0.02)
    lo = np.floor(lo * 20) / 20.0
    hi = np.ceil(hi * 20) / 20.0
    if np.isclose(lo, hi):
        hi = min(1.0, lo + 0.05)
    return lo, hi

def compute_dynamic_ylims(df):
    yl = {}
    col = df["avg_jaccard"].dropna()
    yl["avg_jaccard"] = _nice_bounds_percent(col.min(), col.max()) if not col.empty else (0, 1)
    col = df["avg_intersection"].dropna()
    yl["avg_intersection"] = _nice_bounds_numeric(col.min(), col.max()) if not col.empty else (0, 1)
    col = df["avg_union"].dropna()
    yl["avg_union"] = _nice_bounds_numeric(col.min(), col.max()) if not col.empty else (0, 1)
    return yl


# ============================= Metrics =======================================
def avg_agreement(named_sets, k):
    j, inter, uni = [], [], []
    for combo in combinations(named_sets, k):
        S = [s for _, s in combo]
        inter_set = set.intersection(*S)
        union_set = set.union(*S)
        if not union_set:
            continue
        j.append(len(inter_set) / len(union_set))
        inter.append(len(inter_set))
        uni.append(len(union_set))
    if not j:
        return np.nan, np.nan, np.nan
    return float(np.mean(j)), float(np.mean(inter)), float(np.mean(uni))

def compute_scenario(label, named_sets, kmin=None, kmax=None):
    M = len(named_sets)
    lo = 2 if kmin is None else kmin
    hi = M if kmax is None else kmax
    rows = []
    for k in range(lo, hi + 1):
        aj, ai, au = avg_agreement(named_sets, k)
        rows.append(dict(scenario=label, M=M, k=k,
                         avg_jaccard=aj, avg_intersection=ai, avg_union=au))
    return pd.DataFrame(rows)

def aggregate_over_configs(label, list_of_named_sets):
    dfs = [compute_scenario("tmp", ns) for ns in list_of_named_sets]
    out = dfs[0][["k"]].copy()
    out["avg_jaccard"] = np.mean([d["avg_jaccard"].values for d in dfs], axis=0)
    out["avg_intersection"] = np.mean([d["avg_intersection"].values for d in dfs], axis=0)
    out["avg_union"] = np.mean([d["avg_union"].values for d in dfs], axis=0)
    out["M"] = len(list_of_named_sets[0])
    out["scenario"] = label
    return out[["scenario","M","k","avg_jaccard","avg_intersection","avg_union"]]


# ============================= Load named sets ===============================
# Baselines (required)
BASE_TOKENS = ["seed0","seed7","seed42","seed123","seed1234"]
for t in BASE_TOKENS:
    if t not in INPUTS:
        raise KeyError(f"Missing INPUTS['{t}']")
sets_baseline = [(t, to_id_set(INPUTS[t])) for t in BASE_TOKENS]

# Variants (required for respective scenarios)
V_SEED3_BS512_DROP02 = ("seed3_bs512_drp0.2_wd1e-3", to_id_set(INPUTS["diff_bs_drop"]))
V_DROP25_WD3E3       = ("seed42_drp25_wd3e-3",       to_id_set(INPUTS["diff_drop_wd"]))
V_ARCH               = ("arch",                      to_id_set(INPUTS["diff_arch"]))
V_TL                 = ("tl",                        to_id_set(INPUTS["transfer"]))


# ============================= Scenarios =====================================
# 1) Identical (2–5 seeds)
scen_identical = compute_scenario("Identical (2–5 seeds)", sets_baseline, kmin=2, kmax=5)

# 2) +1 different (BS=512, Drop=0.2)
scen_bs_drop = compute_scenario("+1 different (BS=512, DRP=20%)", sets_baseline + [V_SEED3_BS512_DROP02])

# 3) +1 different (Drop=0.25, WD=3e-3)
scen_drop_wd = compute_scenario("+1 different (DRP=25%, WD=3e-3)", sets_baseline + [V_DROP25_WD3E3])

# 4) +1 different (Arch)
scen_arch = compute_scenario("+1 different (Arch)", sets_baseline + [V_ARCH])

# 5) +1 different (TL)
scen_tl = compute_scenario("+1 different (TL)", sets_baseline + [V_TL])

# 6) +2 different
two_variant_configs = [
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3],
    sets_baseline + [V_SEED3_BS512_DROP02, V_ARCH],
    sets_baseline + [V_SEED3_BS512_DROP02, V_TL],
    sets_baseline + [V_DROP25_WD3E3,       V_ARCH],
    sets_baseline + [V_DROP25_WD3E3,       V_TL],
    sets_baseline + [V_ARCH,               V_TL],
]
scen_plus2 = aggregate_over_configs("+2 different", two_variant_configs)

# 7) +3 different
three_variant_configs = [
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3, V_ARCH],
    sets_baseline + [V_SEED3_BS512_DROP02, V_DROP25_WD3E3, V_TL],
    sets_baseline + [V_SEED3_BS512_DROP02, V_ARCH,         V_TL],
    sets_baseline + [V_DROP25_WD3E3,       V_ARCH,         V_TL],
]
scen_plus3 = aggregate_over_configs("+3 different", three_variant_configs)

all_results = pd.concat(
    [scen_identical, scen_bs_drop, scen_drop_wd, scen_arch, scen_tl, scen_plus2, scen_plus3],
    ignore_index=True
)

# Save table
all_results.to_csv(os.path.join(OUT_TAB_DIR, "stability_7scenarios.csv"), index=False, float_format="%.4f")

# Compute dynamic y-limits after we have the results
DYNAMIC_YLIMS = compute_dynamic_ylims(all_results)


# ---------------------------- Styling ---------------------------------------
sns.set_theme(context="paper", style="whitegrid", font_scale=1.05)
palette = sns.color_palette("colorblind", 8)
COL = {
    "Identical (2–5 seeds)":                 palette[0],
    "+1 different (BS=512, DRP=20%)":       palette[4],
    "+1 different (DRP=25%, WD=3e-3)":     palette[1],
    "+1 different (Arch)":                   palette[6],
    "+1 different (TL)":                     palette[2],
    "+2 different":                          palette[3],
    "+3 different":                          palette[5],
}
STYLE = {
    "Identical (2–5 seeds)": "solid",
    "+1 different (BS=512, DRP=20%)": (0, (7,2)),
    "+1 different (DRP=25%, WD=3e-3)": (0, (5,2)),
    "+1 different (Arch)": (0, (3,1,1,1)),
    "+1 different (TL)": (0, (1,1)),
    "+2 different": (0, (3,1,1,1)),
    "+3 different": (0, (9,2,1,2)),
}
MARK, LW, MS = "o", 2.3, 6.0
ORDER_ALL = [
    "Identical (2–5 seeds)",
    "+1 different (BS=512, DRP=20%)",
    "+1 different (DRP=25%, WD=3e-3)",
    "+1 different (Arch)",
    "+1 different (TL)",
    "+2 different",
    "+3 different",
]
ORDER = [k for k in ORDER_ALL if (all_results["scenario"] == k).any()]

plt.rcParams.update({"pdf.fonttype": 42, "ps.fonttype": 42})


def _endpoint(ax, x, y, txt, color, dy=0):
    ax.annotate(txt, xy=(x, y), xytext=(4, dy), textcoords="offset points",
                fontsize=BASE_FONTSIZE-1, color=color, va="center", ha="left")

def draw_lines(ax, df, ykey, ylabel, pct=False):
    for scen in ORDER:
        sub = df[df["scenario"] == scen].sort_values("k")
        if sub.empty:
            continue
        ax.plot(sub["k"], sub[ykey],
                color=COL[scen], linestyle=STYLE[scen],
                marker=MARK, ms=MS, lw=LW, label=scen)
        x_end, y_end = sub["k"].iloc[-1], sub[ykey].iloc[-1]
        txt = f"{y_end*100:.1f}%" if pct else f"{y_end:.0f}"
        _endpoint(ax, x_end, y_end, txt, COL[scen])
    ax.set_xlabel(r"$k$")
    ax.set_ylabel(ylabel)
    ax.set_ylim(*DYNAMIC_YLIMS[ykey])
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax.grid(True, linestyle=":", linewidth=0.8, alpha=0.6)
    ax.minorticks_on()
    ax.grid(which="minor", axis="y", linestyle=":", alpha=0.25)
    if pct:
        ax.yaxis.set_major_formatter(PercentFormatter(1.0))
        ax.axhline(0.10, color="grey", lw=1, ls=":", alpha=0.45)
        ax.axhline(0.05, color="grey", lw=1, ls=":", alpha=0.45)


# ---------- export panels (PDF only) ----------
def save_panel(path_pdf, ykey, ylabel, pct=False):
    """Save a single panel as PDF only."""
    fig, ax = plt.subplots(figsize=(3.4, 2.4))
    draw_lines(ax, all_results, ykey, ylabel, pct=pct)
    if ax.get_legend(): ax.get_legend().remove()
    fig.tight_layout()
    fig.savefig(path_pdf, bbox_inches="tight")
    plt.close(fig)

def save_legend(path_pdf):
    """Save standalone legend as PDF only."""
    fig_tmp, ax_tmp = plt.subplots()
    draw_lines(ax_tmp, all_results, "avg_jaccard", "Avg. Jaccard", pct=True)
    handles, labels = ax_tmp.get_legend_handles_labels()
    plt.close(fig_tmp)
    
    fig_leg = plt.figure(figsize=(5.0, 0.45))
    fig_leg.legend(handles, labels, loc="center", ncol=3, frameon=True,
                   fontsize=BASE_FONTSIZE-0.5, columnspacing=1.0, handletextpad=0.6)
    fig_leg.savefig(path_pdf, bbox_inches="tight", pad_inches=0.02)
    plt.close(fig_leg)


# ---------------------------- Make files -------------------------------------
save_panel(os.path.join(OUT_FIG_DIR, "jaccard_noleg.pdf"), "avg_jaccard", "Avg. Jaccard", pct=True)
save_panel(os.path.join(OUT_FIG_DIR, "intersection.pdf"), "avg_intersection", "Avg. intersection")
save_panel(os.path.join(OUT_FIG_DIR, "union.pdf"), "avg_union", "Avg. union")
save_legend(os.path.join(OUT_FIG_DIR, "legend.pdf"))

print("✓ Generated 4 essential figures:")
print(f"  {OUT_FIG_DIR}/jaccard_noleg.pdf")
print(f"  {OUT_FIG_DIR}/intersection.pdf")
print(f"  {OUT_FIG_DIR}/union.pdf")
print(f"  {OUT_FIG_DIR}/legend.pdf")
print(f"✓ Table: {OUT_TAB_DIR}/stability_7scenarios.csv")

✓ Generated 4 essential figures:
  figures/jaccard_noleg.pdf
  figures/intersection.pdf
  figures/union.pdf
  figures/legend.pdf
✓ Table: tables/stability_7scenarios.csv


In [12]:
# ============================= TP>=x @ FP=0 — compact heatmaps (Identical only) =============================
# Thresholds to report
TP_THRESHOLDS = [1, 2, 3, 4, 5, 10, 20, 64]

def compute_identical_over_thresholds(tp_values):
    """Return long DF over thresholds for the Identical (2–5 seeds) case only."""
    rows = []
    for tpx in tp_values:
        sets_baseline_x = [(name, to_id_set(INPUTS[name], tp_threshold=tpx)) for name,_ in sets_baseline]
        df = compute_scenario("Identical (2–5 seeds)", sets_baseline_x, kmin=2, kmax=5)
        df["tp_threshold"] = tpx
        rows.append(df)
    out = pd.concat(rows, ignore_index=True)
    return out

ident_only = compute_identical_over_thresholds(TP_THRESHOLDS)

# Save a tidy table with the exact values
out_tbl = ident_only[["tp_threshold","k","avg_jaccard","avg_intersection","avg_union"]].copy()
out_tbl.to_csv(os.path.join(OUT_TAB_DIR, "tpgeq_x_fp0_identical.csv"), index=False, float_format="%.6f")

# ---------- Publication-ready 3-panel heatmap (unified colors, single unlabeled cbar) ----------

def _fmt_percent(v): 
    return "" if pd.isna(v) else f"{v*100:.0f}%"

def _fmt_int(v):
    return "" if pd.isna(v) else f"{int(round(v))}"

def _pivot(df, value):
    p = df.pivot(index="k", columns="tp_threshold", values=value).sort_index()
    return p.reindex(columns=TP_THRESHOLDS)

# Pivot data
p_j = _pivot(ident_only, "avg_jaccard")
p_i = _pivot(ident_only, "avg_intersection")
p_u = _pivot(ident_only, "avg_union")

# Panel-specific ranges (keep correct normalization), but unify the colormap
CMAP = "viridis"
lims_j = (0, 1)
lims_i = _nice_bounds_numeric(np.nanmin(p_i.values), np.nanmax(p_i.values))
lims_u = _nice_bounds_numeric(np.nanmin(p_u.values), np.nanmax(p_u.values))

fig, axes = plt.subplots(1, 3, figsize=(7.2, 2.6))
plt.subplots_adjust(wspace=0.08, left=0.08, right=0.96, top=0.88, bottom=0.22)

# Panel A – Avg. Jaccard
sns.heatmap(
    p_j, ax=axes[0], cmap=CMAP, vmin=lims_j[0], vmax=lims_j[1],
    annot=p_j.applymap(_fmt_percent), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=False, linewidths=0.5, linecolor="white"
)
axes[0].set_title("Avg. Jaccard (Identical)", fontsize=BASE_FONTSIZE+0.3)
axes[0].set_xlabel("")  # no xlabel on left panel
axes[0].set_ylabel("Runs combined (k)")
axes[0].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Panel B – Avg. Intersection
sns.heatmap(
    p_i, ax=axes[1], cmap=CMAP, vmin=lims_i[0], vmax=lims_i[1],
    annot=p_i.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=False, linewidths=0.5, linecolor="white"
)
axes[1].set_title("Avg. Intersection", fontsize=BASE_FONTSIZE+0.3)
axes[1].set_xlabel("TP ≥ x  at  FP = 0")  # xlabel only under the middle panel
axes[1].set_ylabel("")
axes[1].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Panel C – Avg. Union (draw the single shared, unlabeled colorbar here)
hm = sns.heatmap(
    p_u, ax=axes[2], cmap=CMAP, vmin=lims_u[0], vmax=lims_u[1],
    annot=p_u.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
    cbar=True, linewidths=0.5, linecolor="white",
    cbar_kws={"orientation": "vertical", "fraction": 0.046, "pad": 0.02}
)
axes[2].set_title("Avg. Union", fontsize=BASE_FONTSIZE+0.3)
axes[2].set_xlabel("")  # no xlabel on right panel
axes[2].set_ylabel("")
axes[2].tick_params(axis="both", labelsize=BASE_FONTSIZE-1)

# Make the colorbar unlabeled and tickless (unified look, no labels)
cb = hm.collections[0].colorbar
cb.set_label("")                 # no label text
cb.set_ticks([])                 # remove tick labels
cb.outline.set_linewidth(0.6)    # keep a subtle outline

# Align tick labels across all panels
for ax in axes:
    ax.set_xticklabels([str(t) for t in TP_THRESHOLDS], rotation=0)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)

out_heatmap = os.path.join(OUT_FIG_DIR, "tpgeq_x_fp0_identical_heatmaps_unified.pdf")
fig.savefig(out_heatmap, bbox_inches="tight")
plt.close(fig)
print("✓ Unified-color, single unlabeled colorbar figure:", out_heatmap)


  annot=p_j.applymap(_fmt_percent), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
  annot=p_i.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},
  annot=p_u.applymap(_fmt_int), fmt="", annot_kws={"fontsize": BASE_FONTSIZE-1},


✓ Unified-color, single unlabeled colorbar figure: figures\tpgeq_x_fp0_identical_heatmaps_unified.pdf
