In [None]:
import os, numpy as np, matplotlib.pyplot as plt
from matplotlib import ticker
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

# ---------- UTILITIES ----------
def _freedman_diaconis_bins(values, max_bins=80, min_bins=50):
    v = np.asarray(values, dtype=float)
    v = v[np.isfinite(v)]
    if v.size < 2:
        return min_bins
    iqr = np.subtract(*np.percentile(v, [75, 25]))
    if iqr <= 0:
        return min(max(min_bins, int(np.sqrt(v.size))), max_bins)
    bw = 2 * iqr * (v.size ** (-1 / 3))
    if bw <= 0:
        return min(max(min_bins, int(np.sqrt(v.size))), max_bins)
    n_bins = int(np.ceil((v.max() - v.min()) / bw))
    return int(np.clip(n_bins, min_bins, max_bins))


# ---------- MAIN FUNCTION ----------
def plot_sample_inout_distributions(
    sample_idx,
    single_sample = True,
    share_xlim=True,
    save_path="figure_sample_inout.pdf",
    color_member="#F53030",     # muted teal (ICDE style)
    color_nonmember="#1C9452",  # soft orange (ICDE style)
    font_family="DejaVu Sans",
    base_fontsize=8.5
):
    """
    Compact, ICDE-style histogram plot for one sample index.
    Shows only benchmark names and shared legend (Member / Non-member).
    """

    # --- typography ---
    plt.rcParams.update({
        "pdf.fonttype": 42,
        "ps.fonttype": 42,
        "font.family": font_family,
        "font.size": base_fontsize,
        "axes.titlesize": base_fontsize,
        "axes.labelsize": base_fontsize - 0.5,
        "legend.fontsize": base_fontsize - 0.5,
        "axes.linewidth": 0.6,
        "grid.linewidth": 0.4,
        "xtick.major.size": 3,
        "ytick.major.size": 3,
    })

    # --- load data ---
    labels = np.load(os.path.join(exp_paths[0], labels_fname))  # (M, N)
    M, N = labels.shape
    assert 0 <= sample_idx < N, f"sample_idx {sample_idx} out of range [0, {N-1}]"
    if single_sample:
        y = labels[:, sample_idx].astype(bool)
    else:
        y = labels.astype(bool)

    bench_scores = {}
    xmins, xmaxs = [], []

    for bench, p in zip(benchmarks, exp_paths):
        s = np.load(os.path.join(p, scores_fname))
        assert s.shape == labels.shape, f"Shape mismatch for {bench}"
        if single_sample:
            scores_j = s[:, sample_idx].astype(np.float64)
        else:
            scores_j = s.astype(np.float64)
     
        bench_scores[bench] = scores_j
        finite_scores = scores_j[np.isfinite(scores_j)]
        if finite_scores.size:
            xmins.append(finite_scores.min())
            xmaxs.append(finite_scores.max())

    # --- figure setup (smaller and tighter) ---
    if share_xlim and xmins and xmaxs:
        xlo, xhi = min(xmins), max(xmaxs)
        pad = 0.05 * (xhi - xlo + 1e-12)
    else:
        xlo = xhi = pad = None

    fig, axes = plt.subplots(2, 2, figsize=(4.6, 3.3))  # ICDE compact grid
    axes = axes.ravel()

    for ax, bench in zip(axes, benchmarks):
        s = bench_scores[bench]
        s_member = s[y]
        s_nonmember = s[~y]

        bins = _freedman_diaconis_bins(np.concatenate([s_member, s_nonmember]))

        # determine x-range
        if share_xlim and xlo is not None:
            xrange = (xlo - pad, xhi + pad)
        else:
            finite = np.concatenate([s_member[np.isfinite(s_member)], s_nonmember[np.isfinite(s_nonmember)]])
            lo, hi = finite.min(), finite.max()
            pad_local = 0.05 * (hi - lo + 1e-12)
            xrange = (lo - pad_local, hi + pad_local)

        # histograms only
        ax.hist(s_nonmember, bins=bins, range=xrange, density=True,
                color=color_nonmember, alpha=0.6, label="Non-member", edgecolor="none")
        ax.hist(s_member, bins=bins, range=xrange, density=True,
                color=color_member, alpha=0.6, label="Member", edgecolor="none")

        # aesthetics
        ax.set_title(bench, pad=3)
        ax.grid(axis="y", alpha=0.25, ls="--", lw=0.4)
        ax.xaxis.set_major_locator(ticker.MaxNLocator(4))
        ax.yaxis.set_major_locator(ticker.MaxNLocator(3))
        if share_xlim and xlo is not None:
            ax.set_xlim(xrange)
        ax.tick_params(width=0.6)

    # shared legend with shadow and closer positioning
    handles, labels_ = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels_, loc="upper center", ncol=2, frameon=True, 
               borderaxespad=0.1, edgecolor='black', shadow=True, 
               fancybox=False, framealpha=1.0, linewidth=0.6)
    
    # Tighter layout with reduced spacing between subplots
    fig.tight_layout(rect=[0, 0, 1, 0.94], h_pad=0.8, w_pad=0.8)

    # --- save high-quality outputs ---
    if save_path:
        os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
        fig.savefig(save_path, bbox_inches="tight")
        fig.savefig(save_path.replace(".pdf", ".png"), dpi=600, bbox_inches="tight")

    plt.show()

    # Example usage:

# ---- config (reuse your paths) ----
exp_paths = [
    "d:/mona/mia_research/experiments/cifar10/resnet18/weak/",
    "d:/mona/mia_research/experiments/cifar10/resnet18/strong/",
    "d:/mona/mia_research/experiments/cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed42/",
    "d:/mona/mia_research/experiments/cifar10/tl/",
]
benchmarks = ["CRP+FLP", "Baseline", "AOF", "AOF+TL"]
scores_fname = "global_scores_leave_one_out.npy"  # higher = member global_scores_leave_one_out or likelihood_ratios_online_leave_one_out 
labels_fname = "membership_labels.npy"            # boolean


plot_sample_inout_distributions(sample_idx=21, 
                                save_path="figures/sample_inout_score.pdf",
                                single_sample=True)
#21

FileNotFoundError: [Errno 2] No such file or directory: 'd:/mona/mia_research/experiments/c10/weak/membership_labels.npy'

In [None]:
#9442
# ---- config (reuse your paths) ----
exp_paths = [
    "d:/mona/mia_research/experiments/cifar10/resnet18/weak/",
    "d:/mona/mia_research/experiments/cifar10/resnet18/strong/",
    "d:/mona/mia_research/experiments/cifar10/resnet18/weak_rotate_jitter_cutmix_drop0.1_wd1e-3_seed42/",
    "d:/mona/mia_research/experiments/cifar10/tl/",
]