In [1]:
from __future__ import annotations
import os, ast, random, inspect
from pathlib import Path
from typing import Dict, List

import numpy as np, pandas as pd, torch
import torch.utils.data as torchdata
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoTokenizer, AutoModel, BertModel, AutoConfig

from IsoScore import IsoScore
from dadapy import Data
from skdim.id import MLE, MOM, TLE, CorrInt, FisherS, lPCA

2025-11-19 09:55:33.191035: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# ============== Optional deps (gracefully skipped if not installed) ==============
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            return float(np.clip(ev.mean() / (ev[-1] + 1e-9), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH      = "en_ewt-ud-train_sentences.csv"
LENGTH_COL    = "length"                # your per-sentence list[int] column

# - BERT:  BASELINE="bert-base-uncased", WORD_REP_MODE="first" (or "mean"/"last")
# - GPT-2: BASELINE="gpt2",             WORD_REP_MODE="last"  (or "mean")
BASELINE      = "bert-base-uncased"
WORD_REP_MODE = "first"

RAW_MAX_PER_CLASS = 253_700
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 200
FAST_BS_MAX_SAMP_PER_CLASS  = 253_700  # M ~ N for classic bootstrap
HEAVY_BS_MAX_SAMP_PER_CLASS = 5000



LENGTH_MAX_CLASS = 10
EXCLUDE_ZERO_LENGTH = True

# Repro / device
RAND_SEED=42
PLOT_DIR = Path("results_LENGTH"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR  = Path("tables_LENGTH") / "length_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 1
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120
EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

# ========= Per-subsample single-value compute functions (your original set) =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + 1e-9))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng()
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng()
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ~ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast, spectral) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

def _skdim_factory(name: str):
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca95": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "ess": ESS, "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None
    def _builder():
        if name == "lpca":   return cls(ver="FO")
        if name == "lpca95": return cls(ver="ratio", alphaRatio=0.95)
        if name == "lpca99": return cls(ver="ratio", alphaRatio=0.99)
        return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None
    def _once(X: np.ndarray) -> float:
        est = build()
        est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once

# Registries (FULL set preserved)
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    # Isotropy
    "iso": _iso_once, "spect": _spect_once, "rand": _rand_once,
    "sf": _sf_once, "pfI": _pfI_once, "vmf_kappa": _vmf_kappa_once,
    # Linear ID
    "pca95": _pca95_once, "pca99": _pca99_once,
    "erank": _erank_once, "pr": _pr_once, "stable_rank": _stable_rank_once,
}
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    # Non-linear / local
    "twonn": _dadapy_twonn_once, "gride": _dadapy_gride_once,
    "mom": _skdim_once_builder("mom"), "tle": _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"), "fishers": _skdim_once_builder("fishers"),
    "lpca": _skdim_once_builder("lpca"), "lpca95": _skdim_once_builder("lpca95"),
    "lpca99": _skdim_once_builder("lpca99"),
    "mle": _skdim_once_builder("mle"), "danco": _skdim_once_builder("danco"),
    "ess": _skdim_once_builder("ess"), "mind_ml": _skdim_once_builder("mind_ml"),
    "mada": _skdim_once_builder("mada"), "knn": _skdim_once_builder("knn"),
}
LABELS = {
    # Isotropy
    "iso":"IsoScore","spect":"Spectral Ratio","rand":"RandCos |μ|",
    "sf":"Spectral Flatness","pfI":"Partition Isotropy I","vmf_kappa":"vMF κ",
    # Linear ID
    "pca95":"PCs@95%","pca99":"PCs@99%","erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    # Non-linear / local
    "twonn":"TwoNN ID","gride":"GRIDE","mom":"MOM","tle":"TLE","corrint":"CorrInt","fishers":"FisherS",
    "lpca":"lPCA FO","lpca95":"lPCA 0.95","lpca99":"lPCA 0.99",
    "mle":"MLE","danco":"DANCo","ess":"ESS","mind_ml":"MiND-ML","mada":"MADA","knn":"KNN",
}
ALL_METRICS=[ "gride"]
# =============================== DATA: use existing length column ===============================
def _pick_length_col(df: pd.DataFrame) -> str:
    for cand in [LENGTH_COL, "token_length", "lengths", "len", "LEN", "Length"]:
        if cand in df.columns:
            return cand
    raise ValueError(f"No length column found. Tried: {LENGTH_COL}, token_length, lengths, len, LEN, Length.")

def load_length_from_column(csv_path: str,
                            length_max: int = LENGTH_MAX_CLASS,
                            exclude_zero: bool = EXCLUDE_ZERO_LENGTH):
    """
    CSV expects:
      sentence_id (str), tokens (list[str]), length (list[int]) per sentence.
    Emits token-level rows with 'length_class' in {1..length_max} (length_max means 10+ by default).
    """
    df_all = pd.read_csv(csv_path)
    len_col = _pick_length_col(df_all)
    df = df_all[["sentence_id","tokens", len_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens  = df.tokens.apply(_to_list)
    df[len_col] = df[len_col].apply(_to_list)

    rows = []
    for sid, toks, lens in df[["sentence_id","tokens", len_col]].itertuples(index=False):
        L = min(len(toks), len(lens))
        for wid in range(L):
            try:
                k = int(lens[wid])
            except Exception:
                continue
            if k <= 0 and exclude_zero:
                continue
            if k < 0:
                continue
            k = min(max(k, 0), length_max)  # bucket upper tail into 'length_max'
            if k == 0 and exclude_zero:
                continue
            rows.append((sid, wid, str(k), toks[wid]))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","length_class","word"])
    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    if df_tok.empty:
        raise ValueError("No token rows constructed—check that your length column contains integer lists.")
    return df_sent, df_tok

# =============================== Tokenizer/Model loader (robust for GPT‑2) ===============================
def _load_tok_and_model(baseline: str):
    """
    Robust loader:
    - Prefer fast tokenizers (needed for .word_ids()).
    - For GPT-2: set pad_token to eos_token and right-padding.
    - Fallback to 'openai-community/gpt2' if the plain 'gpt2' entry is unavailable.
    """
    model_id = baseline
    tok = None; mdl = None
    tried = []

    def _try_load(mid: str):
        tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
        # Ensure right padding (GPT-2 tips)
        if getattr(tok, "padding_side", None) != "right":
            tok.padding_side = "right"
        # For GPT-2 and similar: add pad token if missing
        if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
            tok.pad_token = tok.eos_token
        mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
        # Propagate pad id to model if absent
        if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
            mdl.config.pad_token_id = tok.pad_token_id
        return tok, mdl

    # try main id
    try:
        tok, mdl = _try_load(model_id)
    except Exception as e1:
        tried.append((model_id, str(e1)))
        # GPT-2 robust fallback namespace
        if model_id.lower() in {"gpt2", "gpt-2"}:
            for alt in ["openai-community/gpt2", "gpt2"]:
                try:
                    tok, mdl = _try_load(alt)
                    model_id = alt
                    break
                except Exception as e2:
                    tried.append((alt, str(e2)))
        if tok is None or mdl is None:
            # surface useful diagnostics
            raise RuntimeError(
                "Failed to load tokenizer/model. Attempts:\n" +
                "\n".join([f" - {mid}: {err}" for mid, err in tried])
            )

    mdl = mdl.eval().to(device)
    if device == "cuda":
        mdl.half()
    return tok, mdl, model_id

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_id = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    # Use add_prefix_space when supported (GPT‑2-friendly)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                # Fast tokenizers expose word_ids(); map wordpiece positions back to word indices
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be {'first','last','mean'} (GPT‑2: 'last' or 'mean').")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            # free batch buffers
            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw"):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "length",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"length_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(10.5, 5.5))
    for c in sorted(class_to_stats.keys(), key=lambda s: int(s)):
        stats = class_to_stats[c]
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = palette.get(c) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=6, fontsize="small", title=f"Length ( {LENGTH_MAX_CLASS} = {LENGTH_MAX_CLASS}+ )", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def sample_raw(df_tok: pd.DataFrame, per_class_cap: int = RAW_MAX_PER_CLASS) -> pd.DataFrame:
    picks = []
    for c, sub in df_tok.groupby("length_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_length_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    vals = sorted([int(c) for c in classes])
    cmap = sns.color_palette("rocket", len(vals))
    return {str(v): cmap[i] for i, v in enumerate(vals)}

def run_length_from_col_pipeline():
    # 1) Load token lists + length classes from existing column
    df_sent, len_df = load_length_from_column(
        CSV_PATH, length_max=LENGTH_MAX_CLASS, exclude_zero=EXCLUDE_ZERO_LENGTH
    )
    classes = sorted(len_df.length_class.unique(), key=lambda s: int(s))
    palette = make_length_palette(classes)
    print(f"✓ corpus ready — {len(len_df):,} tokens across length classes {classes}")

    # 2) Optional per-class cap (currently unlimited for fast metrics)
    raw_df = sample_raw(len_df, RAW_MAX_PER_CLASS)
    print("Sample sizes per length (raw cap):")
    counts = raw_df.length_class.value_counts().sort_index()
    print(counts.to_dict())

    # 3) Embed once
    reps, filled = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.length_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric loop (FULL set)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        # Save + plot immediately
        save_metric_csv_all_classes(metric, class_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                            out_path=PLOT_DIR / f"length_raw_{metric}_{BASELINE}.png",
                            palette=palette)
        print(f"  ✓ saved: CSV= {CSV_DIR}/length_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/length_raw_{metric}_{BASELINE}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_length_from_col_pipeline()


✓ corpus ready — 194,916 tokens across length classes ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
Sample sizes per length (raw cap):
{'1': 29626, '10': 7492, '2': 29918, '3': 36097, '4': 33008, '5': 18892, '6': 14274, '7': 12188, '8': 7895, '9': 5526}


bert-base-uncased (embed subset): 100%|██| 10067/10067 [01:21<00:00, 123.64it/s]


✓ embedded 194,916 tokens  • layers=13

→ Computing metric: gride …
  ✓ saved: CSV= tables_LENGTH/length_bootstrap/length_raw_gride_bert-base-uncased.csv  plot= results_LENGTH/length_raw_gride_bert-base-uncased.png

✓ done (incremental outputs produced per metric).


In [2]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel
sns.set_style("darkgrid")
# ============== Optional deps (gracefully skipped if not installed) ==============
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            return float(np.clip(ev.mean() / (ev[-1] + 1e-9), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH      = "en_ewt-ud-train_sentences.csv"
LENGTH_COL    = "length"

BASELINE      = "gpt2"
WORD_REP_MODE = "last"

RAW_MAX_PER_CLASS = 253_771
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 200
FAST_BS_MAX_SAMP_PER_CLASS  = 184_870
HEAVY_BS_MAX_SAMP_PER_CLASS = 5000

LENGTH_MAX_CLASS   = 10
EXCLUDE_ZERO_LENGTH = True

# NEW: drop tokens at 1-based sentence index = 1 (the first word)
# UD CoNLL-U IDs start at 1; our 'wid' below is 0-based, so index==1 <=> wid==0
EXCLUDE_INDEX_1 = True

RAND_SEED=42
PLOT_DIR = Path("results_LENGTH_no_index"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR  = Path("tables_LENGTH_no_index") / "length_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 1
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

# ========= Per-subsample single-value compute functions (your original set) =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + 1e-9))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng()
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng()
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ~ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast, spectral) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

def _skdim_factory(name: str):
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca95": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "ess": ESS, "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None
    def _builder():
        if name == "lpca":   return cls(ver="FO")
        if name == "lpca95": return cls(ver="ratio", alphaRatio=0.95)
        if name == "lpca99": return cls(ver="ratio", alphaRatio=0.99)
        return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None
    def _once(X: np.ndarray) -> float:
        est = build()
        est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once

# Registries (FULL set preserved)
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    # Isotropy
    "iso": _iso_once, "spect": _spect_once, "rand": _rand_once,
    "sf": _sf_once, "pfI": _pfI_once, "vmf_kappa": _vmf_kappa_once,
    # Linear ID
    "pca95": _pca95_once, "pca99": _pca99_once,
    "erank": _erank_once, "pr": _pr_once, "stable_rank": _stable_rank_once,
}
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    # Non-linear / local
    "twonn": _dadapy_twonn_once, "gride": _dadapy_gride_once,
    "mom": _skdim_once_builder("mom"), "tle": _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"), "fishers": _skdim_once_builder("fishers"),
    "lpca": _skdim_once_builder("lpca"), "lpca95": _skdim_once_builder("lpca95"),
    "lpca99": _skdim_once_builder("lpca99"),
    "mle": _skdim_once_builder("mle"), "danco": _skdim_once_builder("danco"),
    "ess": _skdim_once_builder("ess"), "mind_ml": _skdim_once_builder("mind_ml"),
    "mada": _skdim_once_builder("mada"), "knn": _skdim_once_builder("knn"),
}
LABELS = {
    # Isotropy
    "iso":"IsoScore","spect":"Spectral Ratio","rand":"RandCos |μ|",
    "sf":"Spectral Flatness","pfI":"Partition Isotropy I","vmf_kappa":"vMF κ",
    # Linear ID
    "pca95":"PCs@95%","pca99":"PCs@99%","erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    # Non-linear / local
    "twonn":"TwoNN ID","gride":"GRIDE","mom":"MOM","tle":"TLE","corrint":"CorrInt","fishers":"FisherS",
    "lpca":"lPCA FO","lpca95":"lPCA 0.95","lpca99":"lPCA 0.99",
    "mle":"MLE","danco":"DANCo","ess":"ESS","mind_ml":"MiND-ML","mada":"MADA","knn":"KNN",
}
ALL_METRICS=[  "gride"]
# =============================== DATA: use existing length column ===============================
def _pick_length_col(df: pd.DataFrame) -> str:
    for cand in [LENGTH_COL, "token_length", "lengths", "len", "LEN", "Length"]:
        if cand in df.columns:
            return cand
    raise ValueError(f"No length column found. Tried: {LENGTH_COL}, token_length, lengths, len, LEN, Length.")

def load_length_from_column(csv_path: str,
                            length_max: int = LENGTH_MAX_CLASS,
                            exclude_zero: bool = EXCLUDE_ZERO_LENGTH):
    """
    CSV expects:
      sentence_id (str), tokens (list[str]), length (list[int]) per sentence.
    Emits token-level rows with 'length_class' in {1..length_max} (length_max means 10+ by default).
    Drops tokens whose per-sentence position is 1 (first word) if EXCLUDE_INDEX_1=True.
    """
    df_all = pd.read_csv(csv_path)
    len_col = _pick_length_col(df_all)
    df = df_all[["sentence_id","tokens", len_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens  = df.tokens.apply(_to_list)
    df[len_col] = df[len_col].apply(_to_list)

    rows = []
    for sid, toks, lens in df[["sentence_id","tokens", len_col]].itertuples(index=False):
        L = min(len(toks), len(lens))
        for wid in range(L):
            # -----------------------------------------------
            # IMPORTANT: drop 1-based index == 1 (first word)
            # If you meant 0-based index==1 (2nd word), change wid==0 -> wid==1 below.
            if EXCLUDE_INDEX_1 and wid == 0:
                continue
            # -----------------------------------------------
            try:
                k = int(lens[wid])
            except Exception:
                continue
            if k <= 0 and exclude_zero:
                continue
            if k < 0:
                continue
            k = min(max(k, 0), length_max)  # bucket upper tail into 'length_max'
            if k == 0 and exclude_zero:
                continue
            rows.append((sid, wid, str(k), toks[wid]))

    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","length_class","word"])

    # Extra safety (redundant with the loop guard but future-proof if code moves):
    if EXCLUDE_INDEX_1 and not df_tok.empty:
        df_tok = df_tok[df_tok.word_id != 0].reset_index(drop=True)

    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    if df_tok.empty:
        raise ValueError("No token rows constructed—check that your length column contains integer lists.")
    return df_sent, df_tok

# =============================== Tokenizer/Model loader (robust for GPT‑2) ===============================
def _load_tok_and_model(baseline: str):
    """
    Robust loader:
    - Prefer fast tokenizers (needed for .word_ids()).
    - For GPT-2: set pad_token to eos_token and right-padding.
    - Fallback to 'openai-community/gpt2' if the plain 'gpt2' entry is unavailable.
    """
    model_id = baseline
    tok = None; mdl = None
    tried = []

    def _try_load(mid: str):
        tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
        # Ensure right padding (GPT-2 tips)
        if getattr(tok, "padding_side", None) != "right":
            tok.padding_side = "right"
        # For GPT-2 and similar: add pad token if missing
        if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
            tok.pad_token = tok.eos_token
        mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
        # Propagate pad id to model if absent
        if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
            mdl.config.pad_token_id = tok.pad_token_id
        return tok, mdl

    # try main id
    try:
        tok, mdl = _try_load(model_id)
    except Exception as e1:
        tried.append((model_id, str(e1)))
        # GPT-2 robust fallback namespace
        if model_id.lower() in {"gpt2", "gpt-2"}:
            for alt in ["openai-community/gpt2", "gpt2"]:
                try:
                    tok, mdl = _try_load(alt)
                    model_id = alt
                    break
                except Exception as e2:
                    tried.append((alt, str(e2)))
        if tok is None or mdl is None:
            # surface useful diagnostics
            raise RuntimeError(
                "Failed to load tokenizer/model. Attempts:\n" +
                "\n".join([f" - {mid}: {err}" for mid, err in tried])
            )

    mdl = mdl.eval().to(device)
    if device == "cuda":
        mdl.half()
    return tok, mdl, model_id

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_id = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    # Use add_prefix_space when supported (GPT‑2-friendly)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                # Fast tokenizers expose word_ids(); map wordpiece positions back to word indices
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be {'first','last','mean'} (GPT‑2: 'last' or 'mean').")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            # free batch buffers
            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw"):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "length",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"length_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(10.5, 5.5))
    for c in sorted(class_to_stats.keys(), key=lambda s: int(s)):
        stats = class_to_stats[c]
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = palette.get(c) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=6, fontsize="small", title=f"Length ( {LENGTH_MAX_CLASS} = {LENGTH_MAX_CLASS}+ )", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def sample_raw(df_tok: pd.DataFrame, per_class_cap: int = RAW_MAX_PER_CLASS) -> pd.DataFrame:
    picks = []
    for c, sub in df_tok.groupby("length_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_length_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    vals = sorted([int(c) for c in classes])
    cmap = sns.color_palette("rocket", len(vals))
    return {str(v): cmap[i] for i, v in enumerate(vals)}

def run_length_from_col_pipeline():
    # 1) Load token lists + length classes from existing column
    df_sent, len_df = load_length_from_column(
        CSV_PATH, length_max=LENGTH_MAX_CLASS, exclude_zero=EXCLUDE_ZERO_LENGTH
    )
    classes = sorted(len_df.length_class.unique(), key=lambda s: int(s))
    palette = make_length_palette(classes)
    print(f"✓ corpus ready — {len(len_df):,} tokens across length classes {classes}")

    # 2) Optional per-class cap (currently unlimited for fast metrics)
    raw_df = sample_raw(len_df, RAW_MAX_PER_CLASS)
    print("Sample sizes per length (raw cap):")
    counts = raw_df.length_class.value_counts().sort_index()
    print(counts.to_dict())

    # 3) Embed once
    reps, filled = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.length_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric loop (FULL set)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        # Save + plot immediately
        save_metric_csv_all_classes(metric, class_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                            out_path=PLOT_DIR / f"length_raw_{metric}_{BASELINE}.png",
                            palette=palette)
        print(f"  ✓ saved: CSV= {CSV_DIR}/length_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/length_raw_{metric}_{BASELINE}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_length_from_col_pipeline()


✓ corpus ready — 184,849 tokens across length classes ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']
Sample sizes per length (raw cap):
{'1': 27854, '10': 7320, '2': 27939, '3': 33973, '4': 31213, '5': 18028, '6': 13729, '7': 11768, '8': 7673, '9': 5352}


openai-community/gpt2 (embed subset): 100%|█| 10067/10067 [01:34<00:00, 107.09it


✓ embedded 184,849 tokens  • layers=13

→ Computing metric: gride …
  ✓ saved: CSV= tables_LENGTH_no_index/length_bootstrap/length_raw_gride_gpt2.csv  plot= results_LENGTH_no_index/length_raw_gride_gpt2.png

✓ done (incremental outputs produced per metric).


# Visualize

In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

from sklearn.decomposition import PCA
import plotly.graph_objects as go
from plotly import colors as pc

CSV_PATH      = "en_ewt-ud-train_sentences.csv"   # must include: sentence_id (str), tokens (list[str]), length (list[int])
LENGTH_COL    = "length"                          # name of the per-sentence list[int] column
BASELINE      = "bert-base-uncased"                            # "gpt2" or "bert-base-uncased" etc.
WORD_REP_MODE = "first"                            # GPT-2: 'last' (or 'mean'); BERT: 'first' (or 'mean')

# Plotting subsample (for browser smoothness)
PLOT_MAX_PER_CLASS = None                         # max points per class used in PCA/plot (None = all)

# Output
OUT_DIR  = Path("pca3d_by_classes"); OUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUT = OUT_DIR / f"{BASELINE.replace('/','_')}_pca3d_by_length_classes.html"

# Repro + device
RAND_SEED = 42
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True

# Length classes 1..10 (with 10 = 10+). 0 or negative can be excluded.
LENGTH_MAX_CLASS      = 10
EXCLUDE_ZERO_LENGTH   = True

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine num_hidden_layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden_size from model.config")
    return int(d)

def _load_tok_and_model(model_id: str):
    """
    Robust loader (fast tokenizer for word_ids; pad-right; GPT-2 pad_token=eos when missing).
    Tries 'openai-community/gpt2' if 'gpt2' fails.
    """
    cands = [model_id]
    if model_id.lower() in {"gpt2", "gpt-2"}:
        cands += ["openai-community/gpt2"]

    last_err = None
    for mid in cands:
        try:
            tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
            if getattr(tok, "padding_side", None) != "right":
                tok.padding_side = "right"
            if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
                tok.pad_token = tok.eos_token
            mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
                mdl.config.pad_token_id = tok.pad_token_id
            if device == "cuda": mdl = mdl.half()
            return tok, mdl.eval().to(device), mid
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to load model/tokenizer for {cands}: {last_err}")

def _pick_length_col(df: pd.DataFrame) -> str:
    for cand in [LENGTH_COL, "token_length", "lengths", "len", "LEN", "Length"]:
        if cand in df.columns:
            return cand
    raise ValueError("No length column found; set LENGTH_COL appropriately.")

def load_length_from_column(csv_path: str,
                            length_max: int = LENGTH_MAX_CLASS,
                            exclude_zero: bool = EXCLUDE_ZERO_LENGTH):
    """
    CSV must have: sentence_id (str), tokens (list[str]), length (list[int]).
    Emits token-level rows with 'length_class' in {1..length_max} (length_max means 10+ by default).
    """
    df_all = pd.read_csv(csv_path)
    len_col = _pick_length_col(df_all)
    df = df_all[["sentence_id","tokens", len_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens  = df.tokens.apply(_to_list)
    df[len_col] = df[len_col].apply(_to_list)

    rows = []
    for sid, toks, lens in df[["sentence_id","tokens", len_col]].itertuples(index=False):
        Ls = min(len(toks), len(lens))
        for wid in range(Ls):
            try:
                k = int(lens[wid])
            except Exception:
                continue
            if k <= 0 and exclude_zero:  # skip non-positive if requested
                continue
            if k < 0:
                continue
            k = min(max(k, 1), length_max)  # bucket upper tail into 'length_max'
            rows.append((sid, wid, str(k), toks[wid]))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","length_class","word"])
    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    if df_tok.empty:
        raise ValueError("No token rows constructed—check that your length column contains integer lists.")
    return df_sent, df_tok

def sample_for_plot(df_tok: pd.DataFrame, per_class_cap: int | None = PLOT_MAX_PER_CLASS) -> pd.DataFrame:
    """Cap per class for visualization to keep Plotly responsive."""
    if per_class_cap is None:
        return df_tok.sample(frac=1.0, random_state=RAND_SEED).reset_index(drop=True)
    picks = []
    for c, sub in df_tok.groupby("length_class", sort=True):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True).sample(frac=1.0, random_state=RAND_SEED).reset_index(drop=True)

def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = 4) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """
    Returns:
      reps   (L, N, D)
      filled (N,) mask
      words  list[str] length N (hover text)
    """
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_tag = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)
    words  = [""] * N

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_tag} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids() unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                toks_for_sent = df_sel.loc[sid, "tokens"]
                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be {'first','last','mean'}")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    words[gidx] = str(toks_for_sent[wid])
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    if (~filled).any():
        missing = int((~filled).sum())
        print(f"⚠ Missing vectors for {missing} tokens (skipped in PCA).")
        reps   = reps[:, filled]
        words  = [w for w, f in zip(words, filled) if f]
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, words

def _make_class_palette(classes: List[str]) -> Dict[str, str]:
    """Return a stable mapping class -> color (hex)."""
    base = pc.qualitative.Plotly + pc.qualitative.Set2 + pc.qualitative.Set3 + pc.qualitative.Alphabet
    colors = (base * ((len(classes) // len(base)) + 1))[:len(classes)]
    def _to_hex(c):
        if isinstance(c, str): return c
        r, g, b = c
        return f"rgb({int(r*255)},{int(g*255)},{int(b*255)})"
    return {cls: _to_hex(colors[i]) for i, cls in enumerate(classes)}

# =============================== PCA + PLOTLY (by class, per layer) ===============================
def pca3d_per_layer_by_class(reps: np.ndarray,
                             words: List[str],
                             class_arr: np.ndarray,
                             classes: List[str],
                             model_tag: str):
    """
    reps:      (L, N, D)
    words:     list[str] length N (hover)
    class_arr: shape (N,) with class labels as strings (e.g., "1","2",...,"10")
    classes:   sorted list of unique class labels (strings)
    """
    L, N, D = reps.shape
    rng = np.random.default_rng(RAND_SEED)

 

    palette = _make_class_palette(classes)

    # Pre-split indices by class for quick masking
    cls_to_idx = {c: np.where(class_arr == c)[0] for c in classes}

    Y_layers: List[np.ndarray] = []
    for l in range(L):
        X = reps[l].astype(np.float32, copy=False)
        X = X - X.mean(0, keepdims=True)
        pca = PCA(n_components=3, random_state=RAND_SEED)
        Y = pca.fit_transform(X)  # (N, 3)
        Y_layers.append(Y)

    traces = []
    for l in range(L):
        Y = Y_layers[l]
        for ci, c in enumerate(classes):
            idx = cls_to_idx[c]
            if idx.size == 0:
                traces.append(go.Scatter3d(
                    x=[], y=[], z=[], mode="markers",
                    marker=dict(size=2, opacity=0.75, color=palette[c]),
                    text=[], name=f"{c}",
                    legendgroup=c, showlegend=(l == 0), visible=(l == 0)
                ))
                continue

            traces.append(
                go.Scatter3d(
                    x=Y[idx, 0], y=Y[idx, 1], z=Y[idx, 2],
                    mode="markers",
                    marker=dict(size=2, opacity=0.75, color=palette[c]),
                    text=[f"{words[i]} | len={c}" for i in idx],
                    hovertemplate=(
                        "<b>%{text}</b><br>"
                        "PC1=%{x:.3f}<br>PC2=%{y:.3f}<br>PC3=%{z:.3f}"
                        f"<extra>Layer {l} • class {c}</extra>"
                    ),
                    name=f"{c}",
                    legendgroup=c,         
                    showlegend=(l == 0),    

                )
            )

    steps = []
    traces_per_layer = len(classes)
    total_traces = L * traces_per_layer
    for l in range(L):
        vis = [False] * total_traces
        start = l * traces_per_layer
        for k in range(traces_per_layer):
            vis[start + k] = True
        steps.append(dict(
            method="update",
            args=[{"visible": vis},
                  {"title": f"{model_tag} • PCA 3D by length class • layer {l} (drag to rotate)"}],
            label=str(l),
        ))

    sliders = [dict(
        active=0,
        steps=steps,
        currentvalue={"prefix": "Layer: "},
        pad={"t": 10}
    )]

    layout = go.Layout(
        title=f"{model_tag} • PCA 3D by length class • layer 0 (drag to rotate)",
        scene=dict(
            xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3",
            aspectmode="data",
        ),
        margin=dict(l=0, r=0, b=0, t=40),
        sliders=sliders,
        showlegend=True
    )

    fig = go.Figure(data=traces, layout=layout)
    fig.show()
    fig.write_html(str(HTML_OUT), include_plotlyjs="cdn")
    print("✓ Saved interactive HTML to:", HTML_OUT)




def run_pca3d_by_length_classes():
    # 1) Load tokens + length classes
    df_sent, df_tok = load_length_from_column(CSV_PATH, length_max=LENGTH_MAX_CLASS, exclude_zero=EXCLUDE_ZERO_LENGTH)
    classes = sorted(df_tok.length_class.unique(), key=lambda s: int(s))
    print(f"✓ corpus ready — {len(df_tok):,} tokens across length classes {classes}")

    plot_df = sample_for_plot(df_tok, PLOT_MAX_PER_CLASS)
    print("Plot sample sizes per length (cap):")
    print(plot_df.length_class.value_counts().sort_index().to_dict())

    reps, filled, words = embed_subset(df_sent, plot_df, BASELINE, WORD_REP_MODE, batch_size=4)
    plot_df = plot_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    class_arr = plot_df.length_class.values
    L = reps.shape[0]
    print(f"✓ embedded {len(plot_df):,} tokens  • layers={L}")

    # 4) PCA 3D + Plotly
    pca3d_per_layer_by_class(reps, words, class_arr, classes, model_tag=BASELINE)

    # Cleanup
    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()

if __name__ == "__main__":
    run_pca3d_by_length_classes()
