In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml
from pathlib import Path

import torch
import torchvision.utils as vutils

import os
from torch.utils.data import ConcatDataset
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder

from sklearn.metrics import roc_curve, roc_auc_score

from math import sqrt
from scipy.stats import beta
import re

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [3]:
def load_dataset(config):
    dataset_name = config['dataset']['name'].lower()
    data_dir = "D:/mona/mia_research/data"
    
    if dataset_name == 'cifar10':
        # Directly apply transform during loading for inference
        train_dataset = CIFAR10(root=data_dir, train=True, download=False)
        test_dataset = CIFAR10(root=data_dir, train=False, download=False)
        full_dataset = ConcatDataset([train_dataset, test_dataset])
        train_label = np.array(train_dataset.targets)
        test_label = np.array(test_dataset.targets)
        full_label = np.concatenate((train_label, test_label), axis=0)
    
    elif dataset_name == 'cifar100':
        train_dataset = CIFAR100(root=data_dir, train=True, download=False)
        test_dataset = CIFAR100(root=data_dir, train=False, download=False)
        full_dataset = ConcatDataset([train_dataset, test_dataset])
        train_label = np.array(train_dataset.targets)
        test_label = np.array(test_dataset.targets)
        full_label = np.concatenate((train_label, test_label), axis=0)
           
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    return full_dataset, full_label


In [4]:
# TWO-MODE EVAL: target-based & shadow-based
# - Recompute τ directly from scores (no CSV rounding)
# - Use scikit-learn ROC semantics: score >= τ
# - "Global threshold" attack: only target-mode baseline is produced

# -------- CONFIG --------
exp_path = Path("experiments/cifar10/wrn28-2/weak_rotate_jitter_cutmix_drop0.1_wd1e-3")
TARGET_FPRS = [0.00001, 0.001]   # 0.001% and 0.1%
PRIORS = [0.01, 0.1, 0.5]
DO_SANITY_CHECKS = True          # set False to skip exact ROC point checks

SCORE_FILES = {
    "LiRA (online)":               "online_scores_leave_one_out.npy",
    "LiRA (online, fixed var)":    "online_fixed_scores_leave_one_out.npy",
    "LiRA (offline)":              "offline_scores_leave_one_out.npy",
    "LiRA (offline, fixed var)":   "offline_fixed_scores_leave_one_out.npy",
    "Global threshold":            "global_scores_leave_one_out.npy",
}
LABELS_NPY = "membership_labels.npy"

# -------- OUTPUT DIR --------
def make_out_dir(src: Path) -> Path:
    parts = src.parts
    if len(parts) >= 4:
        dataset, model, configs = parts[-3], parts[-2], parts[-1]
    else:
        dataset, model, configs = "unknown_dataset", "unknown_model", src.name
    out = Path("analysis_results") / dataset / model / configs
    out.mkdir(parents=True, exist_ok=True)
    return out

out_dir = make_out_dir(exp_path)

# -------- LOAD --------
labels = np.load(exp_path / LABELS_NPY)  # [M, N], bool
labels = labels.astype(bool, copy=False)
scores = {name: np.load(exp_path / fname) for name, fname in SCORE_FILES.items()}
M, N = labels.shape
for name, arr in scores.items():
    if arr.shape != (M, N):
        raise ValueError(f"{name} shape {arr.shape} != {(M, N)}")

# -------- HELPERS --------
def compute_metrics_at_tau(scores_row: np.ndarray,
                          labels_row: np.ndarray,
                          tau: float):
    """
    Compute confusion matrix and rates at threshold tau.
    Returns tp, fp, tn, fn, tpr, fpr_achieved (NOT precision - computed later with prior).
    """
    pred = scores_row >= tau
    tp = int(np.sum(pred & labels_row))
    fp = int(np.sum(pred & ~labels_row))
    tn = int(np.sum(~pred & ~labels_row))
    fn = int(np.sum(~pred & labels_row))
    
    tpr = tp / (tp + fn) if (tp + fn) else 0.0
    fpr_achieved = fp / (fp + tn) if (fp + tn) else 0.0
    
    return tp, fp, tn, fn, tpr, fpr_achieved

def compute_precision(tpr: float, fpr_achieved: float, prior: float):
    """
    Compute precision given TPR, achieved FPR, and membership prior.
    Assumes TPR and FPR are representative from balanced evaluation.
    """
    ppos = tpr * prior + fpr_achieved * (1 - prior)
    if ppos > 0:
        return (tpr * prior) / ppos
    else:
        return np.nan  # Undefined when no positive predictions

def tau_for_target(scores_row: np.ndarray,
                   labels_row: np.ndarray,
                   tfpr: float):
    """
    Return (tau, fpr_at_tau, tpr_at_tau) where tau is the largest threshold such that
    FPR(score >= tau) <= tfpr. Uses all thresholds (drop_intermediate=False).
    """
    fpr, tpr, thr = roc_curve(labels_row.astype(bool),
                              scores_row,
                              drop_intermediate=False)
    idx = np.where(fpr <= tfpr)[0]
    if idx.size == 0:
        return np.inf, None, None
    j = idx[-1]
    return float(thr[j]), float(fpr[j]), float(tpr[j])

def median_others(arr: np.ndarray, exclude_idx: int) -> float:
    """Median of finite values in arr excluding arr[exclude_idx]. Returns inf if empty."""
    pool = np.delete(arr, exclude_idx)
    pool = pool[np.isfinite(pool)]
    return float(np.median(pool)) if pool.size else np.inf

# -------- EVAL (two modes) --------
rows_detail = []

for attack_name, arr in scores.items():
    # Precompute per-model AUC once (independent of τ / priors / mode)
    aucs = np.full(M, np.nan, dtype=float)
    for m in range(M):
        try:
            aucs[m] = roc_auc_score(labels[m].astype(int), arr[m])
        except ValueError:
            # happens if a model has only one class present; leave NaN
            pass

    for tfpr in TARGET_FPRS:
        # 1) Target τ per model from its own ROC
        taus_target = np.empty(M, dtype=np.float64)
        fprs_chk = np.empty(M, dtype=np.float64)
        tprs_chk = np.empty(M, dtype=np.float64)
        fprs_chk[:] = np.nan
        tprs_chk[:] = np.nan

        for m in range(M):
            tau_m, fpr_m, tpr_m = tau_for_target(arr[m], labels[m], tfpr)
            taus_target[m] = tau_m
            if fpr_m is not None:
                fprs_chk[m] = fpr_m
                tprs_chk[m] = tpr_m

        # 2) Shadow τ per target = median of other models' target τs
        taus_shadow = np.array([median_others(taus_target, m) for m in range(M)], dtype=np.float64)

        # 3) Sanity checks (optional)
        if DO_SANITY_CHECKS:
            finite_idxs = np.where(np.isfinite(taus_target))[0]
            for m in finite_idxs[:5]:
                tau_m = taus_target[m]
                _, _, _, _, tpr_re, fpr_re = compute_metrics_at_tau(arr[m], labels[m], tau_m)
                # allow tiny numerical noise
                if not (np.isclose(fpr_re, fprs_chk[m], atol=1e-12) and
                        np.isclose(tpr_re, tprs_chk[m], atol=1e-12)):
                    print(f"[SanityCheck WARN] {attack_name} model {m} @ {tfpr}: "
                          f"recomp (fpr={fpr_re}, tpr={tpr_re}) vs roc (fpr={fprs_chk[m]}, tpr={tprs_chk[m]})")

        # 4) Compute metrics for both modes
        # TARGET mode
        for m in range(M):
            tau_m = taus_target[m]
            if not np.isfinite(tau_m):
                tp=fp=0; tn=int(np.sum(~labels[m])); fn=int(np.sum(labels[m]))
                tpr=fpr_achieved=0.0
            else:
                tp, fp, tn, fn, tpr, fpr_achieved = compute_metrics_at_tau(arr[m], labels[m], tau_m)
            
            # Compute precision for each prior
            for prior in PRIORS:
                prec = compute_precision(tpr, fpr_achieved, prior)
                rows_detail.append(dict(
                    mode="target", 
                    attack=attack_name, 
                    target_fpr=tfpr,
                    achieved_fpr=fpr_achieved,  # FPR' - explicitly labeled
                    prior=prior,
                    model_idx=m, 
                    threshold=tau_m, 
                    tp=tp, fp=fp, tn=tn, fn=fn,
                    tpr=tpr,
                    precision=prec,
                    auc=aucs[m]
                ))

        # SHADOW mode (skip for the baseline "Global threshold" attack)
        if attack_name != "Global threshold":
            for m in range(M):
                tau_m = taus_shadow[m]
                if not np.isfinite(tau_m):
                    tp=fp=0; tn=int(np.sum(~labels[m])); fn=int(np.sum(labels[m]))
                    tpr=fpr_achieved=0.0
                else:
                    tp, fp, tn, fn, tpr, fpr_achieved = compute_metrics_at_tau(arr[m], labels[m], tau_m)
                
                # Compute precision for each prior
                for prior in PRIORS:
                    prec = compute_precision(tpr, fpr_achieved, prior)
                    rows_detail.append(dict(
                        mode="shadow",
                        attack=attack_name,
                        target_fpr=tfpr,
                        achieved_fpr=fpr_achieved,  # FPR' - explicitly labeled
                        prior=prior,
                        model_idx=m,
                        threshold=tau_m,
                        tp=tp, fp=fp, tn=tn, fn=fn,
                        tpr=tpr,
                        precision=prec,
                        auc=aucs[m]
                    ))

# -------- SAVE DETAIL + SUMMARY --------
detail_df = pd.DataFrame(rows_detail)
detail_path = out_dir / "per_model_metrics_two_modes.csv"
detail_df.to_csv(detail_path, index=False)

# Summary aggregation
summary = (detail_df
    .groupby(["mode","attack","target_fpr","prior"], as_index=False)
    .agg(TPR_Mean=("tpr","mean"), 
         TPR_Std=("tpr","std"),
         FPR_Achieved_Mean=("achieved_fpr","mean"),  # Changed: explicitly "achieved"
         FPR_Achieved_Std=("achieved_fpr","std"),
         Precision_Mean=("precision","mean"),  # Uses nanmean implicitly
         Precision_Std=("precision","std"),    # Uses nanstd implicitly
         AUC_Mean=("auc","mean"), 
         AUC_Std=("auc","std"))
)

# Express rates & AUC as %
for c in ["TPR_Mean","TPR_Std","FPR_Achieved_Mean","FPR_Achieved_Std",
          "Precision_Mean","Precision_Std","AUC_Mean","AUC_Std"]:
    summary[c] = (summary[c]*100).round(3)

summary["Target FPR (%)"] = (summary["target_fpr"]*100).round(4)
summary = summary.drop(columns=["target_fpr"])

# Reorder columns for clarity matching paper tables
summary = summary[["mode","attack","Target FPR (%)","prior",
                   "TPR_Mean","TPR_Std",
                   "FPR_Achieved_Mean","FPR_Achieved_Std",
                   "Precision_Mean","Precision_Std",
                   "AUC_Mean","AUC_Std"]]

summary_path = out_dir / "summary_statistics_two_modes.csv"
summary.to_csv(summary_path, index=False)

print(f"Saved:\n - {detail_path}\n - {summary_path}")

Saved:
 - analysis_results\cifar10\wrn28-2\weak_rotate_jitter_cutmix_drop0.1_wd1e-3\per_model_metrics_two_modes.csv
 - analysis_results\cifar10\wrn28-2\weak_rotate_jitter_cutmix_drop0.1_wd1e-3\summary_statistics_two_modes.csv


In [5]:
# -------------------------
# PER-SAMPLE VULNERABILITY (ONLINE @ 0.001% FPR, using SHADOW τ)
# Computes TP, FP, TN, FN for each sample across leave-one-out models
# -------------------------

ATTACK_ONLINE = "LiRA (online)"
FPR_001pct = 1e-5

# Safety checks
assert 'detail_df' in globals(), "detail_df not found; run the two-mode cell first."
assert 'scores' in globals() and ATTACK_ONLINE in scores, "scores for LiRA (online) not found."
assert 'labels' in globals(), "labels not found."

labels_bool = labels.astype(bool, copy=False)

# 1) Get shadow thresholds per model
mask = (
    (detail_df["mode"] == "shadow") &
    (detail_df["attack"] == ATTACK_ONLINE) &
    np.isclose(detail_df["target_fpr"].to_numpy(), FPR_001pct, atol=1e-12)
)
shadow_rows = detail_df.loc[mask, ["model_idx", "threshold"]].drop_duplicates(subset=["model_idx"])

if shadow_rows.empty:
    raise RuntimeError("No shadow thresholds in detail_df for LiRA (online) @ 0.001% FPR.")

M_detected, N_detected = scores[ATTACK_ONLINE].shape
taus_shadow = np.full(M_detected, np.inf, dtype=float)
for _, r in shadow_rows.iterrows():
    m = int(r["model_idx"])
    if 0 <= m < M_detected:
        taus_shadow[m] = float(r["threshold"])

# 2) Predict per model with shadow thresholds
scores_online = scores[ATTACK_ONLINE]
pred = (scores_online >= taus_shadow[:, None])  # [M, N] boolean predictions

# 3) Compute confusion matrix per sample
TP = np.sum(pred & labels_bool, axis=0).astype(int)
FP = np.sum(pred & ~labels_bool, axis=0).astype(int)
TN = np.sum(~pred & ~labels_bool, axis=0).astype(int)
FN = np.sum(~pred & labels_bool, axis=0).astype(int)

# 4) Create dataframe with confusion matrix statistics
sample_df = pd.DataFrame({
    "sample_id": np.arange(N_detected),
    "tp": TP,
    "fp": FP,
    "tn": TN,
    "fn": FN,
})

# 5) Sort by vulnerability: prioritize low FP (stable non-false-positives), then high TP (high detection)
# This gives samples that are reliably detected when members, rarely flagged when non-members
vulnerability_sorted = sample_df.sort_values(by=["fp", "tp"], ascending=[True, False])
vuln_path = out_dir / "samples_vulnerability_ranked_online_shadow_0p001pct.csv"
vulnerability_sorted.to_csv(vuln_path, index=False)

# 6) Most vulnerable: FP=0 (never false alarm) AND TP>0 (detected when member)
highly_vulnerable = sample_df[(sample_df["fp"] == 0) & (sample_df["tp"] > 0)]
high_vuln_path = out_dir / "samples_highly_vulnerable_online_shadow_0p001pct.csv"
highly_vulnerable.to_csv(high_vuln_path, index=False)

print(f"Saved:\n - {vuln_path} ({len(sample_df)} samples)")
print(f" - {high_vuln_path} ({len(highly_vulnerable)} highly vulnerable samples)")
print(f"\nMost vulnerable sample: TP={sample_df.iloc[0]['tp']}, FP={sample_df.iloc[0]['fp']}")

Saved:
 - analysis_results\cifar10\wrn28-2\weak_rotate_jitter_cutmix_drop0.1_wd1e-3\samples_vulnerability_ranked_online_shadow_0p001pct.csv (60000 samples)
 - analysis_results\cifar10\wrn28-2\weak_rotate_jitter_cutmix_drop0.1_wd1e-3\samples_highly_vulnerable_online_shadow_0p001pct.csv (875 highly vulnerable samples)

Most vulnerable sample: TP=0, FP=0


In [6]:

def _to_chw_float_tensor(img):
    if isinstance(img, torch.Tensor):
        t = img.clone()
        if t.ndim == 2: t = t.unsqueeze(0)
        elif t.ndim == 3 and t.shape[0] not in (1,3): t = t.permute(2,0,1)
        t = t.float()
        if t.numel() and t.max() > 1.0: t = t / 255.0
    else:
        arr = np.array(img)
        t = torch.from_numpy(arr)
        if t.ndim == 2: t = t.unsqueeze(-1)
        if t.ndim == 3 and t.shape[-1] in (1,3): t = t.permute(2,0,1)
        t = t.float()
        if t.numel() and t.max() > 1.0: t = t / 255.0
    if t.shape[0] == 1: t = t.repeat(3,1,1)
    elif t.shape[0] > 3: t = t[:3]
    return t

def display_top_k_vulnerable_samples(
    vulnerable_samples, full_dataset,
    k: int = 20, nrow: int = 5,
    padding: int = 2, normalize: bool = True, dpi: int = 300,
    out_dir: "Path|str" = ".", save_name: str = "vulnerable_samples.png",
    sample_id_col: str = "sample_id",
    font_size: int = 8,
    badge_margin: int = 2,      # distance from the tile corner
    overhang_left: int = 6,     # extra pixels LEFT (can place into padding/outside)
    overhang_up: int = 6,       # extra pixels UP
):
    """
    Grid of top-k samples with a 'TP:.. FP:..' badge in the TOP-LEFT.
    'overhang_left'/'overhang_up' move the badge further into the corner.
    """
    import pandas as pd
    if not isinstance(vulnerable_samples, pd.DataFrame):
        raise TypeError("vulnerable_samples must be a pandas DataFrame")
    for col in (sample_id_col, "tp", "fp"):
        if col not in vulnerable_samples.columns:
            raise KeyError(f"Column '{col}' missing from vulnerable_samples")

    vs = vulnerable_samples.head(k).copy()
    ids = vs[sample_id_col].to_numpy()

    images = [_to_chw_float_tensor(full_dataset[int(sid)][0]) for sid in ids]
    tensor = torch.stack(images)  # [k, 3, H, W]

    grid = vutils.make_grid(tensor, nrow=nrow, padding=padding, normalize=normalize, pad_value=1.0)
    grid_np = grid.permute(1,2,0).cpu().numpy()

    rows = (k + nrow - 1) // nrow
    H, W = tensor.shape[-2], tensor.shape[-1]
    fig_w = max(4, nrow * 1.2)
    fig_h = max(4, rows * 1.2)

    fig, ax = plt.subplots(figsize=(fig_w, fig_h), dpi=dpi)
    ax.imshow(grid_np, aspect="equal")
    ax.axis("off")

    # grid math per torchvision.make_grid
    stride_x, stride_y = W + padding, H + padding
    base_x = base_y = padding

    # annotate top-left corner
    for i in range(len(ids)):
        r, c = divmod(i, nrow)
        x0 = base_x + c * stride_x    # tile's left edge
        y0 = base_y + r * stride_y    # tile's top edge

        tp_val = int(vs.iloc[i]["tp"])
        fp_val = int(vs.iloc[i]["fp"])
        text = f"TP:{tp_val}  FP:{fp_val}"

        # place slightly inside the tile, then overhang into the corner
        x_text = x0 + badge_margin - overhang_left
        y_text = y0 + badge_margin - overhang_up

        ax.text(
            x_text, y_text, text,
            ha="left", va="top",
            fontsize=font_size, fontweight="bold",
            bbox=dict(boxstyle="round,pad=0.25", facecolor="white", alpha=0.9,
                      edgecolor="black", linewidth=0.7),
            color="black",
            clip_on=False,   # allow overhang beyond axes/tile
        )

    plt.tight_layout(pad=0.1)
    out_dir = Path(out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    out_path = out_dir / save_name
    fig.savefig(out_path, bbox_inches="tight", dpi=dpi, facecolor="white")
    plt.close(fig)
    print(f"Saved grid: {out_path}")
    return out_path


In [7]:
# Load config and dataset
cfg_path = exp_path / "attack_config.yaml"
if not cfg_path.exists():
    raise FileNotFoundError(f"Config not found at {cfg_path}")

with open(cfg_path, "r") as f:
    config = yaml.safe_load(f)

full_dataset, _ = load_dataset(config)

# vulnerability_sorted should already be in memory from previous cell
# This fallback is just for safety if running cells out of order
if "vulnerability_sorted" not in globals():
    rank_csv = out_dir / "samples_vulnerability_ranked_online_shadow_0p001pct.csv"
    if not rank_csv.exists():
        raise FileNotFoundError(f"Ranking CSV not found at {rank_csv}. "
                                "Run the per-sample vulnerability cell first.")
    vulnerability_sorted = pd.read_csv(rank_csv)

print(f"Total samples ranked: {len(vulnerability_sorted)}")
print(f"Top 5 most vulnerable:")
print(vulnerability_sorted.head()[["sample_id", "tp", "fp", "tn", "fn"]])

# Visualize top 20 most vulnerable
display_top_k_vulnerable_samples(
    vulnerable_samples=vulnerability_sorted,
    full_dataset=full_dataset,
    k=20,
    nrow=5,
    out_dir=out_dir,
    save_name="top20_vulnerable_online_shadow_0p001pct.png",
    font_size=7,
    badge_margin=2,
    overhang_left=3,
    overhang_up=4,
)

Total samples ranked: 60000
Top 5 most vulnerable:
       sample_id  tp  fp   tn  fn
38479      38479  67   0  128  61
40350      40350  62   0  128  66
47417      47417  59   0  128  69
47063      47063  58   0  128  70
10361      10361  56   0  128  72
Saved grid: analysis_results\cifar10\wrn28-2\weak_rotate_jitter_cutmix_drop0.1_wd1e-3\top20_vulnerable_online_shadow_0p001pct.png


WindowsPath('analysis_results/cifar10/wrn28-2/weak_rotate_jitter_cutmix_drop0.1_wd1e-3/top20_vulnerable_online_shadow_0p001pct.png')

In [141]:
# ---------------------------------------------------------
# Build an aggregated LaTeX table across benchmarks:
# Balanced prior (0.5), target-based thresholds, show:
#   - TPR @ 0.001% FPR
#   - TPR @ 0.1% FPR
#   - AUC
# Add reduction factors (×...) vs baseline per attack & FPR
# ---------------------------------------------------------

# ===== USER CONFIG =====
# Map display name -> CSV path (first is baseline)
BENCHMARKS = {
    "Purchase-100 (baseline)": "analysis_results/purchase100/fcn/2025-05-29_1623_drp_0_wd_5e-5/summary_statistics_two_modes.csv",
    "Purchase-100 (reg.)":     "analysis_results/purchase100/fcn/2025-06-15_2326_drp_0.5_wd_1e-3/summary_statistics_two_modes.csv",
    # "CIFAR-100 (TL)":       "analysis_results/cifar100/efficientnetv2_rw_s/tl/summary_statistics_two_modes.csv",
}


# Which attacks (and order) to show
ATTACK_ORDER = [
    "LiRA (online)",
    "LiRA (online, fixed var)",
    "LiRA (offline)",
    "LiRA (offline, fixed var)",
    "Global threshold",
]

# Pretty names for the attack column
ATTACK_DISPLAY = {
    "LiRA (online)": "Online",
    "LiRA (online, fixed var)": "Online (fixed var)",
    "LiRA (offline)": "Offline",
    "LiRA (offline, fixed var)": "Offline (fixed var)",
    "Global threshold": "Global threshold",
}

PRIOR = 0.5             # balanced prior
MODE  = "target"        # target-based threshold only
ALPHA1 = 0.001          # TPR @ 0.001% FPR (as percent)
ALPHA2 = 0.1            # TPR @ 0.1% FPR (as percent)

# Table meta
CAPTION = ("CIFAR-100 with balanced prior and target-based threshold. "
           "TPR reduction factors relative to baseline shown in parentheses for  (TL).")
LABEL   = "tab:gtsrb-reg"
COLUMN_FORMAT = "lcccc"  # Benchmark, Attack, TPR@0.001, TPR@0.1, AUC
DIGITS_MAIN = 3          # mean ± std digits
DIGITS_MUL  = 1          # ×multiplier digits (>=10 as integer)

# ===== HELPERS =====
def fmt_mu_sd(mu, sd, d=DIGITS_MAIN):
    if pd.isna(mu) or pd.isna(sd):
        return "--"
    return f"{mu:.{d}f} $\\pm$ {sd:.{d}f}"

def fmt_multiplier(r, d=DIGITS_MUL):
    if r is None or np.isinf(r) or np.isnan(r):
        return "($\\times\\infty$)"
    if r >= 10:
        return f"($\\times${int(round(r))})"
    return f"($\\times${r:.{d}f})"

def safe_ratio(baseline, variant):
    # baseline / variant, guarding zeros
    if variant is None or pd.isna(variant) or variant <= 0:
        return np.inf
    if baseline is None or pd.isna(baseline) or baseline < 0:
        return np.nan
    return baseline / variant

def load_summary(csv_path: str, bench_name: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df["Benchmark"] = bench_name
    return df

def pick_single(df: pd.DataFrame, attack: str, alpha_percent: float):
    """Return (TPR_mu, TPR_sd) for given attack at the given Target FPR (%)."""
    sub = df[(df["attack"] == attack) &
             (df["mode"] == MODE) &
             (np.isclose(df["prior"], PRIOR, atol=1e-12)) &
             (np.isclose(df["Target FPR (%)"], alpha_percent, atol=1e-6))]
    if sub.empty:
        return (np.nan, np.nan)
    # Already aggregated in CSV (mean±std across models). Take the mean of means (robust if repeated).
    return (sub["TPR_Mean"].mean(), sub["TPR_Std"].mean())

def pick_auc(df: pd.DataFrame, attack: str):
    """Return (AUC_mu, AUC_sd) for given attack (mode/alpha/prior independent)."""
    sub = df[(df["attack"] == attack) & (df["mode"] == MODE)]
    if sub.empty:
        return (np.nan, np.nan)
    # AUC shouldn't depend on prior/alpha; average if repeated
    return (sub["AUC_Mean"].mean(), sub["AUC_Std"].mean())

# ===== LOAD & PREP =====
all_df = []
for bench, path in BENCHMARKS.items():
    all_df.append(load_summary(path, bench))
summary = pd.concat(all_df, ignore_index=True)

# ===== BUILD LATEX MANUALLY (for multirow & multipliers) =====
lines = []
lines.append("\\begin{table*}[!ht]")
lines.append("\\centering")
lines.append(f"\\caption{{{CAPTION}}}")
lines.append(f"\\label{{{LABEL}}}")
lines.append("\\resizebox{\\textwidth}{!}{%")
lines.append("\\begin{tabular}{lcccc}")
lines.append("\\toprule")
lines.append("Benchmark & Attack & TPR@ & TPR@ & AUC \\\\ & & 0.001\\% FPR (\\%) & 0.1\\% FPR (\\%) & (\\%) \\\\ ")
lines.append("\\midrule")

bench_names = list(BENCHMARKS.keys())
baseline_name = bench_names[0]

# Pre-slice each benchmark's df for faster lookups
bench_dfs = {b: summary[summary["Benchmark"] == b].copy() for b in bench_names}
baseline_df = bench_dfs[baseline_name]

for bi, bench in enumerate(bench_names):
    bdf = bench_dfs[bench]

    # We need baseline TPRs per attack to compute multipliers for non-baseline benches
    base_tpr1 = {}
    base_tpr2 = {}
    for atk in ATTACK_ORDER:
        mu1, _ = pick_single(baseline_df, atk, ALPHA1)
        mu2, _ = pick_single(baseline_df, atk, ALPHA2)
        base_tpr1[atk] = mu1
        base_tpr2[atk] = mu2

    # Multirow count equals number of attacks
    lines.append(f"\\multirow{{{len(ATTACK_ORDER)}}}{{*}}{{{bench}}} & ",)

    for ai, atk in enumerate(ATTACK_ORDER):
        # Pull means/stds for this benchmark & attack
        tpr1_mu, tpr1_sd = pick_single(bdf, atk, ALPHA1)
        tpr2_mu, tpr2_sd = pick_single(bdf, atk, ALPHA2)
        auc_mu,  auc_sd  = pick_auc(bdf, atk)

        # Format core numbers
        tpr1_txt = fmt_mu_sd(tpr1_mu, tpr1_sd)
        tpr2_txt = fmt_mu_sd(tpr2_mu, tpr2_sd)
        auc_txt  = fmt_mu_sd(auc_mu,  auc_sd)

        # Reduction multipliers (vs baseline) for non-baseline benches
        if bench != baseline_name:
            r1 = safe_ratio(base_tpr1.get(atk), tpr1_mu)
            r2 = safe_ratio(base_tpr2.get(atk), tpr2_mu)
            tpr1_txt += " " + fmt_multiplier(r1)
            tpr2_txt += " " + fmt_multiplier(r2)

        attack_disp = ATTACK_DISPLAY.get(atk, atk)

        if ai == 0:
            # first row already started with \multirow{...} & ...
            lines[-1] += f"{attack_disp} & {tpr1_txt} & {tpr2_txt} & {auc_txt} \\\\"
        else:
            lines.append(f"& {attack_disp} & {tpr1_txt} & {tpr2_txt} & {auc_txt} \\\\")
    # midrule between benchmarks
    if bi < len(bench_names) - 1:
        lines.append("\\midrule")

lines.append("\\bottomrule")
lines.append("\\end{tabular}%")
lines.append("}")
lines.append("\\end{table*}")

latex_table = "\n".join(lines)
print(latex_table)


\begin{table*}[!ht]
\centering
\caption{CIFAR-100 with balanced prior and target-based threshold. TPR reduction factors relative to baseline shown in parentheses for  (TL).}
\label{tab:gtsrb-reg}
\resizebox{\textwidth}{!}{%
\begin{tabular}{lcccc}
\toprule
Benchmark & Attack & TPR@ & TPR@ & AUC \\ & & 0.001\% FPR (\%) & 0.1\% FPR (\%) & (\%) \\ 
\midrule
\multirow{5}{*}{Purchase-100 (baseline)} & Online & 0.523 $\pm$ 0.243 & 4.491 $\pm$ 0.281 & 70.163 $\pm$ 0.286 \\
& Online (fixed var) & 0.180 $\pm$ 0.110 & 3.089 $\pm$ 0.188 & 69.521 $\pm$ 0.280 \\
& Offline & 0.007 $\pm$ 0.007 & 0.500 $\pm$ 0.077 & 55.105 $\pm$ 0.484 \\
& Offline (fixed var) & 0.022 $\pm$ 0.017 & 0.713 $\pm$ 0.078 & 56.110 $\pm$ 0.506 \\
& Global threshold & 0.001 $\pm$ 0.001 & 0.100 $\pm$ 0.015 & 54.834 $\pm$ 0.148 \\
\midrule
\multirow{5}{*}{Purchase-100 (reg.)} & Online & 0.022 $\pm$ 0.017 ($\times$24) & 0.825 $\pm$ 0.068 ($\times$5.4) & 62.640 $\pm$ 0.160 \\
& Online (fixed var) & 0.026 $\pm$ 0.019 ($\times$6.9) &

In [142]:
# -------------------------------------------------------
# ONE unified LaTeX table across many benchmarks
# - Target vs Shadow @ chosen Target FPR (%)
# - \multirow grouping by benchmark
# - Unified columns: Attack, τ, TPR, FPR', Precision@π=...
# - Prints LaTeX only (no saving)
# -------------------------------------------------------

# ===== CONFIG =====
BENCHMARKS = [
    ("Purchase-100 (baseline)", "analysis_results/purchase100/fcn/2025-05-29_1623_drp_0_wd_5e-5/summary_statistics_two_modes.csv"),
    ("Purchase-100 (reg.)",     "analysis_results/purchase100/fcn/2025-06-15_2326_drp_0.5_wd_1e-3/summary_statistics_two_modes.csv"),
    # ("CIFAR-100 (TL)",       "analysis_results/cifar100/efficientnetv2_rw_s/tl/summary_statistics_two_modes.csv"),
]


ALPHA_PERCENT = 0.001       # pick one target FPR (%) for this unified table
PRIORS_SHOW   = [0.01, 0.10, 0.50]
INCLUDE_GLOBAL = False

CAPTION = (f"Target vs Shadow at target FPR $={ALPHA_PERCENT:.3g}\\%$ across benchmarks. "
           "FPR$'$ is the achieved false-positive rate.")
LABEL   = f"tab:merged_target_shadow_{str(ALPHA_PERCENT).replace('.','p')}"

# friendly attack order
ATTACK_ORDER = [
    "LiRA (online)",
    "LiRA (online, fixed var)",
    "LiRA (offline)",
    "LiRA (offline, fixed var)",
    "Global threshold",
]

def fmt_mu_sd(mu, sd, digits=3):
    if pd.isna(mu) or pd.isna(sd):
        return "--"
    return f"{mu:.{digits}f} $\\pm$ {sd:.{digits}f}"

def load_one(csv_path: str, bench: str) -> pd.DataFrame:
    df = pd.read_csv(csv_path)
    df["Benchmark"] = bench
    return df

# ---- load and filter all
frames = [load_one(p, b) for (b, p) in BENCHMARKS]
all_df = pd.concat(frames, ignore_index=True)

df = all_df[np.isclose(all_df["Target FPR (%)"], ALPHA_PERCENT, atol=1e-6)].copy()
if not INCLUDE_GLOBAL:
    df = df[df["attack"] != "Global threshold"]
df = df[df["mode"].isin(["target","shadow"])]

# ---- build rows grouped by benchmark, then attack, Target->Shadow
precision_cols = [f"Precision@$\\pi$={int(p*100)}\\%" for p in PRIORS_SHOW]
lines = []
lines.append("\\begin{table}")
lines.append(f"\\caption{{{CAPTION}}}")
lines.append(f"\\label{{{LABEL}}}")
col_fmt = "l" + "c"*(3 + 1 + len(PRIORS_SHOW))  # Benchmark + (Attack, tau, TPR, FPR', precisions...)
header_cols = ["Benchmark","Attack","$\\tau$","TPR (\\%)","FPR$'$ (\\%)"] + precision_cols
lines.append(f"\\begin{{tabular}}{{{col_fmt}}}")
lines.append("\\toprule")
lines.append(" & ".join(header_cols) + " \\\\")
lines.append("\\midrule")

for bench, _ in BENCHMARKS:
    sub_bench = df[df["Benchmark"] == bench]
    if sub_bench.empty:
        continue
    # count rows for multirow: 2 modes per attack if both present
    row_count = 0
    for atk in ATTACK_ORDER:
        present_modes = sub_bench[sub_bench["attack"] == atk]["mode"].unique()
        row_count += len([m for m in ["target","shadow"] if m in present_modes])
    if row_count == 0:
        continue

    # start multirow block
    first_row_started = False
    for atk in ATTACK_ORDER:
        for mode in ["target","shadow"]:
            sub = sub_bench[(sub_bench["attack"] == atk) & (sub_bench["mode"] == mode)]
            if sub.empty: 
                continue
            # TPR / FPR' (independent of prior; averaging a no-op)
            tpr_mu, tpr_sd   = sub["TPR_Mean"].mean(), sub["TPR_Std"].mean()
            fprp_mu, fprp_sd = sub["FPR_Achieved_Mean"].mean(), sub["FPR_Achieved_Std"].mean()
            row = []
            if not first_row_started:
                row.append(f"\\multirow{{{row_count}}}{{*}}{{{bench}}}")
                first_row_started = True
            else:
                row.append("")  # empty benchmark cell for subsequent rows in block
            row += [atk, mode.title(), fmt_mu_sd(tpr_mu, tpr_sd), fmt_mu_sd(fprp_mu, fprp_sd)]
            # add precisions for requested priors
            for p in PRIORS_SHOW:
                subp = sub[np.isclose(sub["prior"], p, atol=1e-12)]
                if subp.empty:
                    row.append("--")
                else:
                    mu, sd = subp["Precision_Mean"].iloc[0], subp["Precision_Std"].iloc[0]
                    row.append(fmt_mu_sd(mu, sd))
            lines.append(" & ".join(row) + " \\\\")
    lines.append("\\midrule")

# close table
# remove trailing midrule if present
if lines[-1] == "\\midrule":
    lines[-1] = "\\bottomrule"
else:
    lines.append("\\bottomrule")
lines.append("\\end{tabular}")
lines.append("\\end{table}")

print("\n".join(lines))


\begin{table}
\caption{Target vs Shadow at target FPR $=0.001\%$ across benchmarks. FPR$'$ is the achieved false-positive rate.}
\label{tab:merged_target_shadow_0p001}
\begin{tabular}{lccccccc}
\toprule
Benchmark & Attack & $\tau$ & TPR (\%) & FPR$'$ (\%) & Precision@$\pi$=1\% & Precision@$\pi$=10\% & Precision@$\pi$=50\% \\
\midrule
\multirow{8}{*}{Purchase-100 (baseline)} & LiRA (online) & Target & 0.523 $\pm$ 0.243 & 0.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 \\
 & LiRA (online) & Shadow & 0.516 $\pm$ 0.047 & 0.001 $\pm$ 0.001 & 89.927 $\pm$ 11.145 & 98.841 $\pm$ 1.369 & 99.868 $\pm$ 0.157 \\
 & LiRA (online, fixed var) & Target & 0.180 $\pm$ 0.110 & 0.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 & 100.000 $\pm$ 0.000 \\
 & LiRA (online, fixed var) & Shadow & 0.159 $\pm$ 0.015 & 0.001 $\pm$ 0.001 & 78.868 $\pm$ 22.268 & 96.698 $\pm$ 3.802 & 99.606 $\pm$ 0.465 \\
 & LiRA (offline) & Target & 0.007 $\pm$ 0.007 & 0.000 $\pm$ 0.000