In [2]:
from __future__ import annotations
import os, ast, random, inspect
from pathlib import Path
from typing import Dict, List
import numpy as np, pandas as pd, torch
import torch.utils.data as torchdata
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoTokenizer, AutoModel, BertModel, AutoConfig
from IsoScore import IsoScore
from dadapy import Data
from skdim.id import MLE, MOM, TLE, CorrInt, FisherS, lPCA

# Bert POS

In [1]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

HAS_DADAPY = False
try:
    from dadapy import Data  
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean/peak eigenvalue ratio in [0,1]; higher ≈ more isotropic
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

CSV_PATH   = "en_ewt-ud-train_sentences.csv"
BASELINE   = "bert-base-uncased"
WORD_REP_MODE = "first"     
EXCLUDE_POS = {"X", "SYM", "PART", "INTJ"}
RAW_MAX_PER_POS = int(1e12)         





N_BOOTSTRAP_FAST   = 50         
N_BOOTSTRAP_HEAVY  = 200          

FAST_BS_MAX_SAMP_PER_POS  = int(1e12)   
HEAVY_BS_MAX_SAMP_PER_POS = 5000       
# GRIDE multi-scale max neighbor rank
DADAPY_GRID_RANGE_MAX = 64             # 32–128 is typical
RAND_SEED=42
PLOT_DIR     = Path("results_POS"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables_POS") / "pos_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Throughput: start higher than 1 unless GPU is tiny
BATCH_SIZE = 1                         # try 8 → 16 → 32; back off if OOM

# Reproducibility & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")

plt.rcParams["figure.dpi"] = 120


# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

# ========= Per-subsample single-value compute functions (used inside bootstrap) =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + 1e-9))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng()
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng()
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ≈ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    # standard closed-form approximation
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=DADAPY_GRID_RANGE_MAX)
    ids, _, _ = d.return_id_scaling_gride(range_max=DADAPY_GRID_RANGE_MAX)
    return float(ids[-1])

def _skdim_factory(name: str):
    """Return a factory that builds a fresh skdim estimator each call, or None."""
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo,  "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None

    def _builder():
        if name == "lpca":      # FO variant
            return cls(ver="FO")
        elif name == "lpca99":  # ratio (0.99) variant
            return cls(ver="ratio", alphaRatio=0.99)
        else:
            return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None

    def _once(X: np.ndarray) -> float:
        est = build()
        est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once


# =============================== DATA ===============================
def load_word_df(csv_path: str, exclude_pos: set[str] = EXCLUDE_POS):
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)  # keep string IDs
    df.tokens = df.tokens.apply(_to_list); df.pos = df.pos.apply(_to_list)

    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        for wid, (tok, p) in enumerate(zip(toks, poss)):
            if p not in exclude_pos:
                rows.append((sid, wid, p, tok))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id","pos","word"])
    return df, word_df

def sample_raw(word_df: pd.DataFrame, per_pos_cap: int = RAW_MAX_PER_POS) -> pd.DataFrame:
    """Per-POS cap without frequency matching."""
    picks = []
    for p, sub in word_df.groupby("pos", sort=False):
        n = min(len(sub), per_pos_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)


def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    """
    Return a deterministic mapping {class -> RGB tuple} with as many distinct
    qualitative colors as needed. Up to ~60 unique colors without reuse.
    """
    # Try three matplotlib tab palettes first (20 + 20 + 20)
    base_colors: List[Tuple[float, float, float]] = []
    for name in ("tab20", "tab20b", "tab20c"):
        try:
            base_colors.extend(sns.color_palette(name, 20))
        except Exception:
            pass

    # If classes exceed our pool, fall back to evenly spaced hues
    if len(base_colors) < len(classes):
        base_colors = list(sns.color_palette("husl", len(classes)))  # evenly spaced hues

    # Deterministic order (sorted) -> stable color assignment
    ordered = list(sorted(classes))
    return {cls: base_colors[i % len(base_colors)] for i, cls in enumerate(ordered)}

# =============================== EMBEDDING ===============================
def embed_subset(df_all_sentences: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    """Return reps (L,N,D) and filled mask (N,) for the selected tokens."""
    df_all_sentences["sentence_id"] = df_all_sentences["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    # Materialize in the exact order we will batch
    sids = list(by_sid.keys())
    df_sel = (df_all_sentences[df_all_sentences.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True, add_prefix_space=True)
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    if device == "cuda":
        model.half()

    L = model.config.num_hidden_layers + 1
    D = model.config.hidden_size
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{baseline} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # word_id -> token positions for this item
                mp = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    else:
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            # free batch buffers
            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled


# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# fast metric registry (name -> callable(X)->float)
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,
    "spect": _spect_once,
    "rand": _rand_once,
    "sf": _sf_once,
    "vmf_kappa": _vmf_kappa_once,
    "erank": _erank_once,
    "pr": _pr_once,
    "stable_rank": _stable_rank_once,
}

# heavy metric registry
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "twonn": _dadapy_twonn_once,
    "gride": _dadapy_gride_once,
    "mom":   _skdim_once_builder("mom"),
    "tle":   _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"),
    "fishers": _skdim_once_builder("fishers"),
    "lpca":  _skdim_once_builder("lpca"),
    "lpca95": _skdim_once_builder("lpca95"),
    "lpca99": _skdim_once_builder("lpca99"),
    "mle":   _skdim_once_builder("mle"),
    "mada":  _skdim_once_builder("mada"),
    "knn":   _skdim_once_builder("knn"),
}

LABELS = {
    # Isotropy
    "iso":"IsoScore","spect":"Spectral Ratio","rand":"RandCos |μ|",
    "sf":"Spectral Flatness","vmf_kappa":"vMF κ",
    # Linear ID
    "erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    "lpca95":"lPCA95","lpca99":"lPCA99","lpca":"lPCA FO",
    # Non-linear
    "twonn":"TwoNN ID","gride":"GRIDE",
    "mom":"MOM","tle":"TLE","corrint":"CorrInt",
    "fishers":"FisherS",
    "mle":"MLE","mada":"MADA","knn":"KNN",
}

# Choose plotting order
PLOT_ORDER = (
    "iso","sf","vmf_kappa","spect","rand",
    "erank","pr","stable_rank","lpca95","lpca99","lpca",
    "twonn","gride","mom","tle","corrint","fishers",
    "mle","mada","knn"
)

# metrics you want to compute (you can prune this list to reduce runtime)
#ALL_METRICS = list(PLOT_ORDER)
ALL_METRICS = ["gride"]


# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_pos(metric: str,
                            pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                            layers: np.ndarray,
                            baseline: str,
                            subset_name: str = "raw"):
    rows = []
    for p, stats in pos_to_stats.items():
        mu, lo, hi, n = stats["mean"], stats.get("lo"), stats.get("hi"), stats.get("n", np.nan)
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "pos",
                "class": p, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"pos_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(9, 5))
    for p, stats in pos_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): 
            continue
        color = palette.get(p) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=p, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)

    # Make the legend compact if there are many classes
    n_classes = len(pos_to_stats)
    ncol = 3 if n_classes > 12 else 2
    plt.legend(ncol=ncol, fontsize="small", title="POS", frameon=False)

    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()



# =============================== DRIVER ===============================
def run_pos_pipeline():
    # 1) Load
    df_all, word_df = load_word_df(CSV_PATH, EXCLUDE_POS)
    POS_TAGS = sorted(word_df.pos.unique())
    palette = make_class_palette(POS_TAGS)
    print(f"✓ corpus ready — {len(word_df):,} tokens across {len(POS_TAGS)} POS")
    print(f"• DADApy: {'available' if HAS_DADAPY else 'missing'}  • scikit-dimension: {'available' if HAS_SKDIM else 'missing'}")

    # 2) Raw sampling only (no frequency match)
    raw_df = sample_raw(word_df, RAW_MAX_PER_POS)
    print("Sample sizes per POS (raw cap):")
    print(raw_df.pos.value_counts().to_dict())

    # 3) Embed once
    reps, filled = embed_subset(df_all, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    pos_arr = raw_df.pos.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric-by-metric loop (incremental outputs)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")

        # Decide registry & bootstrap settings
        if metric in FAST_ONCE:
            compute_once = FAST_ONCE[metric]
            n_bs = N_BOOTSTRAP_FAST
            Mcap = FAST_BS_MAX_SAMP_PER_POS
        else:
            compute_once = HEAVY_ONCE.get(metric)
            n_bs = N_BOOTSTRAP_HEAVY
            Mcap = HEAVY_BS_MAX_SAMP_PER_POS

        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # Per-POS bootstrap
        metric_results: Dict[str, Dict[str, np.ndarray]] = {}
        for p in POS_TAGS:
            idx = np.where(pos_arr == p)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_p, D)
            Np = sub.shape[1]
            M = min(Mcap, Np)

            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            metric_results[p] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Np)}

        # Save + plot immediately for this metric
        save_metric_csv_all_pos(metric, metric_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(metric_results, layers, metric,
                    title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                    out_path=PLOT_DIR / f"raw_{metric}_{BASELINE}.png",
                    palette=palette)

        print(f"  ✓ saved: CSV= tables/pos_bootstrap/pos_raw_{metric}_{BASELINE}.csv  "
              f"plot= AO_POS/raw_{metric}_{BASELINE}.png")

        # light cleanup for safety
        del metric_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    # Cleanup
    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")


if __name__ == "__main__":
    run_pos_pipeline()


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f500151a380>>
Traceback (most recent call last):
  File "/home/ldomenichelli/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



In [2]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import matplotlib.pyplot as plt  # only for color palettes

CSV_PATH   = "en_ewt-ud-train_sentences.csv"   

BASELINE      = "gpt2"          
WORD_REP_MODE = "last"                        

# Plotting / sampling
PCA_PER_CLASS_MAX_POINTS = 3000              

# Throughput + device
BATCH_SIZE = 2
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True

# Output
OUT_DIR = Path("pca3d_pos_classes"); OUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUT = OUT_DIR / f"{BASELINE.replace('/','_')}_pca3d_pos_classes.html"

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

def _load_tok_and_model(model_id: str):
    """
    Robust loader (BERT/GPT‑2):
      - fast tokenizer for .word_ids()
      - right padding
      - set PAD=EOS for GPT‑like models
    """
    tried = []
    def _try(mid: str):
        tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
        if getattr(tok, "padding_side", None) != "right":
            tok.padding_side = "right"
        if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
            tok.pad_token = tok.eos_token
        mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
        if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
            mdl.config.pad_token_id = tok.pad_token_id
        return tok, mdl

    order = [model_id]
    if model_id.lower() in {"gpt2", "gpt-2"}:
        order += ["openai-community/gpt2", "gpt2"]  # tolerate both namespaces

    last_err = None
    for mid in order:
        try:
            tok, mdl = _try(mid)
            mdl = mdl.eval().to(device)
            if device == "cuda": mdl.half()
            return tok, mdl, mid
        except Exception as e:
            tried.append((mid, repr(e))); last_err = e

    raise RuntimeError("Could not load tokenizer/model. Attempts:\n" +
                       "\n".join(f" - {m}: {err}" for m, err in tried)) from last_err

# =============================== DATA (POS classes) ===============================
def load_pos_tokens(csv_path: str, exclude: set[str] | None = None):
    """
    Expects per-sentence: tokens (list[str]) and pos (list[str]).
    Returns:
      df_sent: sentence_id, tokens
      df_tok : per-token rows with columns [sentence_id, word_id, pos, word]
    """
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df.pos    = df.pos.apply(_to_list)
    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        L = min(len(toks), len(poss))
        for wid in range(L):
            p = str(poss[wid])
            if exclude and p in exclude:
                continue
            rows.append((sid, wid, p, str(toks[wid])))
    df_tok  = pd.DataFrame(rows, columns=["sentence_id","word_id","pos","word"])
    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    if df_tok.empty:
        raise ValueError("No token rows constructed—check POS column content.")
    return df_sent, df_tok

# =============================== EMBEDDING ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str]:
    """
    Returns:
      reps   (L, N, D)
      filled (N,)
      model_tag (resolved id)
    """
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_id = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # map word_id -> token positions
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids() unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be {'first','last','mean'}")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, model_id

# =============================== COLORS ===============================
def _class_palette(classes: List[str]) -> Dict[str, str]:
    """
    Distinct hex colors for many POS tags: use tab20+tab20b+tab20c (~60 colors), then hsv fallback.
    """
    colors = []
    for name in ("tab20", "tab20b", "tab20c"):
        try:
            colors.extend(plt.get_cmap(name).colors)
        except Exception:
            pass
    if len(colors) < len(classes):
        hsv = plt.get_cmap("hsv")
        colors = [hsv(i/len(classes)) for i in range(len(classes))]
    to_hex = lambda rgb: "#{:02x}{:02x}{:02x}".format(int(rgb[0]*255), int(rgb[1]*255), int(rgb[2]*255))
    ordered = list(sorted(classes))  # alphabetical for stability
    return {cls: to_hex(colors[i % len(colors)]) for i, cls in enumerate(ordered)}

# =============================== PCA + PLOTLY BY CLASS ===============================
def pca3d_layers_by_class(reps: np.ndarray,
                          words: List[str],
                          class_arr: np.ndarray,
                          classes: List[str],
                          model_tag: str,
                          html_out: Path):
    """
    reps: (L, N, D)
    words: list[str] length N (hover text)
    class_arr: array length N with POS labels (strings)
    classes: list of unique labels
    """
    L, N, D = reps.shape
    classes_sorted = list(sorted(classes))
    palette = _class_palette(classes_sorted)

    # Build consistent subset across layers: sample up to cap per class
    rng = np.random.default_rng(RAND_SEED)
    sel_idx: List[int] = []
    for c in classes_sorted:
        idx_c = np.where(class_arr == c)[0]
        if PCA_PER_CLASS_MAX_POINTS is not None and len(idx_c) > PCA_PER_CLASS_MAX_POINTS:
            idx_c = rng.choice(idx_c, size=PCA_PER_CLASS_MAX_POINTS, replace=False)
        sel_idx.extend(idx_c.tolist())
    sel_idx = np.array(sel_idx, dtype=np.int64)
    if sel_idx.size == 0:
        raise ValueError("No points selected for PCA plotting (check POS tags and caps).")

    # Per-class local positions within sel_idx
    cls_pos: Dict[str, np.ndarray] = {c: np.where(class_arr[sel_idx] == c)[0] for c in classes_sorted}

    reps_sel = reps[:, sel_idx, :].astype(np.float32, copy=False)
    words_sel = [words[i] for i in sel_idx]

    # PCA per layer
    Y_layers: List[np.ndarray] = []
    for l in range(L):
        X = reps_sel[l]  # (n_sel, D)
        Xc = X - X.mean(0, keepdims=True)
        pca = PCA(n_components=3, random_state=RAND_SEED)
        Y = pca.fit_transform(Xc)  # (n_sel, 3)
        Y_layers.append(Y)

    # Build traces: (layer, class) grid
    traces = []
    for l in range(L):
        Y = Y_layers[l]
        for c in classes_sorted:
            pos = cls_pos[c]
            if pos.size == 0:
                continue
            # customdata for stable hover: [class, layer]
            custom = np.column_stack([np.full(pos.size, c, dtype=object),
                                      np.full(pos.size, l, dtype=int)])
            traces.append(
                go.Scatter3d(
                    x=Y[pos, 0], y=Y[pos, 1], z=Y[pos, 2],
                    mode="markers",
                    marker=dict(size=2, opacity=0.75, color=palette[c]),
                    text=[words_sel[i] for i in pos],
                    customdata=custom,
                    hovertemplate=(
                        "<b>%{text}</b>"
                        "<br>POS=%{customdata[0]} • layer=%{customdata[1]}"
                        "<br>x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}"
                        "<extra></extra>"
                    ),
                    name=f"{c}",
                    visible=(l == 0),
                    showlegend=(l == 0)  # legend only once
                )
            )

    # Slider toggles visibility of all class traces for a given layer
    traces_per_layer = len(traces) // L if L > 0 else 0
    steps = []
    for l in range(L):
        vis = [False]*len(traces)
        start = l * traces_per_layer
        for k in range(traces_per_layer):
            if start + k < len(traces):
                vis[start + k] = True
        steps.append(dict(
            method="update",
            args=[{"visible": vis},
                  {"title": f"{model_tag} • PCA 3D by POS • layer {l} (drag to rotate)"}],
            label=str(l),
        ))

    sliders = [dict(
        active=0,
        steps=steps,
        currentvalue={"prefix": "Layer: "},
        pad={"t": 10}
    )]

    layout = go.Layout(
        title=f"{model_tag} • PCA 3D by POS • layer 0 (drag to rotate)",
        scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3", aspectmode="data"),
        margin=dict(l=0, r=0, b=0, t=40),
        sliders=sliders,
        showlegend=True
    )

    fig = go.Figure(data=traces, layout=layout)
    fig.write_html(str(html_out), include_plotlyjs="cdn")
    print("✓ Saved interactive HTML to:", html_out)

# =============================== DRIVER ===============================
def run_pca3d_pos_classes():
    # 1) Load POS tokens
    #    (Optionally exclude certain POS by passing exclude={"X","SYM","PART","INTJ"} etc.)
    df_sent, df_tok = load_pos_tokens(CSV_PATH, exclude=None)
    pos_tags = sorted(df_tok.pos.unique().tolist())
    print(f"✓ corpus ready — {len(df_tok):,} tokens across POS={pos_tags}")

    # 2) (Optional) cap per POS for speed happens inside the PCA function; here we keep all
    subset_df = df_tok[["sentence_id","word_id","pos","word"]].copy()

    # 3) Embed once
    reps, filled, resolved_model = embed_subset(df_sent, subset_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    subset_df = subset_df.reset_index(drop=True).loc[filled].reset_index(drop=True)

    # 4) Collect hover text + labels
    words = subset_df["word"].astype(str).tolist()
    cls_arr = subset_df["pos"].astype(str).values
    classes = sorted(subset_df["pos"].unique().tolist())

    # 5) PCA+Plotly
    pca3d_layers_by_class(reps, words, cls_arr, classes, model_tag=resolved_model, html_out=HTML_OUT)

    # Cleanup
    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()

if __name__ == "__main__":
    run_pca3d_pos_classes()


✓ corpus ready — 194,916 tokens across POS=['ADJ', 'ADP', 'ADV', 'AUX', 'CCONJ', 'DET', 'INTJ', 'NOUN', 'NUM', 'PART', 'PRON', 'PROPN', 'PUNCT', 'SCONJ', 'SYM', 'VERB', 'X']



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

openai-community/gpt2 (embed subset): 100%|█| 5034/5034 [01:05<00:00, 77.31it/s]


✓ Saved interactive HTML to: pca3d_pos_classes/gpt2_pca3d_pos_classes.html


In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# ============== Optional deps (gracefully skipped if not installed) ==============
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: use library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean/peak eigenvalue ratio in [0,1]; higher ≈ more isotropic
            return float(np.clip(ev.mean() / (ev[-1] + 1e-12), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH        = "en_ewt-ud-train_sentences.csv"
BASELINE        = "gpt2"                 # ← GPT‑2
WORD_REP_MODE   = "last"                 # ← {"last","mean"} word representation
EXCLUDE_POS     = {"X", "SYM", "PART", "INTJ"}

# Sampling cap per POS for plotting (increase if you want)

RAW_MAX_PER_POS = int(1e12)         # effectively no cap

# Bootstrap replicates
N_BOOTSTRAP_FAST   = 50         # good CIs without going overboard
N_BOOTSTRAP_HEAVY  = 200            # keep heavy bootstrap moderate

# Per-replicate sample size M (min(cap, N_pos))
FAST_BS_MAX_SAMP_PER_POS  = int(1e12)   # => M = N_pos (the classic bootstrap uses M=N)
HEAVY_BS_MAX_SAMP_PER_POS = 5000        # 1000–5000 is a practical range for TwoNN/GRIDE/skdim

DADAPY_GRID_RANGE_MAX     = 64

# Runtime / output
BATCH_SIZE   = 8
RAND_SEED    = 42
PLOT_DIR     = Path("AO_POS"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables") / "pos_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Isotropy extras
PFI_DIRS = 256
PFI_Q_LO = 5.0
PFI_Q_HI = 95.0
EPS      = 1e-12

# Repro & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-7) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

# ========= Single-shot compute functions used inside bootstrap =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + 1e-9))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng()
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng()
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ≈ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=DADAPY_GRID_RANGE_MAX)
    ids, _, _ = d.return_id_scaling_gride(range_max=DADAPY_GRID_RANGE_MAX)
    return float(ids[-1])

def _skdim_factory(name: str):
    """Return a factory that builds a fresh skdim estimator each call, or None."""
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN, "ess": ESS,
    }
    cls = mapping.get(name)
    if cls is None: return None

    def _builder():
        if name == "lpca":      # FO variant
            return cls(ver="FO")
        elif name == "lpca99":  # ratio (0.99) variant
            return cls(ver="ratio", alphaRatio=0.99)
        else:
            return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None

    def _once(X: np.ndarray) -> float:
        est = build()
        est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once

# =============================== DATA ===============================
def load_word_df(csv_path: str, exclude_pos: set[str] = EXCLUDE_POS):
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list); df.pos = df.pos.apply(_to_list)

    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        for wid, (tok, p) in enumerate(zip(toks, poss)):
            if p not in exclude_pos:
                rows.append((sid, wid, p, tok))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id","pos","word"])
    return df, word_df

def sample_raw(word_df: pd.DataFrame, per_pos_cap: int = RAW_MAX_PER_POS) -> pd.DataFrame:
    """Per-POS cap without frequency matching."""
    picks = []
    for p, sub in word_df.groupby("pos", sort=False):
        n = min(len(sub), per_pos_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)




# =============================== EMBEDDING (GPT‑2) ===============================
def _num_hidden_layers(model) -> int:
    """Works for GPT‑2 (n_layer) and BERT-like (num_hidden_layers)."""
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None:
        n = getattr(model.config, "n_layer", None)
    if n is None:
        raise ValueError("Cannot determine number of hidden layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None:
        d = getattr(model.config, "n_embd", None)  # GPT‑2
    if d is None:
        raise ValueError("Cannot determine hidden size from model.config")
    return int(d)

def embed_subset(df_all_sentences: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    """
    Return reps (L,N,D) and filled mask (N,) for the selected tokens.
    Uses word_ids() to map subword tokens back to original word indices.
    For GPT‑2, we (1) set add_prefix_space=True and (2) map pad_token→eos_token.
    """
    df_all_sentences["sentence_id"] = df_all_sentences["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    # Materialize in batch order
    sids = list(by_sid.keys())
    df_sel = (df_all_sentences[df_all_sentences.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    # GPT‑2 fast tokenizer with prefix space (needed with is_split_into_words)
    tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True, add_prefix_space=True)

    # Ensure padding works: GPT‑2 has no pad token by default
    if tokzr.pad_token is None:
        tokzr.pad_token = tokzr.eos_token

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    # (Fast tokenizers support `word_ids`; we already set add_prefix_space=True above)

    # Model
    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    # Make sure model's pad id is set (decoder-only models will ignore attention on padding anyway)
    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id
    if device == "cuda":
        model.half()

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float32)   # keep float32 to avoid duplicate-row issues in KNN metrics
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{baseline} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)            # fast tokenizer
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            # hidden_states: tuple len=L of (B,T,D)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # word_id -> token positions for this item
                mp: Dict[int, List[int]] = {}
                # word_ids(b) works only with fast tokenizers
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]          # last subtoken
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1) # mean over subtokens
                    else:  # fallback to last
                        vec = h[:, b, toks[-1], :]
                    reps[:, gidx, :] = vec.astype(np.float32, copy=False)
                    filled[gidx] = True

            # free batch buffers
            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# fast metric registry
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,
    "spect": _spect_once,
    "rand": _rand_once,
    "sf": _sf_once,
    "pfI": _pfI_once,
    "vmf_kappa": _vmf_kappa_once,
    "erank": _erank_once,
    "pr": _pr_once,
    "stable_rank": _stable_rank_once,
    "pca95": _pca95_once,
    "pca99": _pca99_once,
}

# heavy metric registry
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "twonn": _dadapy_twonn_once,
    "gride": _dadapy_gride_once,
    "mom":   _skdim_once_builder("mom"),
    "tle":   _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"),
    "fishers": _skdim_once_builder("fishers"),
    "lpca":  _skdim_once_builder("lpca"),
    "lpca95": _skdim_once_builder("lpca99"),
    "lpca99": _skdim_once_builder("lpca99"),
    "mle":   _skdim_once_builder("mle"),
    "danco": _skdim_once_builder("danco"),
    "ess":   _skdim_once_builder("ess"),
    "mada":  _skdim_once_builder("mada"),
    "knn":   _skdim_once_builder("knn"),
}

LABELS = {
    # Isotropy
    "iso":"IsoScore","sf":"Spectral Flatness","pfI":"Partition Isotropy I",
    "vmf_kappa":"vMF κ","spect":"Spectral Ratio","rand":"RandCos |μ|",
    # Linear ID
    "erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    # Non-linear
    "twonn":"TwoNN ID","gride":"GRIDE",
    "mom":"MOM","tle":"TLE","corrint":"CorrInt","fishers":"FisherS",
    "lpca":"lPCA FO","lpca99":"lPCA 0.99","lpca95":"lPCA 0.95", "mle":"MLE","danco":"DANCo",
    "ess":"ESS","mada":"MADA","knn":"KNN",
}

# Choose plotting order
PLOT_ORDER = (
    "iso","sf","pfI","vmf_kappa","spect","rand",
    "erank","pr","stable_rank","pca95","pca99",
    "twonn","gride","mom","tle","corrint","fishers","lpca","lpca99",
    "mle","danco","ess","mada","knn"
)
#ALL_METRICS = list(PLOT_ORDER)
ALL_METRICS =["gride"]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_pos(metric: str,
                            pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                            layers: np.ndarray,
                            baseline: str,
                            subset_name: str = "raw"):
    rows = []
    for p, stats in pos_to_stats.items():
        mu, lo, hi, n = stats["mean"], stats.get("lo"), stats.get("hi"), stats.get("n", np.nan)
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "pos",
                "class": p, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"pos_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path):
    plt.figure(figsize=(9, 5))
    for p, stats in pos_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        plt.plot(layers, mu, label=p, lw=1.8)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=2, fontsize="small", title="POS", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_pos_pipeline():
    # 1) Load
    df_all, word_df = load_word_df(CSV_PATH, EXCLUDE_POS)
    POS_TAGS = sorted(word_df.pos.unique())
    print(f"✓ corpus ready — {len(word_df):,} tokens across {len(POS_TAGS)} POS")
    print(f"• DADApy: {'available' if HAS_DADAPY else 'missing'}  • scikit-dimension: {'available' if HAS_SKDIM else 'missing'}")

    # 2) Raw sampling only (no frequency match)
    raw_df = sample_raw(word_df, RAW_MAX_PER_POS)
    print("Sample sizes per POS (raw cap):")
    print(raw_df.pos.value_counts().to_dict())

    # 3) Embed once (GPT‑2, last/mean subtoken per word)
    reps, filled = embed_subset(df_all, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    pos_arr = raw_df.pos.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric-by-metric loop (incremental outputs)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")

        # Decide registry & bootstrap settings
        if metric in FAST_ONCE:
            compute_once = FAST_ONCE[metric]
            n_bs = N_BOOTSTRAP_FAST
            Mcap = FAST_BS_MAX_SAMP_PER_POS
        else:
            compute_once = HEAVY_ONCE.get(metric)
            n_bs = N_BOOTSTRAP_HEAVY
            Mcap = HEAVY_BS_MAX_SAMP_PER_POS

        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # Per-POS bootstrap
        metric_results: Dict[str, Dict[str, np.ndarray]] = {}
        for p in POS_TAGS:
            idx = np.where(pos_arr == p)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_p, D)
            Np = sub.shape[1]
            M = min(Mcap, Np)

            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            metric_results[p] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Np)}

        # Save + plot immediately for this metric
        save_metric_csv_all_pos(metric, metric_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(metric_results, layers, metric,
                    title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                    out_path=PLOT_DIR / f"raw_{metric}_{BASELINE}.png")
        print(f"  ✓ saved: CSV= tables/pos_bootstrap/pos_raw_{metric}_{BASELINE}.csv  "
              f"plot= AO_POS/raw_{metric}_{BASELINE}_{WORD_REP_MODE}.png")

        # light cleanup
        del metric_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    # Cleanup
    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_pos_pipeline()


## No index

In [2]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast

# ============== Optional deps (gracefully skipped if not installed) ==============
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: use library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean/peak eigenvalue ratio in [0,1]; higher ≈ more isotropic
            return float(np.clip(ev.mean() / (ev[-1] + 1e-12), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH        = "en_ewt-ud-train_sentences.csv"
BASELINE        = "gpt2"                 # ← GPT‑2
WORD_REP_MODE   = "last"                 # ← {"last","mean"} word representation
EXCLUDE_POS     = {"X", "SYM", "PART", "INTJ"}

# NEW: drop tokens at 0-based sentence index == 1 (the 2nd token)
# If you meant first token (1‑based “1”), change the check wid == 1 → wid == 0 in load_word_df.
EXCLUDE_INDEX_1 = True

# Sampling cap per POS for plotting (increase if you want)
RAW_MAX_PER_POS = int(1e12)         # effectively no cap


# Bootstrap replicates
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 100

# Per-replicate sample size M (min(cap, N_pos))
FAST_BS_MAX_SAMP_PER_POS  = int(1e12)
HEAVY_BS_MAX_SAMP_PER_POS = 5000

DADAPY_GRID_RANGE_MAX     = 64

# Runtime / output
BATCH_SIZE   = 8
RAND_SEED    = 42
PLOT_DIR     = Path("AO_POS"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables") / "pos_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Isotropy extras
PFI_DIRS = 256
PFI_Q_LO = 5.0
PFI_Q_HI = 95.0
EPS      = 1e-12

# Repro & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-7) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

# ========= Single-shot compute functions used inside bootstrap =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + 1e-9))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng()
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng()
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ≈ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=DADAPY_GRID_RANGE_MAX)
    ids, _, _ = d.return_id_scaling_gride(range_max=DADAPY_GRID_RANGE_MAX)
    return float(ids[-1])

def _skdim_factory(name: str):
    """Return a factory that builds a fresh skdim estimator each call, or None."""
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN, "ess": ESS,
    }
    cls = mapping.get(name)
    if cls is None: return None

    def _builder():
        if name == "lpca":      # FO variant
            return cls(ver="FO")
        elif name == "lpca99":  # ratio (0.99) variant
            return cls(ver="ratio", alphaRatio=0.99)
        else:
            return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None
    def _once(X: np.ndarray) -> float:
        est = build()
        est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once

# =============================== DATA ===============================
def load_word_df(csv_path: str,
                 exclude_pos: set[str] = EXCLUDE_POS,
                 exclude_index_1: bool = EXCLUDE_INDEX_1):
    """
    Load sentence rows (tokens, pos), expand to token rows, filter by POS
    """
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list); df.pos = df.pos.apply(_to_list)

    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        L = min(len(toks), len(poss))
        for wid in range(L):
            if exclude_index_1 and wid == 0:
                continue
            p = poss[wid]
            if p in exclude_pos:
                continue
            rows.append((sid, wid, p, toks[wid]))

    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id","pos","word"])

    # Extra safety: enforce the filter even if logic above changes later
    if exclude_index_1 and not word_df.empty:
        word_df = word_df[word_df.word_id !=0].reset_index(drop=True)

    return df, word_df

def sample_raw(word_df: pd.DataFrame, per_pos_cap: int = RAW_MAX_PER_POS) -> pd.DataFrame:
    """Per-POS cap without frequency matching."""
    picks = []
    for p, sub in word_df.groupby("pos", sort=False):
        n = min(len(sub), per_pos_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

# =============================== MODEL LOADING (robust for GPT‑2) ===============================
def _load_tok_and_model(baseline: str):
    """
    Robust loader:
    - GPT‑2: prefer GPT2TokenizerFast; ensure right padding; set pad_token=eos_token if missing.
    - Try both 'gpt2' and 'openai-community/gpt2' (some envs only have one).
    - Set model.config.pad_token_id if missing.
    """
    candidates = [baseline]
    b = baseline.lower()
    if "gpt2" in b:
        if baseline != "openai-community/gpt2":
            candidates.append("openai-community/gpt2")
        if baseline != "gpt2":
            candidates.append("gpt2")

    last_err = None
    for mid in candidates:
        try:
            if "gpt2" in mid.lower():
                tok = GPT2TokenizerFast.from_pretrained(mid, add_prefix_space=True)
            else:
                tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)

            if getattr(tok, "padding_side", None) != "right":
                tok.padding_side = "right"
            if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
                tok.pad_token = tok.eos_token

            mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
                mdl.config.pad_token_id = tok.pad_token_id

            mdl = mdl.eval().to(device)
            if device == "cuda":
                mdl.half()
            return tok, mdl, mid
        except Exception as e:
            last_err = e
            continue

    raise RuntimeError(f"Failed to load tokenizer/model for {candidates}. Last error: {last_err}")

# =============================== EMBEDDING ===============================
def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)  # GPT‑2
    if d is None: raise ValueError("Cannot determine hidden size from model.config")
    return int(d)

def embed_subset(df_all_sentences: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    """
    Return reps (L,N,D) and filled mask (N,) for the selected tokens.
    Uses word_ids() to map subword tokens back to original word indices.
    """
    df_all_sentences["sentence_id"] = df_all_sentences["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    # Materialize in batch order
    sids = list(by_sid.keys())
    df_sel = (df_all_sentences[df_all_sentences.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_id = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float32)  # use float32 (stable for some estimators)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)            # fast tokenizer
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            # hidden_states: tuple len=L of (B,T,D)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # word_id -> token positions for this item
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required; word_ids() returned None.")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]          # last subtoken
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1) # mean over subtokens
                    else:  # fallback to last for decoder models
                        vec = h[:, b, toks[-1], :]
                    reps[:, gidx, :] = vec.astype(np.float32, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# fast metric registry
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,
    "spect": _spect_once,
    "rand": _rand_once,
    "sf": _sf_once,
    "pfI": _pfI_once,
    "vmf_kappa": _vmf_kappa_once,
    "erank": _erank_once,
    "pr": _pr_once,
    "stable_rank": _stable_rank_once,
    "pca95": _pca95_once,
    "pca99": _pca99_once,
}

# heavy metric registry
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "twonn": _dadapy_twonn_once,
    "gride": _dadapy_gride_once,
    "mom":   _skdim_once_builder("mom"),
    "tle":   _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"),
    "fishers": _skdim_once_builder("fishers"),
    "lpca":  _skdim_once_builder("lpca"),
    "lpca95": _skdim_once_builder("lpca99"),
    "lpca99": _skdim_once_builder("lpca99"),
    "mle":   _skdim_once_builder("mle"),
    "danco": _skdim_once_builder("danco"),
    "ess":   _skdim_once_builder("ess"),
    "mada":  _skdim_once_builder("mada"),
    "knn":   _skdim_once_builder("knn"),
}

LABELS = {
    # Isotropy
    "iso":"IsoScore","sf":"Spectral Flatness","pfI":"Partition Isotropy I",
    "vmf_kappa":"vMF κ","spect":"Spectral Ratio","rand":"RandCos |μ|",
    # Linear ID
    "erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    # Non-linear
    "twonn":"TwoNN ID","gride":"GRIDE",
    "mom":"MOM","tle":"TLE","corrint":"CorrInt","fishers":"FisherS",
    "lpca":"lPCA FO","lpca99":"lPCA 0.99","lpca95":"lPCA 0.95", "mle":"MLE","danco":"DANCo",
    "ess":"ESS","mada":"MADA","knn":"KNN",
}

# Choose plotting order
PLOT_ORDER = (
    "iso","sf","pfI","vmf_kappa","spect","rand",
    "erank","pr","stable_rank","pca95","pca99",
    "twonn","gride","mom","tle","corrint","fishers","lpca","lpca99",
    "mle","danco","ess","mada","knn"
)
#ALL_METRICS = list(PLOT_ORDER)
ALL_METRICS = ["gride"]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_pos(metric: str,
                            pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                            layers: np.ndarray,
                            baseline: str,
                            subset_name: str = "raw"):
    rows = []
    for p, stats in pos_to_stats.items():
        mu, lo, hi, n = stats["mean"], stats.get("lo"), stats.get("hi"), stats.get("n", np.nan)
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "pos",
                "class": p, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"pos_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path):
    plt.figure(figsize=(9, 5))
    for p, stats in pos_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        plt.plot(layers, mu, label=p, lw=1.8)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=2, fontsize="small", title="POS", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_pos_pipeline():
    # 1) Load
    df_all, word_df = load_word_df(CSV_PATH, EXCLUDE_POS, EXCLUDE_INDEX_1)
    POS_TAGS = sorted(word_df.pos.unique())
    print(f"✓ corpus ready — {len(word_df):,} tokens across {len(POS_TAGS)} POS")
    print(f"• DADApy: {'available' if HAS_DADAPY else 'missing'}  • scikit-dimension: {'available' if HAS_SKDIM else 'missing'}")

    # 2) Raw sampling only (no frequency match)
    raw_df = sample_raw(word_df, RAW_MAX_PER_POS)
    print("Sample sizes per POS (raw cap):")
    print(raw_df.pos.value_counts().to_dict())

    # 3) Embed once (GPT‑2, last/mean subtoken per word)
    reps, filled = embed_subset(df_all, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    pos_arr = raw_df.pos.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric-by-metric loop (incremental outputs)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")

        # Decide registry & bootstrap settings
        if metric in FAST_ONCE:
            compute_once = FAST_ONCE[metric]
            n_bs = N_BOOTSTRAP_FAST
            Mcap = FAST_BS_MAX_SAMP_PER_POS
        else:
            compute_once = HEAVY_ONCE.get(metric)
            n_bs = N_BOOTSTRAP_HEAVY
            Mcap = HEAVY_BS_MAX_SAMP_PER_POS

        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # Per-POS bootstrap
        metric_results: Dict[str, Dict[str, np.ndarray]] = {}
        for p in POS_TAGS:
            idx = np.where(pos_arr == p)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_p, D)
            Np = sub.shape[1]
            M = min(Mcap, Np)

            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            metric_results[p] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Np)}

        # Save + plot immediately for this metric
        save_metric_csv_all_pos(metric, metric_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(metric_results, layers, metric,
                    title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                    out_path=PLOT_DIR / f"raw_{metric}_{BASELINE}.png")
        print(f"  ✓ saved: CSV= tables/pos_bootstrap/pos_raw_{metric}_{BASELINE}.csv  "
              f"plot= AO_POS/raw_{metric}_{BASELINE}_{WORD_REP_MODE}.png")

        del metric_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    # Cleanup
    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_pos_pipeline()



✓ corpus ready — 179,495 tokens across 13 POS
• DADApy: available  • scikit-dimension: available
Sample sizes per POS (raw cap):
{'NOUN': 33143, 'PUNCT': 21787, 'VERB': 21495, 'ADP': 17040, 'DET': 14938, 'PRON': 14909, 'ADJ': 12350, 'AUX': 11207, 'PROPN': 10478, 'ADV': 9061, 'CCONJ': 6354, 'SCONJ': 3370, 'NUM': 3363}


openai-community/gpt2 (embed subset): 100%|█| 1258/1258 [00:22<00:00, 55.86it/s]


✓ embedded 179,495 tokens  • layers=13

→ Computing metric: gride …
  ✓ saved: CSV= tables/pos_bootstrap/pos_raw_gride_gpt2.csv  plot= AO_POS/raw_gride_gpt2_last.png

✓ done (incremental outputs produced per metric).


# all points

In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    # include the fuller set of estimators now in use
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: use library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean/peak eigenvalue ratio in [0,1]; higher ≈ more isotropic
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH    = "it_isdt-ud-train_sentences.csv"
MODELS      = ["bert-base-uncased", "gpt2"]  # compare both
REP_MODES   = ["first", "mean", "last"]      # extract these
EXCLUDE_POS = {"X", "SYM", "PART", "INTJ"}   # optional exclusions

# Pooled tokens cap (set to None to use ALL tokens)
ALL_TOKEN_CAP            = None          # e.g., 80_000 for safety
BATCH_SIZE               = 1
RAND_SEED                = 42

# Bootstrap replicates
N_BOOTSTRAP_FAST         = 2
N_BOOTSTRAP_HEAVY        = 2

# Per-replicate sample size M for pooled tokens (min(cap, N_all))
FAST_BS_MAX_SAMP         = 10000         # for fast metrics
HEAVY_BS_MAX_SAMP        = 10_000          # for TwoNN/GRIDE/skdim (tune for speed)

# GRIDE multi-scale max neighbor rank
DADAPY_GRID_RANGE_MAX    = 64

# Isotropy Monte-Carlo directions / quantiles and epsilon
PFI_DIRS   = 128
PFI_Q_LO   = 5.0
PFI_Q_HI   = 95.0
EPS        = 1e-12

# Output
OUT_DIR     = Path("metrics_all_token"); OUT_DIR.mkdir(parents=True, exist_ok=True)
PLOT_DIR    = OUT_DIR / "plots";      PLOT_DIR.mkdir(exist_ok=True)
CSV_DIR     = OUT_DIR / "tables";     CSV_DIR.mkdir(exist_ok=True)

# Repro & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Style
sns.set_style("darkgrid"); plt.rcParams["figure.dpi"] = 120

MODEL_LABEL = {
    "bert-base-uncased": "BERT-base",
    "gpt2": "GPT-2",
}

def _model_tag(name: str) -> str:
    return name.split("/")[-1].replace(":", "_")

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

# ========= Per-subsample single-value compute functions (used inside bootstrap) =========
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + EPS))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng(RAND_SEED)
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _pfI_once(X: np.ndarray) -> float:
    n, d = X.shape
    if n < 2: return np.nan
    rng = np.random.default_rng(RAND_SEED)
    U = rng.standard_normal((PFI_DIRS, d)).astype(np.float32)
    U /= np.linalg.norm(U, axis=1, keepdims=True) + 1e-9
    S = U @ X.T
    m = np.max(S, axis=1, keepdims=True)
    logZ = (m + np.log(np.sum(np.exp(S - m), axis=1, keepdims=True))).ravel()
    lo = np.percentile(logZ, PFI_Q_LO)
    hi = np.percentile(logZ, PFI_Q_HI)
    return float(np.exp(lo - hi))  # ≈ min Z / max Z (robust)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    # standard closed-form approximation
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))

# --- Linear ID (fast) ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.95)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=DADAPY_GRID_RANGE_MAX)
    ids, _, _ = d.return_id_scaling_gride(range_max=DADAPY_GRID_RANGE_MAX)
    return float(ids[-1])

def _skdim_factory(name: str):
    """Return a factory that builds a fresh skdim estimator each call, or None."""
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo,  "mind_ml": MiND_ML,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None

    def _builder():
        if name == "lpca":      # FO variant
            return cls(ver="FO")
        elif name == "lpca99":  # ratio (0.99) variant
            return cls(ver="ratio", alphaRatio=0.99)
        else:
            return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None

    def _once(X: np.ndarray) -> float:
        est = build(); est.fit(_jitter_unique(X))
        return float(getattr(est, "dimension_", np.nan))
    return _once

# =============================== DATA ===============================
def load_word_df(csv_path: str, exclude_pos: set[str] = EXCLUDE_POS):
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)  # keep string IDs
    df.tokens = df.tokens.apply(_to_list); df.pos = df.pos.apply(_to_list)

    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        for wid, (tok, p) in enumerate(zip(toks, poss)):
            if (exclude_pos is None) or (p not in exclude_pos):
                rows.append((sid, wid))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id"])
    return df, word_df

def sample_all(word_df: pd.DataFrame, cap: int | None) -> pd.DataFrame:
    if cap is None or len(word_df) <= cap:
        return word_df.reset_index(drop=True)
    return word_df.sample(cap, random_state=RAND_SEED).reset_index(drop=True)

# =============================== EMBEDDING ===============================
def embed_subset(df_all_sentences: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str,
                 word_rep_mode: str,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    """Return reps (L,N,D) and filled mask (N,) for the selected tokens."""
    df_all_sentences["sentence_id"] = df_all_sentences["sentence_id"].astype(str)
    subset_df = subset_df.copy()
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    # Materialize in the exact order we will batch
    sids = list(by_sid.keys())
    df_sel = (df_all_sentences[df_all_sentences.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True,  add_prefix_space=True)
    # GPT-2 has no PAD token → use EOS for padding
    if tokzr.pad_token is None:
        tokzr.pad_token = tokzr.eos_token

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id
    if device == "cuda": model.half()

    L = model.config.num_hidden_layers + 1
    D = model.config.hidden_size
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    tag = _model_tag(baseline)
    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{tag}:{word_rep_mode} (embed all)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # word_id -> token positions for this item
                mp = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid);
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            # free batch buffers
            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps[:, filled], filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# fast metric registry (name -> callable(X)->float)
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,
    "spect": _spect_once,
    "rand": _rand_once,
    "sf": _sf_once,
    "vmf_kappa": _vmf_kappa_once,
    "erank": _erank_once,
    "pr": _pr_once,
    "stable_rank": _stable_rank_once,

}

# heavy metric registry
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "twonn":   _dadapy_twonn_once,
    "gride":   _dadapy_gride_once,
    "mom":     _skdim_once_builder("mom"),
    "tle":     _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"),
    "fishers": _skdim_once_builder("fishers"),
    "lpca":    _skdim_once_builder("lpca"),
    "lpca95":    _skdim_once_builder("lpca95"),
    "lpca99":  _skdim_once_builder("lpca99"),
    "mle":     _skdim_once_builder("mle"),
    "mada":    _skdim_once_builder("mada"),
}

LABELS = {
    # Isotropy
    "iso":"IsoScore","spect":"Spectral Ratio","rand":"RandCos |μ|",
    "sf":"Spectral Flatness","vmf_kappa":"vMF κ",
    # Linear ID
    "erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank", "lpca99":"lPCA 0.95",  "lpca99":"lPCA 0.99", "lpca":"lPCA (FO)",
    # Non-linear
    "twonn":"TwoNN ID","gride":"GRIDE",
    "mom":"MOM","tle":"TLE","corrint":"CorrInt",
    "fishers":"FisherS",,"mle":"MLE","mind_ml":"MiND-ML",
    "mada":"MADA","knn":"KNN",
}

# Choose plotting order
PLOT_ORDER = (
    "iso","sf","pfI","vmf_kappa","spect","rand",
    "erank","pr","stable_rank"
    "twonn","gride","mom","tle","corrint","fishers","lpca","lpca99", "lpca95", 
    "mle","mind_ml","mada","knn"
)

ALL_METRICS = [m for m in PLOT_ORDER if (m in FAST_ONCE or HEAVY_ONCE.get(m) is not None)]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_alltokens(metric: str,
                              results: Dict[tuple, Dict[str, np.ndarray]],
                              layers: np.ndarray):
    """
    results key: (model, rep_mode) -> {"mean","lo","hi","n"}
    """
    rows = []
    for (model, mode), stats in results.items():
        mu, lo, hi, n = stats["mean"], stats.get("lo"), stats.get("hi"), stats.get("n", np.nan)
        tag = _model_tag(model)
        for l, val in enumerate(mu):
            rows.append({
                "model": model, "model_tag": tag,
                "word_rep_mode": mode, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)),
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"alltokens_{metric}.csv"
    df.to_csv(out, index=False)

def plot_metric_compare(results: Dict[tuple, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str):
    """
    Plot all (model,rep) curves together with CI.
    """
    plt.figure(figsize=(10, 5.5))
    for (model, mode), stats in results.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        label = f"{MODEL_LABEL.get(model, model)} • {mode}"
        if mu is None or np.all(np.isnan(mu)): continue
        plt.plot(layers, mu, label=label, lw=1.8)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15)
    plt.xlabel("Layer")
    plt.ylabel(LABELS.get(metric, metric.upper()))
    plt.title(f"{LABELS.get(metric, metric.upper())} • all tokens")
    plt.legend(ncol=2, fontsize="small", frameon=False)
    plt.tight_layout()
    plt.savefig(PLOT_DIR / f"alltokens_{metric}.png", dpi=220)
    plt.close()

def plot_metric_deltas(results: Dict[tuple, Dict[str, np.ndarray]],
                       layers: np.ndarray, metric: str):
    """
    Optional: plot (first - mean) and (last - mean) per model.
    Uses difference of bootstrap means (no CI; for exact CIs you'd need paired resamples).
    """
    plt.figure(figsize=(9, 4.8))
    for model in MODELS:
        base = results.get((model, "mean"))
        if base is None: continue
        base_mu = base["mean"]
        for mode in ("first", "last"):
            cur = results.get((model, mode))
            if cur is None: continue
            diff = cur["mean"] - base_mu
            label = f"{MODEL_LABEL.get(model, model)} • {mode} − mean"
            plt.plot(layers, diff, label=label, lw=2)
    plt.axhline(0.0, ls="--", lw=1, alpha=0.6)
    plt.xlabel("Layer")
    plt.ylabel(f"Δ {LABELS.get(metric, metric.upper())}")
    plt.title(f"Representation deltas (first/last − mean) • {LABELS.get(metric, metric.upper())}")
    plt.legend(ncol=2, fontsize="small", frameon=False)
    plt.tight_layout()
    plt.savefig(PLOT_DIR / f"alltokens_{metric}_deltas.png", dpi=220)
    plt.close()

# =============================== METRIC COMPUTE ===============================
def _compute_metric_bootstrap(rep: np.ndarray, metric: str) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    rep: (L, N, D)
    """
    N = rep.shape[1]
    if metric in FAST_ONCE:
        M = min(FAST_BS_MAX_SAMP, N)
        n_bs = N_BOOTSTRAP_FAST
        compute_once = FAST_ONCE[metric]
    else:
        compute_once = HEAVY_ONCE.get(metric)
        if compute_once is None:
            return (np.full(rep.shape[0], np.nan),)*3
        M = min(HEAVY_BS_MAX_SAMP, N)
        n_bs = N_BOOTSTRAP_HEAVY
    mu, lo, hi = _bs_layer_loop(rep, M, n_bs, compute_once)
    return mu, lo, hi

def compute_all_metrics_for_rep(rep: np.ndarray) -> Dict[str, Dict[str, np.ndarray]]:
    """
    Returns: metric -> {"mean": mu, "lo": lo, "hi": hi, "n": N}
    """
    results = {}
    for metric in ALL_METRICS:
        mu, lo, hi = _compute_metric_bootstrap(rep, metric)
        results[metric] = {"mean": mu, "lo": lo, "hi": hi, "n": int(rep.shape[1])}
    return results

# =============================== DRIVER ===============================
def run_alltokens_compare():
    # 1) Load & pick pooled tokens
    df_all, word_df = load_word_df(CSV_PATH, EXCLUDE_POS)
    all_df = sample_all(word_df, ALL_TOKEN_CAP)
    print(f"✓ pooled tokens: {len(all_df):,}")

    # 2) For each (model, rep_mode): embed, compute metrics per layer
    all_layers = None
    combined: Dict[str, Dict[tuple, Dict[str, np.ndarray]]] = {m:{} for m in ALL_METRICS}

    for model in MODELS:
        for mode in REP_MODES:
            reps, filled = embed_subset(df_all, all_df, model, mode, BATCH_SIZE)

            L = reps.shape[0]
            if all_layers is None:
                all_layers = np.arange(L)
            print(f"→ {model} • {mode}: reps shape = {reps.shape}")

            per_metric = compute_all_metrics_for_rep(reps)
            for metric, stats in per_metric.items():
                combined[metric][(model, mode)] = stats

            # free
            del reps; gc.collect()
            if device == "cuda": torch.cuda.empty_cache()

    # 3) Save & plot per metric
    for metric, results in combined.items():
        save_metric_csv_alltokens(metric, results, all_layers)
        plot_metric_compare(results, all_layers, metric)
        plot_metric_deltas(results, all_layers, metric)

    print("✓ done; outputs in:", OUT_DIR.resolve())

if __name__ == "__main__":
    run_alltokens_compare()


In [3]:
from __future__ import annotations

import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast
from sklearn.decomposition import PCA

# Plotly for interactive 3D
import plotly.graph_objects as go

# =============================== CONFIG ===============================
CSV_PATH   = "en_ewt-ud-train_sentences.csv"  # columns: sentence_id (str), tokens (list[str])
MODEL_ID   = "gpt2"                           # e.g., "gpt2", "bert-base-uncased"
REP_MODE   = "last"                           # "first" | "last" | "mean"
BATCH_SIZE = 4
RAND_SEED  = 42

# For plotting (subsample to keep the browser smooth)
PCA_MAX_POINTS = None                         # None = plot all tokens

# Output
OUT_DIR  = Path("pca3d_all_tokens"); OUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUT = OUT_DIR / f"{MODEL_ID.replace('/','_')}_pca3d_layers.html"

# Repro + device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True

# =============================== PLOT SETTINGS ===============================
# Bigger fonts: set axes to 17 (what you asked)
AXIS_TITLE_FONT_SIZE = 24
AXIS_TICK_FONT_SIZE  = 20
GLOBAL_FONT_SIZE     = 20
SLIDER_FONT_SIZE     = 20
HOVER_FONT_SIZE      = 20
TITLE_FONT_SIZE      = 20



MARKER_SIZE    = 2
MARKER_OPACITY = 0.70

# ---- Color options (easy to change) ----
# Choose one:
#   COLOR_MODE = "constant"   -> single fixed color for all points
#   COLOR_MODE = "per_layer"  -> different color per layer (edit LAYER_COLORS)
#   COLOR_MODE = "colorscale" -> gradient color per point (by pc1/pc2/pc3/radius)
COLOR_MODE = "constant"

# If COLOR_MODE == "constant"
MARKER_COLOR = "orange"  # any CSS color: "red", "#ff8800", "rgb(255,0,0)", etc.

# If COLOR_MODE == "per_layer"
LAYER_COLORS = None
# Example:
# LAYER_COLORS = ["#636EFA","#EF553B","#00CC96","#AB63FA","#FFA15A","#19D3F3","#FF6692","#B6E880"]

# If COLOR_MODE == "colorscale"
COLOR_BY   = "pc1"       # "pc1" | "pc2" | "pc3" | "radius"
COLOR_SCALE = "Turbo"    # "Viridis", "Cividis", "Plasma", "Inferno", "Turbo", etc.
SHOW_COLORBAR = True

# Optional: show the figure window (useful in notebooks / interactive sessions)
SHOW_FIG = False

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine num_hidden_layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden_size from model.config")
    return int(d)

def _load_tok_and_model(model_id: str):
    """
    Robust loader:
      - For GPT-2, force the fast tokenizer and try both 'gpt2' and 'openai-community/gpt2'.
      - Set PAD = EOS for GPT-2-like tokenizers.
    """
    cands = [model_id]
    if "gpt2" in model_id.lower():
        if model_id != "openai-community/gpt2": cands.append("openai-community/gpt2")
        if model_id != "gpt2": cands.append("gpt2")

    last_err = None
    for mid in cands:
        try:
            if "gpt2" in mid.lower():
                tok = GPT2TokenizerFast.from_pretrained(mid, add_prefix_space=True)
            else:
                tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)

            # pad-right and PAD token if missing (common for GPT-2)
            if getattr(tok, "padding_side", None) != "right":
                tok.padding_side = "right"
            if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
                tok.pad_token = tok.eos_token

            mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
                mdl.config.pad_token_id = tok.pad_token_id
            if device == "cuda":
                mdl = mdl.half()
            return tok, mdl.eval().to(device), mid
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to load model/tokenizer for {cands}: {last_err}")

def load_word_df(csv_path: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return (df_all_sentences, token_index_df) for ALL tokens (no labels needed)."""
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    rows = []
    for sid, toks in df[["sentence_id","tokens"]].itertuples(index=False):
        for wid in range(len(toks)):
            rows.append((sid, wid))
    word_df = pd.DataFrame(rows, columns=["sentence_id","word_id"])
    return df, word_df

# =============================== EMBEDDING ===============================
def embed_all_tokens(df_all: pd.DataFrame,
                     token_df: pd.DataFrame,
                     model_id: str,
                     rep_mode: str = "mean",
                     batch_size: int = 4):
    """
    Embed all tokens with per-word piece aggregation.
    Returns:
      reps: np.ndarray of shape (L, N, D)
      words: list[str] length N (token strings for hover)
      model_tag: str
    """
    tokzr, model, model_tag = _load_tok_and_model(model_id)

    # Build per-sentence index: sid -> list[(global_idx, word_id)]
    token_df = token_df.copy()
    token_df["sentence_id"] = token_df["sentence_id"].astype(str)
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(token_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(sid, []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_all[df_all.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    L = _num_hidden_layers(model) + 1
    D = _hidden_size(model)
    N = len(token_df)

    reps   = np.zeros((L, N, D), np.float16)
    words  = [""] * N
    filled = np.zeros(N, dtype=bool)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_tag}: embed ({rep_mode})"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # map word_id -> token positions
                mp: Dict[int, List[int]] = {}
                wids = enc_be.word_ids(b)
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids() unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                toks_for_sent = df_sel.loc[sid, "tokens"]
                for gidx, wid in by_sid.get(sid, []):
                    tokpos = mp.get(wid)
                    if not tokpos:
                        continue

                    # choose representation per wordpiece policy
                    if rep_mode == "first":
                        vec = h[:, b, tokpos[0], :]              # (L, D)
                    elif rep_mode == "last":
                        vec = h[:, b, tokpos[-1], :]             # (L, D)
                    else:  # "mean"
                        vec = h[:, b, tokpos, :].mean(axis=1)    # (L, D)

                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    words[gidx] = str(toks_for_sent[wid])
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    if (~filled).any():
        missing = int((~filled).sum())
        print(f"⚠ Missing vectors for {missing} tokens (skipped in PCA).")
        reps   = reps[:, filled]
        words  = [w for w, f in zip(words, filled) if f]

    return reps, words, model_tag

# =============================== PCA + PLOTLY ===============================
def _layer_color(l: int) -> str:
    """Pick a per-layer color if LAYER_COLORS is provided; otherwise fall back."""
    if LAYER_COLORS is not None and l < len(LAYER_COLORS):
        return LAYER_COLORS[l]
    return MARKER_COLOR

def _colorscale_values(Y: np.ndarray) -> np.ndarray:
    """Compute numeric values used for colorscale coloring."""
    if COLOR_BY == "pc1":
        return Y[:, 0]
    if COLOR_BY == "pc2":
        return Y[:, 1]
    if COLOR_BY == "pc3":
        return Y[:, 2]
    if COLOR_BY == "radius":
        return np.linalg.norm(Y, axis=1)
    raise ValueError("COLOR_BY must be one of: 'pc1', 'pc2', 'pc3', 'radius'.")

def pca3d_per_layer_and_plot(reps: np.ndarray, words: List[str], model_tag: str):
    """
    reps: (L, N, D), words: list[str] length N
    Creates an interactive 3D Plotly figure with a layer slider.
    """
    L, N, D = reps.shape

    # Optional subsampling for plotting
    if PCA_MAX_POINTS is None or PCA_MAX_POINTS >= N:
        sel_idx = np.arange(N, dtype=np.int64)
    else:
        sel_idx = np.random.default_rng(RAND_SEED).choice(N, size=PCA_MAX_POINTS, replace=False)
        print(f"⚠ PCA_MAX_POINTS={PCA_MAX_POINTS} < N={N} → plotting a subset.")

    reps_sel  = reps[:, sel_idx, :].astype(np.float32, copy=False)
    words_sel = [words[i] for i in sel_idx]

    # PCA to 3D for each layer (independently)
    Y_layers: List[np.ndarray] = []
    for l in range(L):
        X = reps_sel[l]  # (n, D)
        Xc = X - X.mean(0, keepdims=True)
        pca = PCA(n_components=3, random_state=RAND_SEED)
        Y = pca.fit_transform(Xc)  # (n, 3)
        Y_layers.append(Y)

    # Build Plotly figure with a slider to switch layers
    traces = []
    for l in range(L):
        Y = Y_layers[l]

        # Choose marker coloring
        if COLOR_MODE == "constant":
            marker = dict(size=MARKER_SIZE, opacity=MARKER_OPACITY, color=MARKER_COLOR)

        elif COLOR_MODE == "per_layer":
            marker = dict(size=MARKER_SIZE, opacity=MARKER_OPACITY, color=_layer_color(l))

        elif COLOR_MODE == "colorscale":
            cvals = _colorscale_values(Y)
            marker = dict(
                size=MARKER_SIZE,
                opacity=MARKER_OPACITY,
                color=cvals,
                colorscale=COLOR_SCALE,
                showscale=SHOW_COLORBAR,
                colorbar=dict(
                    title=COLOR_BY,
                    titlefont=dict(size=AXIS_TITLE_FONT_SIZE),
                    tickfont=dict(size=AXIS_TICK_FONT_SIZE),
                ),
            )
        else:
            raise ValueError("COLOR_MODE must be one of: 'constant', 'per_layer', 'colorscale'.")

        hovertemplate = (
            "<b>%{text}</b><br>"
            "x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}"
            "<extra>Layer " + str(l) + "</extra>"
        )

        traces.append(
            go.Scatter3d(
                x=Y[:, 0], y=Y[:, 1], z=Y[:, 2],
                mode="markers",
                marker=marker,
                text=words_sel,                 # <-- enables %{text} in hovertemplate
                hovertemplate=hovertemplate,
                name=f"Layer {l}",
                visible=(l == 0),
                showlegend=False,
            )
        )

    steps = []
    for l in range(L):
        vis = [False] * L
        vis[l] = True
        steps.append(dict(
            method="update",
            args=[
                {"visible": vis},
                {"title": {
                    "text": f"{model_tag} • PCA 3D • layer {l} (drag to rotate)",
                    "font": {"size": TITLE_FONT_SIZE},
                }},
            ],
            label=str(l),
        ))

    sliders = [dict(
        active=0,
        steps=steps,
        currentvalue={"prefix": "Layer: ", "font": {"size": SLIDER_FONT_SIZE}},
        font={"size": SLIDER_FONT_SIZE},
        pad={"t": 10}
    )]

    layout = go.Layout(
        title={"text": f"{model_tag} • PCA 3D • layer 0 (drag to rotate)",
               "font": {"size": TITLE_FONT_SIZE}},
        font={"size": GLOBAL_FONT_SIZE},
        hoverlabel={"font": {"size": HOVER_FONT_SIZE}},

        scene=dict(
            xaxis=dict(
                title=dict(text="PC1", font=dict(size=AXIS_TITLE_FONT_SIZE)),
                tickfont=dict(size=AXIS_TICK_FONT_SIZE),
            ),
            yaxis=dict(
                title=dict(text="PC2", font=dict(size=AXIS_TITLE_FONT_SIZE)),
                tickfont=dict(size=AXIS_TICK_FONT_SIZE),
            ),
            zaxis=dict(
                title=dict(text="PC3", font=dict(size=AXIS_TITLE_FONT_SIZE)),
                tickfont=dict(size=AXIS_TICK_FONT_SIZE),
            ),
            aspectmode="data",
        ),
        margin=dict(l=0, r=0, b=0, t=60),
        sliders=sliders,
    )

    fig = go.Figure(data=traces, layout=layout)

    if SHOW_FIG:
        fig.show()

    fig.write_html(str(HTML_OUT), include_plotlyjs="cdn")
    print("✓ Saved interactive HTML to:", HTML_OUT)

# =============================== DRIVER ===============================
def run_pca3d_all_tokens():
    df_all, token_df = load_word_df(CSV_PATH)
    reps, words, tag = embed_all_tokens(
        df_all, token_df, MODEL_ID, rep_mode=REP_MODE, batch_size=BATCH_SIZE
    )
    pca3d_per_layer_and_plot(reps, words, tag)

    # Cleanup
    del reps
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

if __name__ == "__main__":
    run_pca3d_all_tokens()


openai-community/gpt2: embed (last): 100%|██| 2517/2517 [00:41<00:00, 60.94it/s]


✓ Saved interactive HTML to: pca3d_all_tokens/gpt2_pca3d_layers.html


In [3]:
"""
SAFE POS-by-POS representation-geometry pipeline (BERT + bootstrap + stability)

Key safety design:
- Embed + compute metrics ONE POS at a time (caps memory).
- Avoid global reps[:, idx] slicing that would make giant copies (NumPy advanced indexing).
- Keep M capped for heavy kNN metrics (GRIDE), and keep batch sizes small.

Outputs:
- tables_POS/pos_bootstrap/pos_raw_<metric>_<model>.csv
- results_POS/raw_<metric>_<model>.png
- results_POS/stability/ stability curves CSV+PNG (optional)
- results_POS/run_metadata.json

Notes:
- For GRIDE, requires dadapy. If missing, the script will skip GRIDE gracefully.
"""

from __future__ import annotations

import os, gc, ast, json, time, random, inspect, contextlib
from pathlib import Path
from typing import Dict, List, Tuple, Callable, Optional

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel


# ===========================
# OPTIONAL DEPENDENCIES
# ===========================
HAS_DADAPY = False
try:
    from dadapy import Data
    HAS_DADAPY = True
except Exception:
    HAS_DADAPY = False

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA,
        MLE, DANCo, ESS, MiND_ML, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    HAS_SKDIM = False

# IsoScore is optional; we keep a fallback
_HAS_ISOSCORE = False
IsoScore = None
try:
    from isoscore import IsoScore as _IsoScoreObj  # type: ignore
    IsoScore = _IsoScoreObj
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False

if not _HAS_ISOSCORE:
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

def _isoscore_call(X: np.ndarray) -> float:
    obj = IsoScore
    if obj is None:
        return float("nan")
    if hasattr(obj, "IsoScore") and callable(getattr(obj, "IsoScore")):
        return float(obj.IsoScore(X))
    if callable(obj):
        return float(obj(X))
    return float("nan")


# ===========================
# CONFIG (SAFE DEFAULTS)
# ===========================
CSV_PATH = "en_ewt-ud-train_sentences.csv"
BASELINE = "bert-base-uncased"

# Representation choice for wordpieces
WORD_REP_MODE = "first"  # "first" or "mean"

# POS filtering
EXCLUDE_POS = {"X", "SYM", "PART", "INTJ"}

# SAFETY: cap how many tokens you embed PER POS (very important)
SAFE_MAX_TOKENS_PER_POS = 5000  # <= 5000 recommended if running GRIDE

# Bootstrap
N_BOOTSTRAP_FAST = 50
N_BOOTSTRAP_HEAVY = 200

# For heavy metrics, cap the resample size M (critical for kNN methods)
HEAVY_BS_MAX_SAMP_PER_POS = 5000

# GRIDE: maximum neighbor rank (maxk) for distances and scaling ID
DADAPY_GRID_RANGE_MAX = 64

# Numerical stability
EPS = 1e-9
JITTER_EPS = 1e-6

# Reproducibility
RAND_SEED = 42

# Embedding throughput (keep small for safety)
BATCH_SIZE = 1

# What metrics to compute
# (You can extend this list; this safe script keeps your default GRIDE example)
ALL_METRICS = ["iso"]  # e.g. ["iso","sf","gride"]

# Optional: stability curves (how many tokens M and how many bootstraps B)
RUN_STABILITY = True
STAB_DIR = Path("results_POS") / "stability"
STAB_DIR.mkdir(parents=True, exist_ok=True)

STAB_TOPK_POS = 3              # run stability on top-K most frequent POS
STAB_M_MIN = 200
STAB_M_MAX = 5000
STAB_M_GRID_POINTS = 8
STAB_M_REPS = 20
STAB_M_REPLACE = False         # recommended for kNN ID stability curves
STAB_TOL_CI_HALFWIDTH = 1.0    # GRIDE: “stable within ±1 dim” default
STAB_STABLE_FOR = 2

RUN_B_CONVERGENCE = True
STAB_B_MAX = 400
STAB_B_GRID = [25, 50, 100, 200, 300, 400]
STAB_B_REPLACE = True          # bootstrap-with-replacement by definition

# Output dirs
PLOT_DIR = Path("results_POS"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR = Path("tables_POS") / "pos_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

# Device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True


# ===========================
# HELPERS
# ===========================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """
    Eigenvalues of (Xc^T Xc) up to a constant via SVD of centered X.
    Pads with zeros to length D so spectral metrics behave consistently.
    """
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)  # length min(N,D)
        D = Xc.shape[1]
        if lam.size < D:
            lam = np.concatenate([lam, np.zeros(D - lam.size, dtype=np.float64)])
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = JITTER_EPS) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X


# ===========================
# METRICS (single-value on X)
# ===========================
# --- Isotropy (fast) ---
def _iso_once(X: np.ndarray) -> float:
    return float(_isoscore_call(X))

def _spect_once(X: np.ndarray) -> float:
    ev = np.linalg.eigvalsh(np.cov(X.T, ddof=0))
    return float(ev[-1] / (ev.mean() + EPS))

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2:
        return np.nan
    rng = np.random.default_rng(RAND_SEED)
    K_eff = min(K, (n * (n - 1)) // 2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A * B, axis=1)
    den = (np.linalg.norm(A, axis=1) * np.linalg.norm(B, axis=1) + EPS)
    return float(np.mean(np.abs(num / den)))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0:
        return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS)))
    am = float(lam.mean() + EPS)
    return float(gm / am)

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2:
        return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + EPS)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < EPS:
        return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + EPS), 0.0))

# --- Linear ID (fast) ---
def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0:
        return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0:
        return np.nan
    s1 = lam.sum()
    s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0:
        return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (heavy) ---
def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY:
        return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=DADAPY_GRID_RANGE_MAX)
    ids, _, _ = d.return_id_scaling_gride(range_max=DADAPY_GRID_RANGE_MAX)
    return float(ids[-1])

# Registries
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,
    "spect": _spect_once,
    "rand": _rand_once,
    "sf": _sf_once,
    "vmf_kappa": _vmf_kappa_once,
    "erank": _erank_once,
    "pr": _pr_once,
    "stable_rank": _stable_rank_once,
}

HEAVY_ONCE: Dict[str, Optional[Callable[[np.ndarray], float]]] = {
    "gride": _dadapy_gride_once,
}

LABELS = {
    "iso": "IsoScore",
    "spect": "Spectral Ratio",
    "rand": "RandCos |μ|",
    "sf": "Spectral Flatness",
    "vmf_kappa": "vMF κ",
    "erank": "Effective Rank",
    "pr": "Participation Ratio",
    "stable_rank": "Stable Rank",
    "gride": "GRIDE",
}


# ===========================
# DATA
# ===========================
def load_word_df(csv_path: str, exclude_pos: set[str] = EXCLUDE_POS):
    df = pd.read_csv(csv_path, usecols=["sentence_id", "tokens", "pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df.pos = df.pos.apply(_to_list)

    # explode to token-level
    rows = []
    for sid, toks, poss in df[["sentence_id", "tokens", "pos"]].itertuples(index=False):
        for wid, (tok, p) in enumerate(zip(toks, poss)):
            if p not in exclude_pos:
                rows.append((sid, wid, p, tok))
    word_df = pd.DataFrame(rows, columns=["sentence_id", "word_id", "pos", "word"])
    return df, word_df

def sample_pos(word_df: pd.DataFrame, pos: str, cap: int) -> pd.DataFrame:
    sub = word_df[word_df.pos == pos]
    if len(sub) == 0:
        return sub
    n = min(len(sub), cap)
    return sub.sample(n=n, random_state=RAND_SEED, replace=False).reset_index(drop=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    base_colors: List[Tuple[float, float, float]] = []
    for name in ("tab20", "tab20b", "tab20c"):
        try:
            base_colors.extend(sns.color_palette(name, 20))
        except Exception:
            pass
    if len(base_colors) < len(classes):
        base_colors = list(sns.color_palette("husl", len(classes)))
    ordered = list(sorted(classes))
    return {cls: base_colors[i % len(base_colors)] for i, cls in enumerate(ordered)}


# ===========================
# EMBEDDING (POS subset)
# ===========================
def build_tokenizer_and_model(baseline: str):
    # tokenizer (bert-base-uncased may not accept add_prefix_space in some versions)
    try:
        tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True, add_prefix_space=True)
    except TypeError:
        tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True)

    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    if device == "cuda":
        model.half()
    return tokzr, model

def embed_subset(
    df_all_sentences: pd.DataFrame,
    subset_df: pd.DataFrame,
    tokzr: AutoTokenizer,
    model: AutoModel,
    word_rep_mode: str = WORD_REP_MODE,
    batch_size: int = BATCH_SIZE,
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Return reps (L,N,D) float16 and filled mask (N,) for the selected tokens (subset_df rows).
    Safe: N is capped per POS, so reps is bounded.
    """
    df_all_sentences["sentence_id"] = df_all_sentences["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int, int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id", "word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (
        df_all_sentences[df_all_sentences.sentence_id.isin(sids)]
        .drop_duplicates("sentence_id")
        .set_index("sentence_id")
        .loc[sids]
    )

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = model.config.num_hidden_layers + 1
    D = model.config.hidden_size
    N = len(subset_df)

    reps = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    amp_ctx = torch.autocast("cuda", dtype=torch.float16) if device == "cuda" else contextlib.nullcontext()

    with torch.inference_mode(), amp_ctx:
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{BASELINE} embed"):
            batch_ids = sids[start: start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t = {k: v.to(device) for k, v in enc_be.items()}

            out = model(**enc_t)
            hs = torch.stack(out.hidden_states).detach()  # (L,B,T,D)
            hs = hs.to(torch.float16).cpu().numpy()       # store as float16 for safety

            for b, sid in enumerate(batch_ids):
                mp: Dict[int, List[int]] = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if word_rep_mode == "first":
                        vec = hs[:, b, toks[0], :]                 # (L,D)
                    else:
                        vec = hs[:, b, toks, :].mean(axis=1)       # (L,D)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            # cleanup batch
            del enc_be, enc_t, out, hs
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} sampled words")

    return reps, filled


# ===========================
# BOOTSTRAP (SAFE, SMALL)
# ===========================
def bs_layer_loop(
    reps_pos: np.ndarray,  # (L,N,D) for ONE POS
    M: int,
    n_reps: int,
    compute_once: Callable[[np.ndarray], float],
    replace: bool = True,
    seed: int = RAND_SEED,
    layers_idx: Optional[List[int]] = None,
    return_reps: bool = False,
):
    """
    Safe: reps_pos is bounded by SAFE_MAX_TOKENS_PER_POS.
    """
    L, N, D = reps_pos.shape
    if layers_idx is None:
        layers_idx = list(range(L))
    layers_idx = [int(x) for x in layers_idx]

    rng = np.random.default_rng(seed)
    A = np.full((n_reps, len(layers_idx)), np.nan, np.float32)

    M_eff = min(int(M), int(N))
    for r in range(n_reps):
        if replace:
            idx = rng.integers(0, N, size=M_eff)
        else:
            idx = rng.choice(N, size=M_eff, replace=False)

        for j, l in enumerate(layers_idx):
            X = reps_pos[l, idx].astype(np.float32, copy=False)  # small (M_eff,D)
            try:
                A[r, j] = float(compute_once(X))
            except Exception:
                A[r, j] = np.nan

    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)

    if return_reps:
        return mu, lo, hi, A
    return mu, lo, hi


# ===========================
# SAVE / PLOT (same style as you had)
# ===========================
def save_metric_csv_all_pos(metric: str,
                            pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                            layers: np.ndarray,
                            baseline: str,
                            subset_name: str = "raw"):
    rows = []
    for p, stats in pos_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        n = stats.get("n", 0)
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name,
                "model": baseline,
                "feature": "pos",
                "class": p,
                "metric": metric,
                "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(n),
                "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
                "seed": int(RAND_SEED),
                "safe_max_tokens_per_pos": int(SAFE_MAX_TOKENS_PER_POS),
                "heavy_cap": int(HEAVY_BS_MAX_SAMP_PER_POS),
                "gride_range_max": int(DADAPY_GRID_RANGE_MAX),
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"pos_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(pos_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Optional[Dict[str, Tuple[float, float, float]]] = None):
    plt.figure(figsize=(9, 5))
    for p, stats in pos_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)):
            continue
        color = palette.get(p) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=p, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)

    plt.xlabel("Layer")
    plt.ylabel(LABELS.get(metric, metric.upper()))
    plt.title(title)

    n_classes = len(pos_to_stats)
    ncol = 3 if n_classes > 12 else 2
    plt.legend(ncol=ncol, fontsize="small", title="POS", frameon=False)

    plt.tight_layout()
    plt.savefig(out_path, dpi=220)
    plt.close()


# ===========================
# STABILITY UTILITIES (M-curve + B-curve)
# ===========================
def make_M_grid(N: int, min_M=200, max_M=5000, n_points=8) -> List[int]:
    max_M = min(int(max_M), int(N))
    min_M = min(int(min_M), max_M)
    if max_M <= 2:
        return [max_M]
    if max_M <= min_M:
        return [max_M]
    Ms = np.unique(np.round(np.geomspace(min_M, max_M, n_points)).astype(int))
    Ms = Ms[Ms >= 2]
    if Ms.size == 0:
        return [max_M]
    if Ms[-1] != max_M:
        Ms = np.unique(np.r_[Ms, max_M])
    return Ms.tolist()

def sweep_M(reps_pos: np.ndarray, Ms: List[int], compute_once, n_reps: int, replace: bool, layers_idx: List[int]) -> pd.DataFrame:
    rows = []
    N = reps_pos.shape[1]
    for M in Ms:
        M_eff = min(int(M), int(N))
        mu, lo, hi = bs_layer_loop(reps_pos, M_eff, n_reps, compute_once, replace=replace, layers_idx=layers_idx, seed=RAND_SEED + M_eff)
        for j, l in enumerate(layers_idx):
            rows.append({
                "M": int(M_eff),
                "layer": int(l),
                "mean": float(mu[j]),
                "ci_low": float(lo[j]),
                "ci_high": float(hi[j]),
                "ci_width": float(hi[j] - lo[j]),
                "n_reps": int(n_reps),
                "replace": bool(replace),
            })
    return pd.DataFrame(rows)

def min_M_for_tolerance(dfM: pd.DataFrame, layer: int, tol_ci_halfwidth: float, stable_for: int = 2) -> Optional[int]:
    g = dfM[dfM.layer == int(layer)].sort_values("M")
    if g.empty:
        return None
    Ms = g["M"].to_numpy()
    hw = (g["ci_width"].to_numpy() / 2.0)
    ok = hw <= float(tol_ci_halfwidth)

    for i in range(len(Ms)):
        if ok[i] and (i + stable_for - 1) < len(Ms) and ok[i:i + stable_for].all():
            return int(Ms[i])
    return None

def plot_M_curve(dfM: pd.DataFrame, metric_label: str, out_path: Path, title: str = ""):
    plt.figure(figsize=(7.6, 4.4))
    for layer, g in dfM.groupby("layer"):
        g = g.sort_values("M")
        plt.plot(g["M"], g["mean"], marker="o", lw=1.6, label=f"layer {layer}")
        plt.fill_between(g["M"], g["ci_low"], g["ci_high"], alpha=0.15)
    plt.xscale("log")
    plt.xlabel("Tokens per POS (M, log scale)")
    plt.ylabel(metric_label)
    plt.title(title or f"{metric_label} stability vs M")
    plt.legend(frameon=False, fontsize="small")
    plt.tight_layout()
    plt.savefig(out_path, dpi=220)
    plt.close()

def sweep_B(reps_pos: np.ndarray, M: int, compute_once, B_max: int, B_grid: List[int], replace: bool, layers_idx: List[int]) -> pd.DataFrame:
    mu, lo, hi, A = bs_layer_loop(reps_pos, M, B_max, compute_once, replace=replace, layers_idx=layers_idx, seed=RAND_SEED + 999, return_reps=True)
    rows = []
    for B in B_grid:
        b = min(int(B), int(B_max))
        Ab = A[:b]
        mub = np.nanmean(Ab, axis=0)
        lob = np.nanpercentile(Ab, 2.5, axis=0)
        hib = np.nanpercentile(Ab, 97.5, axis=0)
        for j, l in enumerate(layers_idx):
            rows.append({
                "B": int(b),
                "layer": int(l),
                "mean": float(mub[j]),
                "ci_low": float(lob[j]),
                "ci_high": float(hib[j]),
                "ci_width": float(hib[j] - lob[j]),
                "replace": bool(replace),
            })
    return pd.DataFrame(rows)

def plot_B_curve(dfB: pd.DataFrame, metric_label: str, out_path: Path, title: str = ""):
    plt.figure(figsize=(7.6, 4.4))
    for layer, g in dfB.groupby("layer"):
        g = g.sort_values("B")
        plt.plot(g["B"], g["mean"], marker="o", lw=1.6, label=f"layer {layer}")
        plt.fill_between(g["B"], g["ci_low"], g["ci_high"], alpha=0.15)
    plt.xlabel("Bootstrap replicates (B)")
    plt.ylabel(metric_label)
    plt.title(title or f"{metric_label} stability vs B")
    plt.legend(frameon=False, fontsize="small")
    plt.tight_layout()
    plt.savefig(out_path, dpi=220)
    plt.close()


# ===========================
# METADATA DUMP
# ===========================
def dump_run_metadata(out_path: Path):
    meta = {
        "time": time.strftime("%Y-%m-%d %H:%M:%S"),
        "csv_path": str(CSV_PATH),
        "baseline": BASELINE,
        "word_rep_mode": WORD_REP_MODE,
        "exclude_pos": sorted(list(EXCLUDE_POS)),
        "seed": int(RAND_SEED),
        "device": device,
        "batch_size": int(BATCH_SIZE),
        "safe_max_tokens_per_pos": int(SAFE_MAX_TOKENS_PER_POS),
        "bootstrap": {
            "N_BOOTSTRAP_FAST": int(N_BOOTSTRAP_FAST),
            "N_BOOTSTRAP_HEAVY": int(N_BOOTSTRAP_HEAVY),
            "heavy_cap": int(HEAVY_BS_MAX_SAMP_PER_POS),
            "ci_percentiles": [2.5, 97.5],
        },
        "gride": {
            "range_max": int(DADAPY_GRID_RANGE_MAX),
            "summary": "ids[-1] (largest scale)",
            "dadapy_available": bool(HAS_DADAPY),
        },
        "stability": {
            "RUN_STABILITY": bool(RUN_STABILITY),
            "STAB_TOPK_POS": int(STAB_TOPK_POS),
            "M_grid": [int(STAB_M_MIN), int(STAB_M_MAX), int(STAB_M_GRID_POINTS)],
            "M_reps": int(STAB_M_REPS),
            "M_replace": bool(STAB_M_REPLACE),
            "tol_ci_halfwidth": float(STAB_TOL_CI_HALFWIDTH),
            "stable_for": int(STAB_STABLE_FOR),
            "RUN_B_CONVERGENCE": bool(RUN_B_CONVERGENCE),
            "B_max": int(STAB_B_MAX),
            "B_grid": STAB_B_GRID,
        },
        "packages": {},
    }

    # best-effort versions
    try:
        import numpy, pandas, transformers
        meta["packages"]["numpy"] = numpy.__version__
        meta["packages"]["pandas"] = pandas.__version__
        meta["packages"]["torch"] = torch.__version__
        meta["packages"]["transformers"] = transformers.__version__
    except Exception:
        pass
    try:
        if HAS_DADAPY:
            import dadapy
            meta["packages"]["dadapy"] = getattr(dadapy, "__version__", "unknown")
    except Exception:
        pass

    out_path.write_text(json.dumps(meta, indent=2))


# ===========================
# MAIN (SAFE POS-by-POS)
# ===========================
def run_safe_pos_pipeline():
    dump_run_metadata(PLOT_DIR / "run_metadata.json")

    # 1) load data
    df_all, word_df = load_word_df(CSV_PATH, EXCLUDE_POS)
    POS_TAGS = sorted(word_df.pos.unique().tolist())
    palette = make_class_palette(POS_TAGS)

    print(f"✓ corpus ready — {len(word_df):,} tokens across {len(POS_TAGS)} POS")
    print(f"• Device={device}  • DADApy={'yes' if HAS_DADAPY else 'no'}  • IsoScore={'pkg' if _HAS_ISOSCORE else 'fallback'}")
    print(f"• SAFE_MAX_TOKENS_PER_POS={SAFE_MAX_TOKENS_PER_POS}  • BATCH_SIZE={BATCH_SIZE}")

    # 2) build model once
    tokzr, model = build_tokenizer_and_model(BASELINE)
    L = model.config.num_hidden_layers + 1
    layers = np.arange(L)

    # 3) compute metrics POS-by-POS (safe)
    metric_to_posstats: Dict[str, Dict[str, Dict[str, np.ndarray]]] = {m: {} for m in ALL_METRICS}

    # for stability: choose top-K POS by frequency
    pos_counts = word_df.pos.value_counts()
    stab_pos_list = pos_counts.head(STAB_TOPK_POS).index.tolist()

    for pos in POS_TAGS:
        # sample for this POS only (safe cap)
        pos_df = sample_pos(word_df, pos, cap=SAFE_MAX_TOKENS_PER_POS)
        if len(pos_df) < 3:
            continue

        print(f"\n→ POS={pos}  (sampled {len(pos_df)} tokens)")

        reps_pos, filled = embed_subset(df_all, pos_df, tokzr, model, WORD_REP_MODE, BATCH_SIZE)
        pos_df = pos_df.loc[filled].reset_index(drop=True)
        reps_pos = reps_pos[:, filled]  # keep only filled vectors
        Np = reps_pos.shape[1]
        if Np < 3:
            print("  (skipping: too few embedded tokens)")
            del reps_pos
            gc.collect()
            if device == "cuda":
                torch.cuda.empty_cache()
            continue

        # compute each metric for this POS
        for metric in ALL_METRICS:
            if metric in FAST_ONCE:
                compute_once = FAST_ONCE[metric]
                n_bs = N_BOOTSTRAP_FAST
                M = Np  # fast metrics can use all
            else:
                compute_once = HEAVY_ONCE.get(metric)
                n_bs = N_BOOTSTRAP_HEAVY
                M = min(Np, HEAVY_BS_MAX_SAMP_PER_POS)

            if compute_once is None:
                print(f"  (skipping {metric}: estimator unavailable)")
                continue

            mu, lo, hi = bs_layer_loop(reps_pos, M=M, n_reps=n_bs, compute_once=compute_once, replace=True, seed=RAND_SEED)
            metric_to_posstats[metric][pos] = {"mean": mu, "lo": lo, "hi": hi, "n": np.array([Np])}

        # optional stability curves (run on top-K POS only)
        if RUN_STABILITY and pos in stab_pos_list:
            layers_idx = [0, L // 2, L - 1]
            for metric in [m for m in ALL_METRICS if m in HEAVY_ONCE or m in FAST_ONCE]:
                if metric in FAST_ONCE:
                    compute_once = FAST_ONCE[metric]
                    tol = 0.01  # typical for [0,1] scores
                else:
                    compute_once = HEAVY_ONCE.get(metric)
                    tol = STAB_TOL_CI_HALFWIDTH
                if compute_once is None:
                    continue

                Ms = make_M_grid(Np, STAB_M_MIN, STAB_M_MAX, STAB_M_GRID_POINTS)
                dfM = sweep_M(reps_pos, Ms, compute_once, STAB_M_REPS, STAB_M_REPLACE, layers_idx)
                dfM["pos"] = pos
                dfM["metric"] = metric
                dfM["model"] = BASELINE
                dfM.to_csv(STAB_DIR / f"stability_M_{metric}_{pos}_{BASELINE}.csv", index=False)
                plot_M_curve(dfM, LABELS.get(metric, metric), STAB_DIR / f"stability_M_{metric}_{pos}_{BASELINE}.png",
                             title=f"{LABELS.get(metric, metric)} • POS={pos} • stability vs M")

                # report M* for last layer
                m_star = min_M_for_tolerance(dfM, layer=L - 1, tol_ci_halfwidth=tol, stable_for=STAB_STABLE_FOR)
                print(f"  stability({metric}): suggested M* (layer {L-1}) = {m_star}  (tol halfwidth={tol})")

                if RUN_B_CONVERGENCE:
                    M_fixed = min(Np, HEAVY_BS_MAX_SAMP_PER_POS) if metric not in FAST_ONCE else Np
                    dfB = sweep_B(reps_pos, M_fixed, compute_once, STAB_B_MAX, STAB_B_GRID, STAB_B_REPLACE, layers_idx)
                    dfB["pos"] = pos
                    dfB["metric"] = metric
                    dfB["model"] = BASELINE
                    dfB["M_fixed"] = int(M_fixed)
                    dfB.to_csv(STAB_DIR / f"stability_B_{metric}_{pos}_{BASELINE}.csv", index=False)
                    plot_B_curve(dfB, LABELS.get(metric, metric), STAB_DIR / f"stability_B_{metric}_{pos}_{BASELINE}.png",
                                 title=f"{LABELS.get(metric, metric)} • POS={pos} • stability vs B (M={M_fixed})")

        # cleanup per POS (critical for kernel safety)
        del reps_pos
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()

    # 4) Save + plot per metric (small objects only)
    for metric, pos_to_stats in metric_to_posstats.items():
        if not pos_to_stats:
            continue
        # normalize shape for n
        for p in list(pos_to_stats.keys()):
            n_arr = pos_to_stats[p].get("n")
            if isinstance(n_arr, np.ndarray) and n_arr.size == 1:
                pos_to_stats[p]["n"] = int(n_arr[0])  # type: ignore

        save_metric_csv_all_pos(metric, pos_to_stats, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(pos_to_stats, layers, metric,
                            title=f"{LABELS.get(metric, metric)} • {BASELINE}",
                            out_path=PLOT_DIR / f"raw_{metric}_{BASELINE}.png",
                            palette=palette)
        print(f"\n✓ saved metric {metric}:")
        print(f"  CSV:  {CSV_DIR / f'pos_raw_{metric}_{BASELINE}.csv'}")
        print(f"  Plot: {PLOT_DIR / f'raw_{metric}_{BASELINE}.png'}")

    # final cleanup
    del model
    gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

    print("\n✓ done.")


if __name__ == "__main__":
    run_safe_pos_pipeline()


✓ corpus ready — 189,167 tokens across 13 POS
• Device=cuda  • DADApy=yes  • IsoScore=fallback
• SAFE_MAX_TOKENS_PER_POS=5000  • BATCH_SIZE=1

→ POS=ADJ  (sampled 5000 tokens)


bert-base-uncased embed:  80%|██████████▍  | 2968/3687 [00:21<00:05, 140.26it/s]


KeyboardInterrupt: 