In [44]:
import pandas as pd

import os
import math
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json 
import seaborn as sns 

In [45]:
plt.rcParams.update(
    {
        "text.usetex": True,
        "font.family": "serif",
        "axes.titlesize": 24,        # Larger axes/title fonts
        "legend.fontsize": 24,
    }
)
sns.set_context("talk")


In [46]:
palette_name = 'ocean_sunset'
palette = json.load(open(f'../palettes/{palette_name}.json', 'r'))

In [47]:
palette

{'Rich black': '#001219',
 'Midnight green': '#005f73',
 'Dark cyan': '#0a9396',
 'Tiffany Blue': '#94d2bd',
 'Vanilla': '#e9d8a6',
 'Gamboge': '#ee9b00',
 'Alloy orange': '#ca6702',
 'Rust': '#bb3e03',
 'Rufous': '#ae2012',
 'Auburn': '#9b2226'}

## Sampling collisions

In [48]:
output_dir = '../figures/sampling_collisions'
model_names = ['gemma-3-1b-pt', 'gemma-3-4b-pt', 'gemma-3-12b-pt', 'gpt2', 'gpt2-medium', 'gpt2-large'] #  'Llama-3.1-8B', 'Mistral-7B-v0.1', 'Phi-4-mini-instruct', 'TinyStories-1M', 'TinyStories-8M', 'TinyStories-33M'

In [49]:
model_name_to_data = {}

for model in model_names:

    filename_fn = lambda x: f'{x}-layer-stats.csv'
    df = pd.read_csv(f'../data/sampling_collisions/{filename_fn(model)}')
    model_name_to_data[model] = df
    print(df.head())

   layer       count        mean         std       min          max
0      1  4999950000  388.000867   81.484355  0.690545  1221.549438
1      2  4999950000  454.868565   83.023647  1.488919  1197.440063
2      3  4999950000  522.992854   88.149104  2.125064  1582.803589
3      4  4999950000  636.365883  139.904533  2.682702  1849.554932
4      5  4999950000  756.525868  315.431737  5.326268  6730.736328
   layer       count         mean          std       min           max
0      1  4999950000   755.600975   570.572009  0.581290  10856.806641
1      2  4999950000   787.215231   586.668592  1.633421  11545.596680
2      3  4999950000   885.661262   788.778374  1.971468  17010.734375
3      4  4999950000   948.739379   916.405098  2.184747  19956.787109
4      5  4999950000  1184.181242  1409.029561  4.980736  31204.494141
   layer       count         mean          std       min           max
0      1  4999950000   889.823674   562.720207  0.896930   8438.365234
1      2  4999950000   8

In [50]:
Z_25 = 0.6744897501960817  # Φ^{-1}(0.75) for normal ≈ 0.67449

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return np.nan

def stats_row_to_bxp_dict(row):
    """
    Convert a row with (layer, mean, std, min, max) to a dict for matplotlib.bxp.
    Expected row fields: 'layer', 'mean', 'std', 'min', 'max'
    """
    lyr = int(row["layer"])
    mu  = _safe_float(row["mean"])
    sd  = _safe_float(row["std"])
    mn  = _safe_float(row["min"])
    mx  = _safe_float(row["max"])

    # Fallbacks if missing/NaN
    if np.isnan(mu): mu = 0.0
    if np.isnan(sd): sd = 0.0
    if np.isnan(mn): mn = mu
    if np.isnan(mx): mx = mu

    # Ensure min <= max
    if mx < mn:
        mn, mx = mx, mn

    # Approximate quartiles (normal assumption), then clamp to [min, max]
    q1 = mu - Z_25 * sd
    q3 = mu + Z_25 * sd
    med = mu

    q1 = min(max(q1, mn), mx)
    q3 = min(max(q3, mn), mx)
    med = min(max(med, mn), mx)

    # If q1 > q3 (pathological due to clamping), collapse to median
    if q1 > q3:
        q1 = q3 = med = min(max(mu, mn), mx)

    return {
        "label": f"{lyr}",
        "whislo": mn,   # bottom whisker
        "q1": q1,       # 25th percentile
        "med": med,     # median (here: mean, clamped)
        "q3": q3,       # 75th percentile
        "whishi": mx,   # top whisker
        "fliers": []
    }

In [68]:
for model in model_names:
    df = model_name_to_data[model]

    # Expect columns: layer, count, mean, std, min, max
    required = {"layer", "mean", "std", "min", "max"}
    missing = required - set(map(str.lower, df.columns))
    if missing:
        # Try case-insensitive rename
        lower_map = {c.lower(): c for c in df.columns}
        df = df.rename(columns={lower_map.get(k, k): k for k in required if k in lower_map})

    # Sort by layer
    if "layer" in df.columns:
        df = df.sort_values("layer")

    # Build bxp stats list
    bxp_stats = [stats_row_to_bxp_dict(row) for _, row in df.iterrows()]

    fig, ax = plt.subplots(figsize=(max(8, len(bxp_stats) * 0.35), 5))

    ax.bxp(
        bxp_stats,
        showfliers=False,
        widths=0.6,
        medianprops=dict(color=palette['Auburn'], linewidth=1.8)
    )

    ax.set_xlabel("Layer")   
    ax.set_ylabel("Value")

    # Improve readability
    ax.grid(axis="y", linestyle="--", alpha=0.4)

    ax.tick_params(axis="x", labelrotation=0, labelsize=14, pad=8, width=0.5)
    ax.tick_params(axis="y", labelsize=14, width=0.5)

    ax.tick_params(axis="y", which="minor", width=0.0, length=3)
    ax.tick_params(axis="y", which="major", width=1.2, length=6)

    all_vals = []
    for b in bxp_stats:
        all_vals.extend([b["whislo"], b["q1"], b["med"], b["q3"], b["whishi"]])
    if all(v > 0 for v in all_vals):
        ax.set_yscale("log")

    # Draw the collision threshold line
    ax.axhline(0.0, linestyle="--", linewidth=1.5, color=palette['Auburn'], alpha=0.9, zorder=10)
    ax.annotate(
        "collision threshold",
        xy=(0.995, 0.0), xycoords=("axes fraction", "data"),
        ha="right", va="bottom",
        fontsize=10, color=palette['Auburn'],
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7)
    )

    # === Global minimum line ===
    global_min = min(b["whislo"] for b in bxp_stats)
    ax.axhline(global_min, linestyle=":", linewidth=1.5, color=palette['Midnight green'], alpha=0.9, zorder=9)
    ax.annotate(
        f"minimum distance: {global_min:.2f}",
        xy=(0.5, global_min), xycoords=("axes fraction", "data"),
        xytext=(0, 5), textcoords="offset points",
        ha="center", va="bottom",
        fontsize=18, color=palette['Midnight green'],
        bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.7)
    )

    # Adjust y-limits with extra headroom
    ymax = max(b["whishi"] for b in bxp_stats)
    pad = max(1e-6, 0.05 * ymax)  # 👈 increase from 0.03 to 0.05 for more space
    ax.set_ylim(global_min - pad, ymax + 2 * pad)  

    plt.tight_layout()
    
    fig.subplots_adjust(bottom=0.18)  # 0–1; bump this up if still clipped

    out_path = os.path.join(output_dir, f"{model}-layer-boxplot.pdf")
    plt.savefig(out_path, dpi=200)
    plt.close(fig)


  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
  ax.set_ylim(global_min - pad, ymax + 2 * pad)
  plt.tight_layout()
