In [100]:
import pandas as pd 

In [101]:
import pandas as pd

import os
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json 
import seaborn as sns 

In [102]:
plt.rcParams.update(
    {
        "text.usetex": False,
        "font.family": "serif",
        "axes.titlesize": 24,        # Larger axes/title fonts
        "legend.fontsize": 24,
    }
)
sns.set_context("talk")


In [103]:
palette_name = 'ocean_sunset'
palette = json.load(open(f'../palettes/{palette_name}.json', 'r'))

In [104]:
palette

{'Rich black': '#001219',
 'Midnight green': '#005f73',
 'Dark cyan': '#0a9396',
 'Tiffany Blue': '#94d2bd',
 'Vanilla': '#e9d8a6',
 'Gamboge': '#ee9b00',
 'Alloy orange': '#ca6702',
 'Rust': '#bb3e03',
 'Rufous': '#ae2012',
 'Auburn': '#9b2226'}

## Exhaustive

In [105]:
model_names = ['gemma-3-1b-pt', 'gpt2']

In [106]:
model_name_to_data = {}

for model in model_names:

    print(f'Processing model: {model}')
    filename_fn = lambda x: f'{x}-sample-stats.csv'
    df = pd.read_csv(f'../data/exhaustive/{filename_fn(model)}')
    model_name_to_data[model] = df

    print(df)

Processing model: gemma-3-1b-pt
   sample_idx        count          mean          std        min           max
0           0  34359869440   9491.273973  2122.655164  28.347353  46543.382812
1           1  34359869440   9338.912643  1939.911406  27.463047  38455.750000
2           2  34359869440  10061.169184  1784.916624  26.331551  48011.203125
3           3  34359869440  10436.978361  2685.334255  43.648937  52808.609375
4           4  34359869440   9646.630449  2012.337411  29.853027  38111.019531
5           5  34359869440  10436.449318  2344.532139  18.496944  36690.414062
6           6  34359869440   9509.748273  2237.611196  17.228855  36161.695312
7           7  34359869440   9157.106724  2083.316769  20.826887  34140.921875
8           8  34359869440  11058.262927  2475.411253  22.809589  62162.691406
9           9  34359869440  10886.070964  2507.579198  18.572681  52399.929688
Processing model: gpt2
   sample_idx       count        mean        std       min          max
0   

## Box plot

In [107]:
Z_25 = 0.6744897501960817  # Φ^{-1}(0.75) for normal ≈ 0.67449

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return np.nan

def stats_row_to_bxp_dict(row):
    """
    Convert a row with (sample, mean, std, min, max) to a dict for matplotlib.bxp.
    Expected row fields: 'layer', 'mean', 'std', 'min', 'max'
    """
    sample = int(row["sample_idx"])
    mu  = _safe_float(row["mean"])
    sd  = _safe_float(row["std"])
    mn  = _safe_float(row["min"])
    mx  = _safe_float(row["max"])

    # Fallbacks if missing/NaN
    if np.isnan(mu): mu = 0.0
    if np.isnan(sd): sd = 0.0
    if np.isnan(mn): mn = mu
    if np.isnan(mx): mx = mu

    # Ensure min <= max
    if mx < mn:
        mn, mx = mx, mn

    # Approximate quartiles (normal assumption), then clamp to [min, max]
    q1 = mu - Z_25 * sd
    q3 = mu + Z_25 * sd
    med = mu

    q1 = min(max(q1, mn), mx)
    q3 = min(max(q3, mn), mx)
    med = min(max(med, mn), mx)

    # If q1 > q3 (pathological due to clamping), collapse to median
    if q1 > q3:
        q1 = q3 = med = min(max(mu, mn), mx)

    return {
        "label": f"{sample}",
        "whislo": mn,   # bottom whisker
        "q1": q1,       # 25th percentile
        "med": med,     # median (here: mean, clamped)
        "q3": q3,       # 75th percentile
        "whishi": mx,   # top whisker
        "fliers": []
    }

In [108]:
output_dir = "../figures/exhaustive_collisions"

for model in model_names:
    df = model_name_to_data[model]

    # Expect columns: layer, count, mean, std, min, max
    required = {"mean", "std", "min", "max"}
    missing = required - set(map(str.lower, df.columns))
    if missing:
        # Try case-insensitive rename
        lower_map = {c.lower(): c for c in df.columns}
        df = df.rename(columns={lower_map.get(k, k): k for k in required if k in lower_map})

    # Build bxp stats list
    bxp_stats = [stats_row_to_bxp_dict(row) for _, row in df.iterrows()]

    fig, ax = plt.subplots(figsize=(max(8, len(bxp_stats) * 0.35), 3))

    ax.bxp(
        bxp_stats,
        showfliers=False,
        widths=0.6,
        medianprops=dict(color=palette['Auburn'], linewidth=1.8)
    )

    ax.set_xlabel("Sample Index")   
    ax.set_ylabel("L2 Distance")

    # Improve readability
    ax.grid(axis="y", linestyle="--", alpha=0.4)

    ax.tick_params(axis="x", labelrotation=0, labelsize=14, pad=8, width=0.5)
    ax.tick_params(axis="y", labelsize=14, width=0.5)

    ax.tick_params(axis="y", which="minor", width=0.0, length=3)
    ax.tick_params(axis="y", which="major", width=1.2, length=6)

    all_vals = []
    for b in bxp_stats:
        all_vals.extend([b["whislo"], b["q1"], b["med"], b["q3"], b["whishi"]])
    if all(v > 0 for v in all_vals):
        ax.set_yscale("log")

    # Draw the collision threshold line
    collision_threshold = 1e-6
    ax.axhline(collision_threshold, linestyle="--", linewidth=2.0,
               color=palette['Auburn'], alpha=0.9, zorder=10)
    ax.annotate(
        "Collision threshold",
        xy=(0.5, collision_threshold), xycoords=("axes fraction", "data"),
        xytext=(0, 8), textcoords="offset points",
        ha="center", va="bottom",
        fontsize=24, color=palette['Auburn'],
        bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="none", alpha=0.8)
    )

    # === Global minimum line ===
    global_min = min(b["whislo"] for b in bxp_stats)
    # ax.axhline(global_min, linestyle=":", linewidth=1.5, color=palette['Midnight green'], alpha=0.9, zorder=9)
    # ax.annotate(
    #     f"minimum distance: {global_min:.2f}",
    #     xy=(0.5, global_min), xycoords=("axes fraction", "data"),
    #     xytext=(0, 5), textcoords="offset points",
    #     ha="center", va="bottom",
    #     fontsize=18, color=palette['Midnight green'],
    #     bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7)
    # )

    # Set model-specific y-axis limits and ticks
    if 'gemma' in model.lower():
        yticks = [1e-7, 1e-1, 1e4]
        combined_ylim = (1e-7, 1e5)
    else:  # GPT2 and other models
        yticks = [1e-7, 1e-1, 1e2]
        combined_ylim = (1e-7, 1e4)
    
    ax.set_ylim(combined_ylim)
    ax.set_yticks(yticks)  

    plt.tight_layout()
    
    fig.subplots_adjust(bottom=0.25)  # Increase bottom margin for xlabel

    out_path = os.path.join(output_dir, f"exhaustive-{model}-boxplot.pdf")
    plt.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0.1)
    plt.close(fig)
