In [3]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.font_manager import FontProperties
from transformers import AutoTokenizer, AutoModel

# ============================ CONFIG ============================
CSV_PATH   = "it_isdt-ud-train_sentences.csv"      # must have: sentence_id (str), tokens (list[str])
MODELS     = ["bert-base-uncased", "gpt2"]         # both models
BATCH_SIZE = 1
RAND_SEED  = 42

# Outlier detection: using 3-sigma criterion on per-dimension MEANS
USE_ABS_MEAN     = False      # if True, use mean(|x|); else mean(x)
OUTLIER_SIGMAS   = 3          # number of standard deviations from mean to define outlier

# Optional: cap points to accelerate embedding (None = all)
TOKEN_CAP: int | None = None

# Robust loading: work offline if models are cached
LOCAL_ONLY = (
    os.environ.get("TRANSFORMERS_OFFLINE") == "1"
    or os.environ.get("HF_HUB_OFFLINE") == "1"
)

OUT_DIR = Path("outlier_dims_all_tokens_3sigma"); OUT_DIR.mkdir(parents=True, exist_ok=True)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")

# ============================ LEGEND STYLE (SAME AS BEFORE) ============================
# Legend-only font + compact spacing (does NOT change axis/title fonts from seaborn context)
LEGEND_FP = FontProperties(
    family="DejaVu Sans Mono",  # cambia a "DejaVu Sans" / "serif" si quieres
    size=10                     # tamaño de la leyenda (independiente de sns font_scale)
)

LEGEND_KW = dict(
    prop=LEGEND_FP,
    frameon=False,
    borderpad=0.20,
    labelspacing=0.25,
    handlelength=1.15,
    handletextpad=0.40,
    borderaxespad=0.25,
    markerscale=0.85,
)

def small_legend(ax, **kwargs):
    """Apply compact legend styling on a given axis."""
    kw = dict(LEGEND_KW)
    kw.update(kwargs)
    return ax.legend(**kw)

# First figure y-axis clip: keep the axis small but still show the thresholds
FIRST_FIG_Y_CLIP_QUANTILE = 0.995  # use 99.5th percentile to clip y-axis
FIRST_FIG_TOP_PAD = 1.05           # padding factor above chosen limits

# Bubble sizes (second figure)
DOT_SIZE_MIN = 30.0
DOT_SIZE_MAX = 900.0

# ============================ HELPERS ============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None:
        n = getattr(model.config, "n_layer", None)
    if n is None:
        raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None:
        d = getattr(model.config, "n_embd", None)
    if d is None:
        raise ValueError("Cannot determine hidden size")
    return int(d)

def _subtoken_policy(model, resolved_id: str) -> str:
    """
    Decide whether to take the FIRST or LAST subtoken per word.
    - encoder/MLM families -> 'first'
    - decoder/causal families -> 'last'
    """
    t = (getattr(model.config, "model_type", "") or "").lower()
    enc_first = {"bert", "roberta", "albert", "electra", "distilbert", "deberta", "deberta-v2", "camembert"}
    dec_last  = {"gpt2", "gpt_neo", "gpt_neox", "gptj", "bloom", "opt", "llama", "falcon", "mistral", "xglm", "replit"}
    if t in dec_last or "gpt" in resolved_id.lower():
        return "last"
    if t in enc_first:
        return "first"
    return "first"

def _load_tok_and_model(model_id: str):
    """Fast tokenizer (word_ids) + model; robust for GPT‑2; offline‑friendly."""
    tried: List[Tuple[str, str]] = []

    def _try(mid: str):
        tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True, local_files_only=LOCAL_ONLY)
        if getattr(tok, "padding_side", None) != "right":
            tok.padding_side = "right"
        if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
            tok.pad_token = tok.eos_token  # GPT‑2 has no pad_token by default
        mdl = AutoModel.from_pretrained(mid, output_hidden_states=True, local_files_only=LOCAL_ONLY)
        if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
            mdl.config.pad_token_id = tok.pad_token_id
        return tok, mdl

    order = [model_id]
    if model_id.lower() in {"gpt2", "gpt-2"}:
        order += ["openai-community/gpt2", "distilgpt2"]

    last_err = None
    for mid in order:
        try:
            tok, mdl = _try(mid)
            mdl = mdl.eval().to(device)
            if device == "cuda": mdl.half()
            return tok, mdl, mid
        except Exception as e:
            tried.append((mid, repr(e))); last_err = e

    raise RuntimeError(
        "Could not load tokenizer/model. Attempts:\n" +
        "\n".join(f" - {m}: {err}" for m, err in tried)
    ) from last_err

def load_token_index(csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return (df_all_sentences, token_index_df) for ALL tokens (no POS filtering)."""
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)

    rows = []
    for sid, toks in df[["sentence_id","tokens"]].itertuples(index=False):
        for wid in range(len(toks)):
            rows.append((sid, wid))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id"])

    if TOKEN_CAP is not None and len(word_df) > TOKEN_CAP:
        word_df = word_df.sample(TOKEN_CAP, random_state=RAND_SEED).reset_index(drop=True)
    return df, word_df

def _create_layer_memmaps(L: int, N: int, D: int, base_dir: Path, tag: str) -> tuple[list[Path], list[np.memmap]]:
    base_dir.mkdir(parents=True, exist_ok=True)
    files, mms = [], []
    for l in range(L):
        fn = base_dir / f"{tag}_layer{l:02d}.mmap"
        mm = np.memmap(fn, dtype="float32", mode="w+", shape=(N, D))
        files.append(fn); mms.append(mm)
    return files, mms

def embed_to_memmaps(df_all: pd.DataFrame, token_df: pd.DataFrame, model_id: str,
                     batch_size: int = 1, out_dir: Path | None = None) -> tuple[list[Path], int, int, int, str]:
    """
    Embed ALL tokens and spill per-layer (N,D) memmaps.
    BERT-like -> use FIRST subtoken; GPT-like -> use LAST subtoken.
    """
    tokzr, model, resolved_id = _load_tok_and_model(model_id)
    policy = _subtoken_policy(model, resolved_id)  # 'first' or 'last'
    print(f"• Subtoken policy for {resolved_id}: {policy}")

    token_df = token_df.copy()
    token_df["sentence_id"] = token_df["sentence_id"].astype(str)
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(token_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(sid, []).append((gidx, int(wid)))
    sids = list(by_sid.keys())
    df_sel = (df_all[df_all.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    L = _num_hidden_layers(model) + 1
    D = _hidden_size(model)
    N = len(token_df)

    tag = resolved_id.split("/")[-1]
    store_dir = (OUT_DIR / f"{tag}_memmaps") if out_dir is None else out_dir
    fns, mms = _create_layer_memmaps(L, N, D, store_dir, f"{tag}_{policy}")
    filled = np.zeros(N, dtype=bool)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{tag}:{policy} (embed→memmap)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L, B, T, D)

            for b, sid in enumerate(batch_ids):
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids() unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if policy == "first":
                        vec = h[:, b, toks[0], :]      # (L, D)
                    else:  # "last"
                        vec = h[:, b, toks[-1], :]     # (L, D)
                    for l in range(L):
                        mms[l][gidx, :] = vec[l]
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    if (~filled).any():
        print(f"⚠ Missing vectors for {int((~filled).sum())} tokens (skipped).")

    for mm in mms:
        mm.flush()
        del mm
    gc.collect()
    if device == "cuda": torch.cuda.empty_cache()

    return fns, L, N, D, tag

# ---------- MEAN statistics ----------
def per_layer_dim_means(memmap_paths: List[Path], L: int, N: int, D: int, use_abs: bool = True) -> np.ndarray:
    """Return per-layer, per-dim MEANS: shape (L, D)."""
    means = np.zeros((L, D), dtype=np.float32)
    for l in range(L):
        mm = np.memmap(memmap_paths[l], dtype="float32", mode="r", shape=(N, D))
        X = np.array(mm, copy=False)
        if use_abs:
            means[l] = np.mean(np.abs(X), axis=0)
        else:
            means[l] = np.mean(X, axis=0)
        del mm; gc.collect()
    return means

def counts_prev_overlap_and_magnitude_3sigma(vals: np.ndarray, n_sigmas: float = 3.0
                                             ) -> tuple[np.ndarray, np.ndarray, float, float, np.ndarray]:
    """
    vals: (L,D) array of per‑dimension statistics (means here).
    n_sigmas: number of standard deviations for threshold.
    Returns:
      counts        (L,)   : # of outlier dims in each layer (≥ μ+σ*n or ≤ μ−σ*n)
      overlaps_prev (L,)   : # outlier dims overlapping with previous layer (0 at layer 0)
      thr_plus      (float): + threshold (global mean + nσ)
      thr_minus     (float): − threshold (global mean − nσ)
      magnitude     (L,)   : average exceedance magnitude for outlier dims in each layer
    """
    L, D = vals.shape
    flat = vals.reshape(-1)
    mu   = float(flat.mean())
    sigma= float(flat.std(ddof=0))
    thr_plus  = mu + n_sigmas * sigma
    thr_minus = mu - n_sigmas * sigma

    od_sets = []
    counts   = np.zeros(L, dtype=np.int32)
    overlaps_prev = np.zeros(L, dtype=np.int32)
    magnitude = np.zeros(L, dtype=np.float32)

    for l in range(L):
        mask = (vals[l] >= thr_plus) | (vals[l] <= thr_minus)
        idx = np.where(mask)[0]
        idx_set = set(idx)
        od_sets.append(idx_set)
        counts[l] = len(idx)

        if len(idx) > 0:
            above = vals[l, idx] - thr_plus
            below = thr_minus - vals[l, idx]
            exc = np.where(vals[l, idx] >= thr_plus, above, below)
            magnitude[l] = float(np.mean(exc))
        else:
            magnitude[l] = 0.0

        if l > 0:
            overlaps_prev[l] = len(idx_set.intersection(od_sets[l-1]))

    return counts, overlaps_prev, thr_plus, thr_minus, magnitude

# ============================ PLOTTING ============================
def make_plots(model_to_stats: Dict[str, dict], out_dir: Path = OUT_DIR):
    """
    model_to_stats[model_tag] = {
        'means': (L,D),
        'thr_plus': float,
        'thr_minus': float,
        'counts': (L,),
        'overlaps_prev': (L,),
        'mag': (L,)  # average exceedance per layer
    }
    """
    # ---------- Left: last-layer means for each model ----------
    fig, axes = plt.subplots(1, len(model_to_stats), figsize=(6.6 * len(model_to_stats), 4.6), sharey=True)
    if len(model_to_stats) == 1:
        axes = [axes]

    for ax, (tag, st) in zip(axes, model_to_stats.items()):
        means = st["means"]; thr_plus = st["thr_plus"]; thr_minus = st["thr_minus"]
        last_vals = means[-1]
        ax.scatter(np.arange(last_vals.size), last_vals, s=8, alpha=0.6, edgecolor="none")
        ax.axhline(thr_plus,  color="orange", linestyle="--", linewidth=1.5, label=f"μ+{OUTLIER_SIGMAS}σ")
        ax.axhline(thr_minus, color="orange", linestyle="--", linewidth=1.5, label=f"μ−{OUTLIER_SIGMAS}σ")

        # (Tu código original mantiene ylim fijo)
        ax.set_ylim(-80, 80)

        ax.set_title(f"{tag} • last layer")
        ax.set_xlabel("dimension")
        ax.set_ylabel("mean(|activation|)" if USE_ABS_MEAN else "mean(activation)")

        # ✅ SAME legend styling (small + different font)
        small_legend(ax, loc="upper left")

    plt.tight_layout()
    out_png1 = out_dir / "last_layer_means_both_3sigma.pdf"
    plt.savefig(out_png1, dpi=220)
    plt.close(fig)
    print("✓ wrote", out_png1)

    # ---------- Right: counts & overlaps-with-previous + bubble magnitude ----------
    sns.set_style("darkgrid")
    sns.set_context("paper", font_scale=2.5)
    plt.figure(figsize=(9, 5))
    ax2 = plt.gca()

    # For bubble scaling, normalize across all models/layers
    all_mag = np.concatenate([st["mag"] for st in model_to_stats.values()]) if len(model_to_stats) > 0 else np.array([0.0])
    mag_min, mag_max = float(all_mag.min()), float(all_mag.max())
    mag_span = max(mag_max - mag_min, 1e-12)

    for tag, st in model_to_stats.items():
        L = st["means"].shape[0]
        layers = np.arange(L)

        ax2.plot(layers, st["counts"], marker="o", linewidth=2.5, alpha=0.95, label=f"{tag}: #ODs")
        ax2.plot(layers, st["overlaps_prev"], marker="s", linestyle="--", linewidth=2.5, alpha=0.95,
                 label=f"{tag}: #OD ∩ prev")

        mags  = st["mag"]
        sizes = DOT_SIZE_MIN + (DOT_SIZE_MAX - DOT_SIZE_MIN) * ((mags - mag_min) / mag_span)
        ax2.scatter(layers, st["counts"], s=sizes, alpha=0.35, edgecolor="none")

    ax2.set_xlabel("layer")
    ax2.set_ylabel("# outlier dims")

    # ✅ SAME legend styling (small + different font)
    small_legend(ax2, ncol=2, loc="upper left")

    plt.tight_layout()
    out_png2 = out_dir / "od_counts_per_layer_both_3sigma_prev_bubbles.pdf"
    plt.savefig(out_png2, dpi=220)
    plt.close()
    print("✓ wrote", out_png2)

# ============================ DRIVER ============================
def run():
    df_all, token_df = load_token_index(CSV_PATH)
    print(f"✓ Total tokens: {len(token_df):,}")

    model_to_stats: Dict[str, dict] = {}
    for model_id in MODELS:
        print(f"\n=== Embedding all tokens • {model_id} (auto first/last subtoken) ===")
        mmap_paths, L, N, D, tag = embed_to_memmaps(df_all, token_df, model_id, batch_size=BATCH_SIZE)
        print(f"Shapes: layers={L}, tokens={N}, dim={D}")

        means = per_layer_dim_means(mmap_paths, L, N, D, use_abs=USE_ABS_MEAN)
        counts, overlaps_prev, thr_plus, thr_minus, magnitudes = counts_prev_overlap_and_magnitude_3sigma(means, OUTLIER_SIGMAS)
        print(f"{tag}: 3σ thresholds = [{thr_minus:.6f}, {thr_plus:.6f}]  • last-layer #OD={counts[-1]}")

        model_to_stats[tag] = {
            "means": means,
            "counts": counts,
            "overlaps_prev": overlaps_prev,
            "thr_plus": thr_plus,
            "thr_minus": thr_minus,
            "mag": magnitudes
        }

        # Clean up memmaps (optional)
        for p in mmap_paths:
            try: os.remove(p)
            except Exception: pass

    make_plots(model_to_stats, OUT_DIR)
    print("✓ done. See:", OUT_DIR.resolve())

if __name__ == "__main__":
    run()


✓ Total tokens: 253,771

=== Embedding all tokens • bert-base-uncased (auto first/last subtoken) ===


  with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):


• Subtoken policy for bert-base-uncased: first


bert-base-uncased:first (embed→memmap): 100%|█| 12112/12112 [01:52<00:00, 107.56


⚠ Missing vectors for 27 tokens (skipped).
Shapes: layers=13, tokens=253771, dim=768
bert-base-uncased: 3σ thresholds = [-1.187926, 1.139241]  • last-layer #OD=2

=== Embedding all tokens • gpt2 (auto first/last subtoken) ===


  with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):


• Subtoken policy for openai-community/gpt2: last


gpt2:last (embed→memmap): 100%|███████████| 12112/12112 [02:08<00:00, 94.58it/s]


Shapes: layers=13, tokens=253771, dim=768
gpt2: 3σ thresholds = [-13.937875, 14.172828]  • last-layer #OD=4
✓ wrote outlier_dims_all_tokens_3sigma/last_layer_means_both_3sigma.pdf
✓ wrote outlier_dims_all_tokens_3sigma/od_counts_per_layer_both_3sigma_prev_bubbles.pdf
✓ done. See: /home/ldomenichelli/geometric_profiling_of_a_neural_language_model/outlier_dims_all_tokens_3sigma


In [1]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# ============================ CONFIG ============================
CSV_PATH   = "it_isdt-ud-train_sentences.csv"      # must contain: sentence_id(str), tokens(list[str])
MODELS     = ["bert-base-uncased", "gpt2"]         # compare both
BATCH_SIZE = 1
RAND_SEED  = 42

# Outlier rule: global quantile over all (layer,dim) medians
USE_ABS_MEDIAN      = True     # use median(|x|) if True, else median(x)
OUTLIER_QUANTILE    = 0.99     # e.g., top 1% across all layers & dims

# First plot y-axis compression (clip to this quantile of last-layer medians)
FIRST_PLOT_Y_QUANT  = 0.995

# Bubble sizes for second plot (scaled from mean exceedance)
DOT_SIZE_MIN        = 30.0
DOT_SIZE_MAX        = 600.0

# Optional: cap tokens to speed up embedding (None = all)
TOKEN_CAP: int | None = None

# Offline-friendly load if models are cached
LOCAL_ONLY = (
    os.environ.get("TRANSFORMERS_OFFLINE") == "1"
    or os.environ.get("HF_HUB_OFFLINE") == "1"
)

OUT_DIR = Path("outlier_dims_all_tokens_quantile"); OUT_DIR.mkdir(parents=True, exist_ok=True)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")

# ============================ HELPERS ============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

def _subtoken_policy(model, resolved_id: str) -> str:
    """
    FIRST subtoken for encoder/MLM families; LAST for decoder/causal families.
    """
    t = (getattr(model.config, "model_type", "") or "").lower()
    enc_first = {"bert", "roberta", "albert", "electra", "distilbert", "deberta", "deberta-v2", "camembert"}
    dec_last  = {"gpt2", "gpt_neo", "gpt_neox", "gptj", "bloom", "opt", "llama", "falcon", "mistral", "xglm", "replit"}
    if t in dec_last or "gpt" in resolved_id.lower():  # GPT-like → last piece
        return "last"
    if t in enc_first:                                  # BERT-like → first piece
        return "first"
    return "first"

def _load_tok_and_model(model_id: str):
    """Fast tokenizer (word_ids) + model; robust for GPT‑2; offline‑friendly."""
    tried: List[Tuple[str, str]] = []

    def _try(mid: str):
        tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True, local_files_only=LOCAL_ONLY)
        if getattr(tok, "padding_side", None) != "right":
            tok.padding_side = "right"
        if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
            tok.pad_token = tok.eos_token  # GPT‑2 has no pad_token by default
        mdl = AutoModel.from_pretrained(mid, output_hidden_states=True, local_files_only=LOCAL_ONLY)
        if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
            mdl.config.pad_token_id = tok.pad_token_id
        return tok, mdl

    order = [model_id]
    if model_id.lower() in {"gpt2", "gpt-2"}:
        order += ["openai-community/gpt2", "distilgpt2"]

    last_err = None
    for mid in order:
        try:
            tok, mdl = _try(mid)
            mdl = mdl.eval().to(device)
            if device == "cuda": mdl.half()
            return tok, mdl, mid
        except Exception as e:
            tried.append((mid, repr(e))); last_err = e

    raise RuntimeError("Could not load tokenizer/model. Attempts:\n" + "\n".join(f" - {m}: {err}" for m, err in tried)) from last_err

def load_token_index(csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return (df_all_sentences, token_index_df) for ALL tokens."""
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)

    rows = []
    for sid, toks in df[["sentence_id","tokens"]].itertuples(index=False):
        for wid in range(len(toks)):
            rows.append((sid, wid))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id"])

    if TOKEN_CAP is not None and len(word_df) > TOKEN_CAP:
        word_df = word_df.sample(TOKEN_CAP, random_state=RAND_SEED).reset_index(drop=True)
    return df, word_df

def _create_layer_memmaps(L: int, N: int, D: int, base_dir: Path, tag: str) -> tuple[list[Path], list[np.memmap]]:
    base_dir.mkdir(parents=True, exist_ok=True)
    files, mms = [], []
    for l in range(L):
        fn = base_dir / f"{tag}_layer{l:02d}.mmap"
        mm = np.memmap(fn, dtype="float32", mode="w+", shape=(N, D))
        files.append(fn); mms.append(mm)
    return files, mms

def embed_to_memmaps(df_all: pd.DataFrame, token_df: pd.DataFrame, model_id: str,
                     batch_size: int = 1, out_dir: Path | None = None) -> tuple[list[Path], int, int, int, str]:
    """
    Embed ALL tokens and spill per-layer (N,D) memmaps.
    BERT-like -> FIRST subtoken; GPT-like -> LAST subtoken.
    """
    tokzr, model, resolved_id = _load_tok_and_model(model_id)
    policy = _subtoken_policy(model, resolved_id)
    print(f"• Subtoken policy for {resolved_id}: {policy}")

    token_df = token_df.copy()
    token_df["sentence_id"] = token_df["sentence_id"].astype(str)
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(token_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(sid, []).append((gidx, int(wid)))
    sids = list(by_sid.keys())
    df_sel = (df_all[df_all.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    L = _num_hidden_layers(model) + 1
    D = _hidden_size(model)
    N = len(token_df)

    tag = resolved_id.split("/")[-1]
    store_dir = (OUT_DIR / f"{tag}_memmaps") if out_dir is None else out_dir
    fns, mms = _create_layer_memmaps(L, N, D, store_dir, f"{tag}_{policy}")
    filled = np.zeros(N, dtype=bool)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{tag}:{policy} (embed→memmap)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L, B, T, D)

            for b, sid in enumerate(batch_ids):
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids() unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if policy == "first":
                        vec = h[:, b, toks[0], :]      # (L, D)
                    else:  # "last"
                        vec = h[:, b, toks[-1], :]     # (L, D)
                    for l in range(L):
                        mms[l][gidx, :] = vec[l]
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    if (~filled).any():
        print(f"⚠ Missing vectors for {int((~filled).sum())} tokens (skipped).")

    for mm in mms:
        mm.flush()
        del mm
    gc.collect()
    if device == "cuda": torch.cuda.empty_cache()

    return fns, L, N, D, tag

# ---------- per-layer, per-dim statistics (MEDIAN) ----------
def per_layer_dim_medians(memmap_paths: List[Path], L: int, N: int, D: int, use_abs: bool = True) -> np.ndarray:
    """Return medians per layer & dim: shape (L, D)."""
    meds = np.zeros((L, D), dtype=np.float32)
    for l in range(L):
        mm = np.memmap(memmap_paths[l], dtype="float32", mode="r", shape=(N, D))
        X = np.array(mm, copy=False)
        meds[l] = np.median(np.abs(X), axis=0) if use_abs else np.median(X, axis=0)
        del mm; gc.collect()
    return meds

# ---------- outlier counting + prev-layer overlap + magnitude (exceedance) ----------
def counts_prev_overlap_and_magnitude_quantile(meds: np.ndarray, q: float) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]:
    """
    meds: (L, D) per-dim medians.
    q: global quantile in [0,1].
    Returns:
        counts        (L,)  : # of outlier dimensions at each layer
        overlaps_prev (L,)  : |OD_l ∩ OD_{l-1}| (with 0 for l=0)
        thr           float : global quantile threshold
        magnitude     (L,)  : mean exceedance of outliers at each layer, i.e. mean( med[l, idx] - thr )
    """
    L, D = meds.shape
    flat  = meds.reshape(-1)
    thr   = float(np.quantile(flat, q))

    od_sets = []
    counts = np.zeros(L, dtype=np.int32)
    overlaps_prev = np.zeros(L, dtype=np.int32)
    magnitude = np.zeros(L, dtype=np.float32)

    for l in range(L):
        idx = np.where(meds[l] >= thr)[0]
        od_sets.append(set(idx))
        counts[l] = idx.size
        magnitude[l] = float(np.mean(meds[l, idx] - thr)) if idx.size > 0 else 0.0
        if l > 0:
            overlaps_prev[l] = len(od_sets[l].intersection(od_sets[l-1]))
    return counts, overlaps_prev, thr, magnitude

# ============================ PLOTTING ============================
def make_plots(model_to_stats: Dict[str, dict], out_dir: Path = OUT_DIR):
    """
    model_to_stats[model_tag] = {
        'meds': (L,D),
        'thr': float,
        'counts': (L,),
        'overlaps_prev': (L,),
        'mag': (L,)
    }
    """

    # ---------- Figure 1: last-layer medians per model (with global quantile line) ----------
    fig, axes = plt.subplots(1, len(model_to_stats), figsize=(6.6 * len(model_to_stats), 4.8), sharey=True)
    if len(model_to_stats) == 1:
        axes = [axes]

    for ax, (tag, st) in zip(axes, model_to_stats.items()):
        meds = st["meds"]; thr = st["thr"]; L, D = meds.shape
        last = meds[-1]
        ax.scatter(np.arange(D), last, s=10, alpha=0.65, edgecolor="none")
        ax.axhline(thr, color="orange", linestyle="--", linewidth=1.5, label=f"global q={OUTLIER_QUANTILE:.2f}")
        # compress y-axis using a high quantile (but not below the threshold)
        y_max = float(max(np.quantile(last, FIRST_PLOT_Y_QUANT), thr*1.05))
        ax.set_ylim(0.0, y_max)
        ax.set_title(f"{tag} • last layer")
        ax.set_xlabel("dimension")
        ax.set_ylabel("median(|activation|)" if USE_ABS_MEDIAN else "median(activation)")
        ax.legend(frameon=False, fontsize="small")

    plt.tight_layout()
    out_png1 = out_dir / "last_layer_medians_both_quantile.pdf"
    plt.savefig(out_png1, dpi=220); plt.close(fig)
    print("✓ wrote", out_png1)

    # ---------- Figure 2: counts & overlaps w/previous + bubble size ∝ magnitude ----------
    plt.figure(figsize=(10.5, 5.8))

    # global scaling for bubble sizes
    all_mag = np.concatenate([st["mag"] for st in model_to_stats.values()])
    mag_min, mag_max = float(all_mag.min()), float(all_mag.max())
    mag_span = max(mag_max - mag_min, 1e-12)

    for tag, st in model_to_stats.items():
        L = st["meds"].shape[0]
        layers = np.arange(L)

        # lines
        plt.plot(layers, st["counts"], marker="o", linewidth=2.5, alpha=0.95, label=f"{tag}: #ODs")
        plt.plot(layers, st["overlaps_prev"], marker="s", linestyle="--", linewidth=2.5, alpha=0.95, label=f"{tag}: #OD ∩ prev")

        # bubbles sized by mean exceedance
        mags  = st["mag"]
        sizes = DOT_SIZE_MIN + (DOT_SIZE_MAX - DOT_SIZE_MIN) * ((mags - mag_min) / mag_span)
        plt.scatter(layers, st["counts"], s=sizes, alpha=0.55, edgecolor="none")

    plt.xlabel("layer")
    plt.ylabel(f"# outlier dims (≥ global q={OUTLIER_QUANTILE:.2f})")
    plt.title("Outlier dimensions per layer (global-quantile rule)\nSolid: count • Dashed: overlap with previous • Dot size ∝ exceedance")
    plt.legend(ncol=2, fontsize="small", frameon=False)
    plt.tight_layout()
    out_png2 = out_dir / "od_counts_per_layer_both_quantile_bubbles.pdf"
    plt.savefig(out_png2, dpi=220); plt.close()
    print("✓ wrote", out_png2)

# ============================ DRIVER ============================
def run():
    df_all, token_df = load_token_index(CSV_PATH)
    print(f"✓ Total tokens: {len(token_df):,}")

    model_to_stats: Dict[str, dict] = {}
    for model_id in MODELS:
        print(f"\n=== Embedding all tokens • {model_id} (auto first/last subtoken) ===")
        mmap_paths, L, N, D, tag = embed_to_memmaps(df_all, token_df, model_id, batch_size=BATCH_SIZE)
        print(f"Shapes: layers={L}, tokens={N}, dim={D}")

        meds = per_layer_dim_medians(mmap_paths, L, N, D, use_abs=USE_ABS_MEDIAN)
        counts, overlaps_prev, thr, mag = counts_prev_overlap_and_magnitude_quantile(meds, OUTLIER_QUANTILE)
        print(f"{tag}: global q={OUTLIER_QUANTILE:.2f}  → threshold={thr:.6f}  • last-layer #OD={counts[-1]}")

        model_to_stats[tag] = {"meds": meds, "counts": counts, "overlaps_prev": overlaps_prev, "thr": thr, "mag": mag}

        # tidy up memmaps
        for p in mmap_paths:
            try: os.remove(p)
            except Exception: pass

    make_plots(model_to_stats, OUT_DIR)
    print("✓ done. See:", OUT_DIR.resolve())

if __name__ == "__main__":
    run()



KeyboardInterrupt

