In [None]:
"""
Threshold distribution analysis script (Fig. 1 in the paper).

This script creates box plots showing threshold distributions across
leave-one-out target models for the LiRA attack.

Author: Najeeb Jebreel, optmized by Cloude Sonnet 4.5
Date: 2025

Reproducibility notes:
- The left box aggregates per-target thresholds from a single run (M LOO targets).
- The right box pools thresholds across five independent runs (5×M targets).
- This directly corresponds to **Fig. 1**: dispersion (spread), central tendency (median),
  and robustness (rMAD) of the *online LiRA* decision threshold at a fixed target FPR.
- **Filtering exactly matches Fig. 1:** `mode == "target"`,
  `attack == "LiRA (online)"`, `target_fpr == 1e-5`, and `prior == 0.01`.
"""

# --- Headless-safe backend ---
# Use a non-interactive backend to make plotting work in servers/CI (no display needed).
import matplotlib
matplotlib.use("Agg")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# ------------------ Okabe–Ito palette (colorblind-safe) ------------------
# We keep colors consistent with the paper’s early figures:
# - BLUE for single run, ORANGE for pooled runs; GREY for subtle accents.
OI_BLUE   = "#0072B2"
OI_ORANGE = "#E69F00"
OI_VERMIL = "#D55E00"
OI_GREY   = "#888888"

# ------------------ Styling ------------------
def _style_for_papers():
    # Global rcParams tuned for compact, camera-ready vector PDFs.
    # - Font sizes chosen to be readable at ~85–100 mm figure widths.
    # - Turn off top/right spines for a cleaner, modern look.
    # - Boxplot defaults: small outlier markers; vermilion median line.
    plt.rcParams.update({
        "figure.dpi": 300, "savefig.dpi": 300,
        "font.size": 9, "axes.labelsize": 9, "axes.titlesize": 9,
        "xtick.labelsize": 8, "ytick.labelsize": 8,
        "axes.spines.top": False, "axes.spines.right": False,
        "pdf.fonttype": 42, "ps.fonttype": 42,  # vector-friendly fonts (embed as Type 42)
        # subtle boxplot defaults
        "boxplot.flierprops.marker": "o",
        "boxplot.flierprops.markersize": 2.0,
        "boxplot.flierprops.markerfacecolor": OI_GREY,
        "boxplot.flierprops.markeredgecolor": OI_GREY,
        "boxplot.medianprops.color": OI_VERMIL,
        "boxplot.medianprops.linewidth": 1.2,
        "boxplot.whiskerprops.linewidth": 0.9,
        "boxplot.capprops.linewidth": 0.9,
    })

# ------------------ Data loading ------------------
def _load_thresholds(
    csv_path,
    target_fpr,
    attack_equals="LiRA (online)",
    prior_equals=0.01,     
    mode_equals="target",
):
    """
    Load and apply exact filters (case-insensitive for strings, tolerant for floats):
      - mode == "target"
      - attack == "LiRA (online)"
      - target_fpr == 1e-5
      - prior == 0.01
    Returns a 1D array of finite thresholds.
    """
    df = pd.read_csv(csv_path)

    # Coerce numeric fields
    df["target_fpr"] = pd.to_numeric(df.get("target_fpr"), errors="coerce")
    df["prior"]      = pd.to_numeric(df.get("prior"),      errors="coerce")

    # Build masks (strip whitespace; casefold for robust equality on strings)
    m_mode   = df["mode"].astype(str).str.strip().str.casefold()   == str(mode_equals).casefold()
    m_attack = df["attack"].astype(str).str.strip().str.casefold() == str(attack_equals).casefold()
    m_tfpr   = np.isclose(df["target_fpr"].to_numpy(), float(target_fpr), rtol=1e-6, atol=1e-12)
    m_prior  = np.isclose(df["prior"].to_numpy(),      float(prior_equals),  rtol=1e-6, atol=1e-12)

    vals = pd.to_numeric(
        df.loc[m_mode & m_attack & m_tfpr & m_prior, "threshold"],
        errors="coerce"
    ).to_numpy()
    vals = vals[np.isfinite(vals)]
    return vals

def _load_thresholds_many(
    csv_paths,
    target_fpr,
    attack_equals="LiRA (online)", # match Fig. 1 
    prior_equals=0.01,     # either 0.01, 0.1 or 0.5; all work similarly
    mode_equals="target", # per-model thresholds
):
    series = []
    for p in csv_paths:
        v = _load_thresholds(
            p,
            target_fpr=target_fpr,
            attack_equals=attack_equals,
            prior_equals=prior_equals,
            mode_equals=mode_equals,
        )
        if v.size:
            series.append(v)
    return np.concatenate(series) if series else np.array([])


# ------------------ Robust dispersion (median-centric) ------------------
def _median_rmad(vals):
    # Robust counterpart to (mean, std): median + scaled MAD.
    # rMAD (% of median) = 100 * 1.4826 * MAD / median.
    med = float(np.median(vals))
    mad = float(np.median(np.abs(vals - med)))
    rmad = 100.0 * 1.4826 * mad / med if med != 0 else np.nan
    return med, rmad

# ------------------ Main: side-by-side BOX with Q3-aligned labels ------------------
def plot_thresholds_box_q3labels(
    single_csv,
    pooled_csvs,       # list[str]: CSVs to pool for the right box
    out_path,
    target_fpr=1e-5,
    labels=(r"(a) Single run ($M{=}256$)", r"(b) Five runs ($5{\times}256$)"),
    show_fliers=False,
    whisker_mode="tukey"  # "tukey" (1.5*IQR) or "p05p95" (5–95% whiskers)
):
    # 1) Load per-target thresholds for a *single* run (left) and *pooled* runs (right).
    #    Each threshold corresponds to one leave-one-out target model’s decision threshold τ.
    vals_single = _load_thresholds(single_csv, target_fpr)
    vals_pooled = _load_thresholds_many(pooled_csvs, target_fpr)
    if vals_single.size == 0:
        raise ValueError("No thresholds in single_csv for the given filters.")
    if vals_pooled.size == 0:
        raise ValueError("No thresholds found across pooled_csvs for the given filters.")

    _style_for_papers()

    # 2) Choose whisker mode:
    #    - 'tukey' → 1.5×IQR whiskers (classic)
    #    - 'p05p95' → whiskers at empirical 5th/95th percentiles (robust to heavy tails)
    whis = (5, 95) if whisker_mode == "p05p95" else 1.5

    # 3) Create a compact side-by-side box plot.
    fig, ax = plt.subplots(figsize=(3.6, 2.2))
    plt.subplots_adjust(left=0.12, right=0.99, top=0.95, bottom=0.24)

    data = [vals_single, vals_pooled]
    bp = ax.boxplot(
        data,
        vert=True,
        patch_artist=True,   # enable facecolor fill
        widths=0.55,
        whis=whis,
        showfliers=show_fliers
    )

    # Colors: BLUE (single), ORANGE (pooled). Use translucent fill to keep medians visible.
    edgecolors = [OI_BLUE, OI_ORANGE]
    facecolors = [OI_BLUE + "33", OI_ORANGE + "33"]  # add alpha via hex suffix '33' (~20% opacity)
    for patch, fc, ec in zip(bp['boxes'], facecolors, edgecolors):
        patch.set_facecolor(fc)
        patch.set_edgecolor(ec)
        patch.set_linewidth(1.0)
    for med in bp['medians']:
        med.set_color(OI_VERMIL); med.set_linewidth(1.2)
    for part in ["whiskers", "caps"]:
        for i, line in enumerate(bp[part]):
            ec = edgecolors[0] if i < 2 else edgecolors[1]  # first box’s two lines, then second box’s two lines
            line.set_color(ec); line.set_linewidth(0.9)

    # Axes labels and tick labels; labels include (a)/(b) panel markers for cross-reference in the paper.
    ax.set_xticks([1, 2])
    ax.set_xticklabels(labels, fontsize=8.5)
    ax.set_ylabel(r"Threshold $\tau$")

    # ---- Q3-aligned annotations (placed just above Q3) ----
    # For each box: compute median and rMAD (%). Annotate above the upper quartile (Q3).
    for xpos, vals, color in zip([1, 2], data, edgecolors):
        med, rmad = _median_rmad(vals)
        q3 = float(np.percentile(vals, 75))
        ax.annotate(
            f"Median = {med:.2f}\n"
            f"rMAD = {rmad:.1f}%",
            xy=(xpos, q3),                    # anchor at top of box
            xytext=(5, 12),                   # shift right & upward (points)
            textcoords="offset points",
            ha="left", va="bottom",           # position above box
            fontsize=6.5, color="#1A2732",
            bbox=dict(facecolor="white", edgecolor="none", alpha=0.65, pad=0.4),
            clip_on=False
        )

    # 4) Save as vector PDF (transparent background enables flexible placement in LaTeX).
    out_path = Path(out_path); out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, bbox_inches="tight", transparent=True)
    plt.show()
    plt.close(fig)
    return str(out_path)


In [None]:
# Set paths to csv files containing thresholds and run the plotting function.
# These files are generated by post_analysis scripts in analysis_results/.
single_path = "cifar10/resnet18/2025-10-05_1535/per_model_metrics_two_modes.csv"
pooled_paths = [
    "cifar10/resnet18/2025-10-19_1230/per_model_metrics_two_modes.csv",
    "cifar10/resnet18/2025-10-05_1535/per_model_metrics_two_modes.csv",
    "cifar10/resnet18/2025-10-07_1640/per_model_metrics_two_modes.csv",
    "cifar10/resnet18/2025-10-09_1745/per_model_metrics_two_modes.csv",
    "cifar10/resnet18/2025-10-15_1850/per_model_metrics_two_modes.csv",
]

plot_thresholds_box_q3labels(
    single_csv=single_path,
    pooled_csvs=pooled_paths,
    out_path="thresholds/thresh_sidebyside_box_q3labels.pdf",
    target_fpr=1e-5,       # matches the CSV 'target_fpr' used for Fig. 1
    show_fliers=False,
    whisker_mode="tukey"
)

  plt.show()


'thresholds\\thresh_sidebyside_box_q3labels.pdf'