In [3]:
from __future__ import annotations
import os, ast, random, inspect
from pathlib import Path
from typing import Dict, List
import numpy as np, pandas as pd, torch
import torch.utils.data as torchdata
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoTokenizer, AutoModel, BertModel, AutoConfig
from IsoScore import IsoScore
from dadapy import Data
from skdim.id import MLE, MOM, TLE, CorrInt, FisherS, lPCA

In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast

# =========================== Optional deps ===========================
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

# IsoScore: library if available, else a monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean / max eigenvalue in [0,1] (↑ ~ more isotropic)
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH   = "en_ewt-ud-train_sentences.csv"

# Set to "gpt2" (decoder) or "bert-base-uncased" (encoder)
BASELINE   = "bert-base-uncased"          # <-- you said you use GPT-2
WORD_REP_MODE = "first"       # <-- and "last"

# No per-class cap for the fast metrics
RAW_MAX_PER_CLASS = int(1e12)

# Cap overrides: cap only arity "0" to 30_000; others use RAW_MAX_PER_CLASS
PER_CLASS_CAPS: Dict[str, int] = {"0": 50_000}

# Bootstrap replicates (tune down if slow)
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 200

# Per-replicate sample size (M = min(cap, N_class))
FAST_BS_MAX_SAMP_PER_CLASS  = int(1e12)
HEAVY_BS_MAX_SAMP_PER_CLASS = 5000

RAND_SEED=42
PLOT_DIR     = Path("results_ARITY"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables_ARITY") / "arity_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Throughput (raise if you have more GPU memory)
BATCH_SIZE = 1

# Reproducibility & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine num_hidden_layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden_size from model.config")
    return int(d)

def _is_gpt_like(model) -> bool:
    mt = str(getattr(model.config, "model_type", "")).lower()
    name = str(getattr(getattr(model, "name_or_path", ""), "lower", lambda: "")())
    return ("gpt2" in mt) or ("gpt2" in name)

def _pick_arity_col(df: pd.DataFrame) -> str:
    for cand in ["arity", "ariety", "ARITY", "Arity"]:
        if cand in df.columns: return cand
    raise ValueError("No arity column found. Expected one of: arity, ariety, ARITY, Arity")

# ========= Per-subsample single-value compute functions (used inside bootstrap) =========
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

# =============================== DATA (arity already in CSV) ===============================
def load_arity_df(csv_path: str):
    """
    Reads sentence-level rows with list columns: tokens, <arity>.
    Expands to one row per token with an arity class in {0,1,2,3,4} (4 = 4+).
    Returns:
      df_sent (sent-level with tokens),
      df_tok  (token-level: sentence_id, word_id, arity_class, word)
    """
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens"], dtype={"sentence_id": str})
    df_all = pd.read_csv(csv_path)
    ar_col = _pick_arity_col(df_all)
    df[ar_col] = df_all[ar_col]

    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df[ar_col] = df[ar_col].apply(_to_list)

    rows = []
    for sid, toks, A in df[["sentence_id","tokens", ar_col]].itertuples(index=False):
        L = min(len(toks), len(A))
        for wid, (tok, a) in enumerate(zip(toks[:L], A[:L])):
            try:
                ai = int(a)
            except Exception:
                ai = 0
            cl = str(min(max(ai, 0), 4))   # cap at 4 (means 4+)
            rows.append((sid, wid, cl, tok))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","arity_class","word"])
    return df[["sentence_id","tokens"]].drop_duplicates("sentence_id"), df_tok

def sample_raw(df_tok: pd.DataFrame,
               per_class_cap: int = RAW_MAX_PER_CLASS,
               per_class_caps: Dict[str, int] | None = None) -> pd.DataFrame:
    """Per-class cap without frequency matching (with optional overrides per class)."""
    picks = []
    caps = per_class_caps or {}
    for c, sub in df_tok.groupby("arity_class", sort=False):
        cap = caps.get(str(c), per_class_cap)
        n = min(len(sub), cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    base_colors: List[Tuple[float, float, float]] = []
    for name in ("tab20", "tab20b", "tab20c"):
        try: base_colors.extend(sns.color_palette(name, 20))
        except Exception: pass
    if len(base_colors) < len(classes):
        base_colors = list(sns.color_palette("husl", len(classes)))
    ordered = list(sorted(classes, key=lambda s: int(s) if s.isdigit() else 99))
    return {cls: base_colors[i % len(base_colors)] for i, cls in enumerate(ordered)}

# =============================== TOKENIZER/MODEL LOADING ===============================
def _load_tokenizer_and_model(baseline: str):
    """
    Robustly load tokenizer+model. For GPT-2, prefer GPT2TokenizerFast and
    try both 'gpt2' and 'openai-community/gpt2' repo IDs.
    """
    candidates = [baseline]
    b = baseline.lower()
    if "gpt2" in b:
        # try both IDs (some environments only have one cached / accessible)
        if baseline != "openai-community/gpt2":
            candidates.append("openai-community/gpt2")
        if baseline != "gpt2":
            candidates.append("gpt2")

    last_err = None
    for mid in candidates:
        try:
            if "gpt2" in mid.lower():
                tokzr = GPT2TokenizerFast.from_pretrained(mid, add_prefix_space=True)
            else:
                tokzr = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
            model = AutoModel.from_pretrained(mid, output_hidden_states=True)
            return tokzr, model, mid
        except Exception as e:
            last_err = e
            continue
    # If we got here, surface the most informative error
    raise RuntimeError(f"Failed to load tokenizer/model for any of {candidates}: {last_err}")

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str]:
    """
    Return (reps, filled, rep_mode_used).
    reps shape: (L, N, D) where L includes the embedding layer (layer 0).
    """
    df_sent["sentence_id"]  = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    # Robust tokenizer/model load (fixes GPT-2 "vocab_file NoneType" crashes)
    tokzr, model, model_id_used = _load_tokenizer_and_model(baseline)

    # Pre-tokenized input with word mapping
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)

    # GPT‑2 specifics: add_prefix_space + pad token
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True
    if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
        tokzr.pad_token = tokzr.eos_token  # safe default for GPT-2

    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id
    if device == "cuda": model.half()
    model = model.eval().to(device)

    L = _num_hidden_layers(model) + 1   # include embeddings
    D = _hidden_size(model)
    N = len(subset_df)
    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    # Choose/validate rep mode depending on model family
    gpt_like = _is_gpt_like(model)
    if gpt_like:
        rep_mode = word_rep_mode if word_rep_mode in {"last","mean"} else "last"
    else:
        rep_mode = word_rep_mode if word_rep_mode in {"first","mean","last"} else "first"

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id_used} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # Map word_id -> token positions for this item
                word_map = {}
                # NOTE: word_ids() requires a *fast* tokenizer; ensured by GPT2TokenizerFast above.
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        word_map.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = word_map.get(wid)
                    if not toks: continue
                    if rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, rep_mode

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:    A[r, l] = float(compute_once(X))
            except Exception: A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# ---- Metric registries ----
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,     # IsoScore
    "pca99": _pca99_once  # lPCA @ 0.99 explained variance
}

HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "gride": _dadapy_gride_once
}

LABELS = {
    "pca99":"lPCA 0.99",
    "gride":"GRIDE"
}

# Keep runtime reasonable (add more if you want)
ALL_METRICS = ["gride", "pca99", "iso"]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw",
                                rep_mode_used: str = ""):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "arity",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)),
                "word_rep_mode": rep_mode_used or WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"arity_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(9, 5))
    for c in sorted(class_to_stats.keys(), key=lambda s: int(s)):
        stats = class_to_stats[c]
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = palette.get(c) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=3, fontsize="small", title="Arity (4 = 4+)", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_arity_pipeline():
    # 1) Load tokens + arity classes
    df_sent, ar_df = load_arity_df(CSV_PATH)
    classes = sorted(ar_df.arity_class.unique(), key=lambda s: int(s))
    palette = make_class_palette(classes)
    print(f"✓ corpus ready — {len(ar_df):,} tokens across arity classes {classes}")

    # 2) Optional per-class cap (cap arity "0" to 30k via PER_CLASS_CAPS)
    raw_df = sample_raw(ar_df, RAW_MAX_PER_CLASS, PER_CLASS_CAPS)
    print("Sample sizes per arity (raw cap):")
    print(raw_df.arity_class.value_counts().sort_index().to_dict())


    type_counts = (
        raw_df.groupby("arity_class")["word"]
              .nunique()
              .sort_index()
              .to_dict()
    )
    print("Unique word types per arity (with class 0 capped at 30k tokens):")
    print(type_counts)


    # 3) Embed once (encoder or decoder; rep mode auto-validated)
    reps, filled, rep_mode_used = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.arity_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}  • rep_mode={rep_mode_used}")

    # 4) Metric loop
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # choose bootstrap budget
        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        save_metric_csv_all_classes(metric, class_results, layers, BASELINE,
                                    subset_name="raw", rep_mode_used=rep_mode_used)
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE} • rep={rep_mode_used}",
                            out_path=PLOT_DIR / f"arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png",
                            palette=palette)
        print(f"  ✓ saved: CSV= {CSV_DIR}/arity_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_arity_pipeline()


✓ corpus ready — 194,916 tokens across arity classes ['0', '1', '2', '3', '4']
Sample sizes per arity (raw cap):
{'0': 50000, '1': 20239, '2': 18037, '3': 13546, '4': 14030}
Unique word types per arity (with class 0 capped at 30k tokens):
{'0': 4936, '1': 7099, '2': 6408, '3': 4932, '4': 4798}



bert-base-uncased (embed subset):   0%|               | 0/10067 [00:00<?, ?it/s][A
bert-base-uncased (embed subset):   0%|     | 11/10067 [00:00<01:35, 105.22it/s][A
bert-base-uncased (embed subset):   0%|     | 22/10067 [00:00<01:35, 105.19it/s][A
bert-base-uncased (embed subset):   0%|     | 33/10067 [00:00<01:34, 106.72it/s][A
bert-base-uncased (embed subset):   0%|     | 44/10067 [00:00<01:35, 104.60it/s][A
bert-base-uncased (embed subset):   1%|     | 55/10067 [00:00<01:36, 103.98it/s][A
bert-base-uncased (embed subset):   1%|     | 66/10067 [00:00<01:35, 104.54it/s][A
bert-base-uncased (embed subset):   1%|     | 78/10067 [00:00<01:33, 106.91it/s][A
bert-base-uncased (embed subset):   1%|     | 89/10067 [00:00<01:33, 106.49it/s][A
bert-base-uncased (embed subset):   1%|    | 100/10067 [00:00<01:36, 102.80it/s][A
bert-base-uncased (embed subset):   1%|    | 111/10067 [00:01<01:35, 104.48it/s][A
bert-base-uncased (embed subset):   1%|    | 122/10067 [00:01<01:35, 103.75

✓ embedded 115,852 tokens  • layers=13  • rep_mode=first

→ Computing metric: gride …


In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast

# Plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.colors as pc

# =============================== CONFIG ===============================
CSV_PATH       = "en_ewt-ud-train_sentences.csv"   # needs: sentence_id, tokens (list[str]), arity (list[int])
BASELINE       = "bert-base-uncased"                            # or "openai-community/gpt2" or "bert-base-uncased"
WORD_REP_MODE  = "first"                            # BERT: {"first","last","mean"}; GPT-2: {"last","mean"}

# Optional: subsample for plotting smoothness (per class)
PLOT_MAX_PER_CLASS = None                          # e.g., 4000; None = use all tokens

# Output
OUT_DIR  = Path("pca3d_arity"); OUT_DIR.mkdir(parents=True, exist_ok=True)
HTML_OUT = OUT_DIR / f"{BASELINE.replace('/','_')}_arity_pca3d_layers.html"

# Throughput / device
BATCH_SIZE = 2
RAND_SEED  = 42
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)   # GPT-2
    if n is None: raise ValueError("Cannot determine num_hidden_layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)    # GPT-2
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

def _is_gpt_like(model) -> bool:
    mt = str(getattr(model.config, "model_type", "")).lower()
    name = str(getattr(getattr(model, "name_or_path", ""), "lower", lambda: "")())
    return ("gpt2" in mt) or ("gpt2" in name)

def _load_tok_and_model(model_id: str):
    """
    Robust loader:
    - GPT‑2: force *fast* tokenizer, try both 'gpt2' and 'openai-community/gpt2'
      (fixes NoneType vocab path issues).
    - Right padding; set PAD=EOS if missing (safe for batched inference).
    """
    cands = [model_id]
    if "gpt2" in model_id.lower():
        if model_id != "openai-community/gpt2": cands.append("openai-community/gpt2")
        if model_id != "gpt2": cands.append("gpt2")

    last_err = None
    for mid in cands:
        try:
            if "gpt2" in mid.lower():
                tok = GPT2TokenizerFast.from_pretrained(mid, add_prefix_space=True)
            else:
                tok = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)

            # Right padding + PAD token (GPT‑2 has no pad by default)
            if getattr(tok, "padding_side", None) != "right":
                tok.padding_side = "right"
            if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
                tok.pad_token = tok.eos_token

            mdl = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
                mdl.config.pad_token_id = tok.pad_token_id
            mdl = mdl.eval().to(device)
            if device == "cuda":
                mdl.half()
            return tok, mdl, mid
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"Failed to load tokenizer/model for {cands}: {last_err}")

# =============================== DATA (arity already in CSV) ===============================
def load_arity_df(csv_path: str):
    """
    Reads sentence-level rows with list columns: tokens, arity (list[int]).
    Expands to one row per token with arity_class in {0,1,2,3,4} (4 = 4+).
    Returns:
      df_sent: sentence_id + tokens
      df_tok : sentence_id, word_id, arity_class, word
    """
    df_all = pd.read_csv(csv_path)
    if "arity" not in df_all.columns:
        # try a few common misspellings/variants
        for cand in ["ariety", "ARITY", "Arity"]:
            if cand in df_all.columns:
                df_all = df_all.rename(columns={cand: "arity"})
                break
    if "arity" not in df_all.columns:
        raise ValueError("CSV must contain a list[int] column named 'arity' (or ariety/ARITY/Arity).")

    df = df_all[["sentence_id","tokens","arity"]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df.arity  = df.arity.apply(_to_list)

    rows = []
    for sid, toks, A in df[["sentence_id","tokens","arity"]].itertuples(index=False):
        L = min(len(toks), len(A))
        for wid in range(L):
            try:
                ai = int(A[wid])
            except Exception:
                ai = 0
            cl = str(min(max(ai, 0), 4))   # cap at 4 (means 4+)
            rows.append((sid, wid, cl, toks[wid]))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","arity_class","word"])
    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    if df_tok.empty:
        raise ValueError("No token rows constructed—check your 'arity' column contents.")
    return df_sent, df_tok

def sample_per_class(df_tok: pd.DataFrame, per_class_cap: int | None) -> pd.DataFrame:
    """Optional per-class subsample for plotting."""
    if per_class_cap is None:
        return df_tok.reset_index(drop=True)
    picks = []
    for c, sub in df_tok.groupby("arity_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

# =============================== EMBEDDING ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str, str]:
    """
    Return (reps (L,N,D), filled mask (N,), rep_mode_used, model_tag).
    """
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_tag = _load_tok_and_model(baseline)
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True  # needed for byte-level BPE when pre-tokenized

    # Choose/validate rep mode depending on model family
    rep_mode = word_rep_mode
    if _is_gpt_like(model) and rep_mode not in {"last","mean"}:
        rep_mode = "last"

    L = _num_hidden_layers(model) + 1   # include embeddings
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_tag} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                wids = enc_be.word_ids(b)  # fast tokenizer is required for word_ids()
                if wids is None:
                    raise RuntimeError("Fast tokenizer required: word_ids() unavailable.")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} tokens (skipped in PCA).")
        reps = reps[:, filled]
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, rep_mode, model_tag

# =============================== PCA 3D PER LAYER ===============================
def _pca3d_layer(X: np.ndarray, n_components: int = 3) -> np.ndarray:
    """Lightweight PCA to 3D via SVD (no sklearn dependency)."""
    X = X.astype(np.float32, copy=False)
    Xc = X - X.mean(0, keepdims=True)
    U, S, _ = np.linalg.svd(Xc, full_matrices=False)
    return (U[:, :n_components] * S[:n_components]).astype(np.float32, copy=False)  # (n,3)

def _qual_palette_for_classes(classes: List[str]) -> Dict[str, str]:
    """
    Build a discrete palette: stack multiple qualitative sets; if still short,
    sample a continuous scale (Turbo) with evenly spaced samples.
    """
    seqs = [
        px.colors.qualitative.Bold,      # 10
        px.colors.qualitative.D3,        # 10
        px.colors.qualitative.Vivid,     # 11
        px.colors.qualitative.Safe,      # 11
        px.colors.qualitative.Alphabet,  # 26
        px.colors.qualitative.Dark24,    # 24
        px.colors.qualitative.Light24,   # 24
        px.colors.qualitative.Set3,      # 12
        px.colors.qualitative.Set2,      # 8
        px.colors.qualitative.Set1,      # 9
        px.colors.qualitative.Pastel2,   # 8
        px.colors.qualitative.Pastel1,   # 9,
    ]
    pool = []
    for s in seqs:
        pool.extend(s)
    k = len(classes)
    if len(pool) < k:
        t = np.linspace(0.0, 1.0, k, endpoint=True)
        colors = [pc.sample_colorscale("Turbo", [ti])[0] for ti in t]
    else:
        colors = pool[:k]
    return {cls: colors[i] for i, cls in enumerate(classes)}

def pca3d_by_arity_and_plot(reps: np.ndarray,
                            words: List[str],
                            classes_arr: np.ndarray,
                            all_classes: List[str],
                            model_tag: str,
                            html_out: Path):
    """
    Build one 3D scatter trace per class per layer; a slider toggles layers.
    """
    L, N, D = reps.shape
    print(f"PCA plotting on {N:,} tokens across {L} layers…")

    # PCA per layer
    Y_layers: List[np.ndarray] = []
    for l in range(L):
        Y_layers.append(_pca3d_layer(reps[l]))  # (N,3)

    cmap = _qual_palette_for_classes(all_classes)

    traces = []
    n_per_layer = len(all_classes)

    for l in range(L):
        Y = Y_layers[l]
        show_legend = (l == 0)  # keep legend only for layer 0
        for j, c in enumerate(all_classes):
            mask = (classes_arr == c)
            if not np.any(mask):
                x = y = z = []; hov = []
            else:
                x, y, z = Y[mask, 0], Y[mask, 1], Y[mask, 2]
                hov = [f"{w} | arity={c}" for w in np.asarray(words)[mask]]

            traces.append(
                go.Scatter3d(
                    x=x, y=y, z=z,
                    mode="markers",
                    marker=dict(size=2, opacity=0.75, color=cmap[c]),
                    name=c,
                    legendgroup=c,
                    showlegend=show_legend,
                    hovertext=hov,
                    # NOTE: do NOT format this string (contains %{...} placeholders)
                    hovertemplate=(
                        "<b>%{hovertext}</b><br>"
                        "x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}"
                        "<extra></extra>"
                    ),
                    visible=(l == 0),
                )
            )

    # Slider: toggle visibility per selected layer
    n_total = n_per_layer * L
    steps = []
    for l in range(L):
        vis = [False] * n_total
        start = l * n_per_layer
        vis[start : start + n_per_layer] = [True] * n_per_layer
        steps.append(dict(
            method="update",
            args=[{"visible": vis},
                  {"title": f"{model_tag} • PCA 3D by arity • Layer {l} (drag to rotate)"}],
            label=str(l),
        ))

    sliders = [dict(
        active=0,
        steps=steps,
        currentvalue={"prefix": "Layer: "},
        pad={"t": 10}
    )]

    layout = go.Layout(
        title=f"{model_tag} • PCA 3D by arity • Layer 0 (drag to rotate)",
        scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3", aspectmode="data"),
        margin=dict(l=0, r=0, b=0, t=40),
        sliders=sliders,
        showlegend=True,
        legend=dict(title="arity (4 = 4+)", itemsizing="trace")
    )

    fig = go.Figure(data=traces, layout=layout)
    fig.show()
    fig.write_html(str(html_out), include_plotlyjs="cdn")
    print("✓ Saved interactive HTML to:", html_out.resolve())

# =============================== DRIVER ===============================
def run_pca3d_arity():
    # 1) Load arity classes
    df_sent, ar_df = load_arity_df(CSV_PATH)
    classes = sorted(ar_df.arity_class.unique(), key=lambda s: int(s))
    print(f"✓ plotting subset — {len(ar_df):,} tokens across arity classes {classes}")

    # 2) Optional per-class cap
    raw_df = sample_per_class(ar_df, PLOT_MAX_PER_CLASS)

    # 3) Embed once
    reps, filled, rep_mode, model_tag = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)

    # 4) Labels & hover text
    cls_arr = raw_df.arity_class.values.astype(str)
    words   = raw_df.word.astype(str).tolist()

    # 5) PCA→3D per layer + Plotly
    pca3d_by_arity_and_plot(
        reps.astype(np.float32, copy=False), words, cls_arr, classes,
        model_tag=f"{model_tag} (rep={rep_mode})", html_out=HTML_OUT
    )

    # Cleanup
    del reps; gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

if __name__ == "__main__":
    run_pca3d_arity()


✓ plotting subset — 194,916 tokens across arity classes ['0', '1', '2', '3', '4']



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.

bert-base-uncased (embed subset):  57%|██▊  | 2850/5034 [00:28<00:22, 96.80it/s]

In [1]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# =========================== Optional deps ===========================
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

# IsoScore: library if available, else a monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean / max eigenvalue in [0,1] (↑ ~ more isotropic)
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
# Update this path if your file has a different name (e.g., "en_ewt-ud-train_sentences (2).csv")
CSV_PATH   = "en_ewt-ud-train_sentences.csv"

# Set to "gpt2" (decoder) or "bert-base-uncased" (encoder)
BASELINE   = "gpt2"
WORD_REP_MODE = "last"  # GPT-2: {"last","mean"}

# Exclude FIRST word (index 0) of each sentence from analysis
EXCLUDE_FIRST_WORD = True

# No per-class cap for the fast metrics
RAW_MAX_PER_CLASS = int(1e12)

# Cap overrides: cap only arity "0" to 30_000; others use RAW_MAX_PER_CLASS
PER_CLASS_CAPS: Dict[str, int] = {"0": 30_000}

# Bootstrap replicates (tune down if slow)
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 20

# Per-replicate sample size (M = min(cap, N_class))
FAST_BS_MAX_SAMP_PER_CLASS  = int(1e12)
HEAVY_BS_MAX_SAMP_PER_CLASS = 5000

RAND_SEED=42
PLOT_DIR     = Path("results_ARITY_no_index"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables_ARITY_no_index") / "arity_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Throughput (raise if you have more GPU memory)
BATCH_SIZE = 1

# Reproducibility & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine num_hidden_layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden_size from model.config")
    return int(d)

def _is_gpt_like(model) -> bool:
    mt = str(getattr(model.config, "model_type", "")).lower()
    name = str(getattr(getattr(model, "name_or_path", ""), "lower", lambda: "")())
    return ("gpt2" in mt) or ("gpt2" in name)

def _pick_arity_col(df: pd.DataFrame) -> str:
    for cand in ["arity", "ariety", "ARITY", "Arity"]:
        if cand in df.columns: return cand
    raise ValueError("No arity column found. Expected one of: arity, ariety, ARITY, Arity")

# -------- Safe tokenizer+model loader (fixes GPT-2 NoneType path error) --------
def _safe_load_tok_and_model(baseline: str):
    """
    Robust tokenizer+model loader for GPT-2 and others.
    - Patches os.path.isfile during tokenizer load to avoid NoneType path errors.
    - Falls back between 'gpt2' and 'openai-community/gpt2'.
    Returns: (tokzr, model, model_id_used)
    """
    import os as _os

    bl = baseline.lower()
    if "gpt2" in bl:
        candidates = []
        for mid in (baseline, "openai-community/gpt2", "gpt2"):
            if mid not in candidates:
                candidates.append(mid)
    else:
        candidates = [baseline]

    last_err = None
    for mid in candidates:
        try:
            # Patch during tokenizer load only
            _orig_isfile = _os.path.isfile
            _os.path.isfile = (lambda p: False if p is None else _orig_isfile(p))
            try:
                tokzr = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
            finally:
                _os.path.isfile = _orig_isfile

            # Right padding + pad token (for GPT-2)
            if getattr(tokzr, "padding_side", None) != "right":
                tokzr.padding_side = "right"
            if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
                tokzr.pad_token = tokzr.eos_token

            # Pass add_prefix_space=True at call time if supported
            tok_call_kwargs = {}
            try:
                if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
                    tok_call_kwargs["add_prefix_space"] = True
            except Exception:
                pass
            tokzr._safe_call_kwargs = tok_call_kwargs

            model = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
                model.config.pad_token_id = tokzr.pad_token_id

            return tokzr, model, mid
        except Exception as e:
            last_err = e
            continue

    raise RuntimeError(f"Failed to load tokenizer/model for {baseline} "
                       f"(tried {candidates}). Last error: {last_err}")

# ========= Per-subsample single-value compute functions (used inside bootstrap) =========
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

# =============================== DATA (arity already in CSV) ===============================
def load_arity_df(csv_path: str):
    """
    Reads sentence-level rows with list columns: tokens, <arity>.
    Expands to one row per token with an arity class in {0,1,2,3,4} (4 = 4+).
    EXCLUDES the first token (word_id == 0) of each sentence if EXCLUDE_FIRST_WORD=True.
    Returns:
      df_sent (sent-level with tokens),
      df_tok  (token-level: sentence_id, word_id, arity_class, word)
    """
    if csv_path is None:
        raise TypeError("CSV_PATH is None — set CSV_PATH to your dataset filename.")
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens"], dtype={"sentence_id": str})
    df_all = pd.read_csv(csv_path)
    ar_col = _pick_arity_col(df_all)
    df[ar_col] = df_all[ar_col]

    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df[ar_col] = df[ar_col].apply(_to_list)

    rows = []
    for sid, toks, A in df[["sentence_id","tokens", ar_col]].itertuples(index=False):
        L = min(len(toks), len(A))
        for wid, (tok, a) in enumerate(zip(toks[:L], A[:L])):
            # ---- EXCLUDE FIRST WORD ----
            if EXCLUDE_FIRST_WORD and wid == 0:
                continue
            try:
                ai = int(a)
            except Exception:
                ai = 0
            cl = str(min(max(ai, 0), 4))   # cap at 4 (means 4+)
            rows.append((sid, wid, cl, tok))

    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","arity_class","word"])

    # Safety: enforce again
    if EXCLUDE_FIRST_WORD and not df_tok.empty:
        df_tok = df_tok[df_tok.word_id != 0].reset_index(drop=True)

    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    return df_sent, df_tok

def sample_raw(df_tok: pd.DataFrame,
               per_class_cap: int = RAW_MAX_PER_CLASS,
               per_class_caps: Dict[str, int] | None = None) -> pd.DataFrame:
    """Per-class cap without frequency matching (with optional overrides per class)."""
    picks = []
    caps = per_class_caps or {}
    for c, sub in df_tok.groupby("arity_class", sort=False):
        cap = caps.get(str(c), per_class_cap)
        n = min(len(sub), cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    base_colors: List[Tuple[float, float, float]] = []
    for name in ("tab20", "tab20b", "tab20c"):
        try: base_colors.extend(sns.color_palette(name, 20))
        except Exception: pass
    if len(base_colors) < len(classes):
        base_colors = list(sns.color_palette("husl", len(classes)))
    ordered = list(sorted(classes, key=lambda s: int(s) if s.isdigit() else 99))
    return {cls: base_colors[i % len(base_colors)] for i, cls in enumerate(ordered)}

# =============================== TOKENIZER/MODEL (robust) ===============================
def _load_tokenizer_and_model(baseline: str):
    """Backward-compatible alias using the safe loader above."""
    return _safe_load_tok_and_model(baseline)

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str]:
    """
    Return (reps, filled, rep_mode_used).
    reps shape: (L, N, D) where L includes the embedding layer (layer 0).
    """
    df_sent["sentence_id"]  = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    # Robust tokenizer/model load (fixes GPT-2 "NoneType path" crashes)
    tokzr, model, model_id_used = _load_tokenizer_and_model(baseline)

    # Pre-tokenized input with word mapping
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    # If tokenizer supports add_prefix_space, supply it
    enc_kwargs.update(getattr(tokzr, "_safe_call_kwargs", {}))

    # GPT‑2 specifics: ensure pad token is set
    if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
        tokzr.pad_token = tokzr.eos_token
    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id

    if device == "cuda": model.half()
    model = model.eval().to(device)

    L = _num_hidden_layers(model) + 1   # include embeddings
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float32)  # float32 for stable SVD/stats
    filled = np.zeros(N, dtype=bool)

    # Choose/validate rep mode depending on model family
    gpt_like = _is_gpt_like(model)
    if gpt_like:
        rep_mode = word_rep_mode if word_rep_mode in {"last","mean"} else "last"
    else:
        rep_mode = word_rep_mode if word_rep_mode in {"first","mean","last"} else "first"

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id_used} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # Map word_id -> token positions for this item
                word_map = {}
                wids = enc_be.word_ids(b)  # fast tokenizer needed
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        word_map.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = word_map.get(wid)
                    if not toks: continue
                    if rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float32, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, rep_mode

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:    A[r, l] = float(compute_once(X))
            except Exception: A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# ---- Metric registries ----
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,     # IsoScore
    "pca99": _pca99_once  # lPCA @ 0.99 explained variance
}

HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "gride": _dadapy_gride_once
}

LABELS = {
    "pca99":"lPCA 0.99",
    "gride":"GRIDE"
}

# Keep runtime reasonable (add more if you want)
ALL_METRICS = ["pca99"]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw",
                                rep_mode_used: str = ""):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "arity",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)),
                "word_rep_mode": rep_mode_used or WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"arity_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(9, 5))
    for c in sorted(class_to_stats.keys(), key=lambda s: int(s)):
        stats = class_to_stats[c]
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = palette.get(c) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=3, fontsize="small", title="Arity (4 = 4+)", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_arity_pipeline():
    # 1) Load tokens + arity classes (with first-word exclusion)
    df_sent, ar_df = load_arity_df(CSV_PATH)
    classes = sorted(ar_df.arity_class.unique(), key=lambda s: int(s))
    palette = make_class_palette(classes)
    print(f"✓ corpus ready — {len(ar_df):,} tokens across arity classes {classes}")

    # 2) Optional per-class cap (cap arity "0" to 30k via PER_CLASS_CAPS)
    raw_df = sample_raw(ar_df, RAW_MAX_PER_CLASS, PER_CLASS_CAPS)
    print("Sample sizes per arity (raw cap):")
    print(raw_df.arity_class.value_counts().sort_index().to_dict())

    # 3) Embed once (encoder or decoder; rep mode auto-validated)
    reps, filled, rep_mode_used = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.arity_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}  • rep_mode={rep_mode_used}")

    # 4) Metric loop
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # choose bootstrap budget
        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        save_metric_csv_all_classes(metric, class_results, layers, BASELINE,
                                    subset_name="raw", rep_mode_used=rep_mode_used)
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE} • rep={rep_mode_used}",
                            out_path=PLOT_DIR / f"arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png",
                            palette=palette)
        print(f"  ✓ saved: CSV= {CSV_DIR}/arity_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_arity_pipeline()


✓ corpus ready — 184,849 tokens across arity classes ['0', '1', '2', '3', '4']
Sample sizes per arity (raw cap):
{'0': 30000, '1': 19369, '2': 17546, '3': 13383, '4': 13911}


2025-11-12 11:48:07.472787: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
openai-community/gpt2 (embed subset): 100%|█| 10067/10067 [01:30<00:00, 110.65it


✓ embedded 94,209 tokens  • layers=13  • rep_mode=last

→ Computing metric: pca99 …
  ✓ saved: CSV= tables_ARITY_no_index/arity_bootstrap/arity_raw_pca99_gpt2.csv  plot= results_ARITY_no_index/arity_raw_pca99_gpt2_rep-last.png

✓ done (incremental outputs produced per metric).


## Fine Grained

In [11]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# =========================== Optional deps ===========================
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

# IsoScore: library if available, else a monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            # mean / max eigenvalue in [0,1] (↑ ~ more isotropic)
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
# =============================== CONFIG ===============================
CSV_PATH   = "en_ewt-ud-train_sentences.csv"

BASELINE   = "bert-base-uncased"
WORD_REP_MODE = "first"  # GPT-2: {"last","mean"}

EXCLUDE_FIRST_WORD = True

POS_COL      = "pos"     # column in your CSV with POS tags
KEEP_POS_TAG = "VERB"    # only keep tokens whose POS == "VERB"


# No per-class cap for the fast metrics
RAW_MAX_PER_CLASS = int(1e12)

# Cap overrides: cap only arity "0" to 30_000; others use RAW_MAX_PER_CLASS
PER_CLASS_CAPS: Dict[str, int] = {"0": 30_000}

# Bootstrap replicates (tune down if slow)
N_BOOTSTRAP_FAST   = 50
N_BOOTSTRAP_HEAVY  = 20

# Per-replicate sample size (M = min(cap, N_class))
FAST_BS_MAX_SAMP_PER_CLASS  = int(1e12)
HEAVY_BS_MAX_SAMP_PER_CLASS = 5000

RAND_SEED=42
PLOT_DIR     = Path("results_ARITY_no_index"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR      = Path("tables_ARITY_no_index") / "arity_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

# Throughput (raise if you have more GPU memory)
BATCH_SIZE = 1

# Reproducibility & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

# Seaborn style
sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120

EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    """Eigenvalues of covariance up to a constant via SVD of centered X (descending)."""
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """Add tiny noise if there are duplicate rows (helps NN-based estimators)."""
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine num_hidden_layers from model.config")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden_size from model.config")
    return int(d)

def _is_gpt_like(model) -> bool:
    mt = str(getattr(model.config, "model_type", "")).lower()
    name = str(getattr(getattr(model, "name_or_path", ""), "lower", lambda: "")())
    return ("gpt2" in mt) or ("gpt2" in name)

def _pick_arity_col(df: pd.DataFrame) -> str:
    for cand in ["arity", "ariety", "ARITY", "Arity"]:
        if cand in df.columns: return cand
    raise ValueError("No arity column found. Expected one of: arity, ariety, ARITY, Arity")

# -------- Safe tokenizer+model loader (fixes GPT-2 NoneType path error) --------
def _safe_load_tok_and_model(baseline: str):
    """
    Robust tokenizer+model loader for GPT-2 and others.
    - Patches os.path.isfile during tokenizer load to avoid NoneType path errors.
    - Falls back between 'gpt2' and 'openai-community/gpt2'.
    Returns: (tokzr, model, model_id_used)
    """
    import os as _os

    bl = baseline.lower()
    if "gpt2" in bl:
        candidates = []
        for mid in (baseline, "openai-community/gpt2", "gpt2"):
            if mid not in candidates:
                candidates.append(mid)
    else:
        candidates = [baseline]

    last_err = None
    for mid in candidates:
        try:
            # Patch during tokenizer load only
            _orig_isfile = _os.path.isfile
            _os.path.isfile = (lambda p: False if p is None else _orig_isfile(p))
            try:
                tokzr = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
            finally:
                _os.path.isfile = _orig_isfile

            # Right padding + pad token (for GPT-2)
            if getattr(tokzr, "padding_side", None) != "right":
                tokzr.padding_side = "right"
            if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
                tokzr.pad_token = tokzr.eos_token

            # Pass add_prefix_space=True at call time if supported
            tok_call_kwargs = {}
            try:
                if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
                    tok_call_kwargs["add_prefix_space"] = True
            except Exception:
                pass
            tokzr._safe_call_kwargs = tok_call_kwargs

            model = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
                model.config.pad_token_id = tokzr.pad_token_id

            return tokzr, model, mid
        except Exception as e:
            last_err = e
            continue

    raise RuntimeError(f"Failed to load tokenizer/model for {baseline} "
                       f"(tried {candidates}). Last error: {last_err}")

# ========= Per-subsample single-value compute functions (used inside bootstrap) =========
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca99_once(X: np.ndarray) -> float:
    return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

# =============================== DATA (arity already in CSV) ===============================
def load_arity_df(csv_path: str):
    """
    Reads sentence-level rows with list columns: tokens, <arity>, pos.
    Expands to one row per token with an arity class in {0,1,2,3,4} (4 = 4+).

    - EXCLUDES the first token (word_id == 0) of each sentence if EXCLUDE_FIRST_WORD=True.
    - KEEPS ONLY tokens whose POS == KEEP_POS_TAG (e.g. "VERB").
    Returns:
      df_sent (sent-level with tokens),
      df_tok  (token-level: sentence_id, word_id, arity_class, word[, pos])
    """
    if csv_path is None:
        raise TypeError("CSV_PATH is None — set CSV_PATH to your dataset filename.")

    # Read full CSV once
    df_all = pd.read_csv(csv_path)

    # Find the correct arity column name ("arity" / "ariety" / etc.)
    ar_col = _pick_arity_col(df_all)

    # We need sentence_id, tokens, arity and POS
    needed_cols = ["sentence_id", "tokens", ar_col, POS_COL]
    for c in needed_cols:
        if c not in df_all.columns:
            raise ValueError(f"Expected column {c!r} in CSV, but it is missing.")
    df = df_all[needed_cols].copy()

    df["sentence_id"] = df["sentence_id"].astype(str)
    df["tokens"]      = df["tokens"].apply(_to_list)
    df[ar_col]        = df[ar_col].apply(_to_list)
    df[POS_COL]       = df[POS_COL].apply(_to_list)

    rows = []
    for sid, toks, A, poss in df[["sentence_id", "tokens", ar_col, POS_COL]].itertuples(index=False):
        # be safe if lengths differ
        L = min(len(toks), len(A), len(poss))
        for wid in range(L):
            # ---- EXCLUDE FIRST WORD ----
            if EXCLUDE_FIRST_WORD and wid == 0:
                continue

            pos_tag = poss[wid]
            # ---- POS FILTER: keep only VERBs (or KEEP_POS_TAG) ----
            if pos_tag != KEEP_POS_TAG:
                continue

            tok = toks[wid]
            a   = A[wid]

            try:
                ai = int(a)
            except Exception:
                ai = 0

            # Arity class: 0,1,2,3,4 (where 4 = 4+)
            cl = str(min(max(ai, 0), 6))
            rows.append((sid, wid, cl, tok, pos_tag))

    df_tok = pd.DataFrame(
        rows,
        columns=["sentence_id", "word_id", "arity_class", "word", "pos"]
    )

    # Safety: enforce again (should be redundant)
    if EXCLUDE_FIRST_WORD and not df_tok.empty:
        df_tok = df_tok[df_tok.word_id != 0].reset_index(drop=True)

    df_sent = df[["sentence_id", "tokens"]].drop_duplicates("sentence_id")
    return df_sent, df_tok


def sample_raw(df_tok: pd.DataFrame,
               per_class_cap: int = RAW_MAX_PER_CLASS,
               per_class_caps: Dict[str, int] | None = None) -> pd.DataFrame:
    """Per-class cap without frequency matching (with optional overrides per class)."""
    picks = []
    caps = per_class_caps or {}
    for c, sub in df_tok.groupby("arity_class", sort=False):
        cap = caps.get(str(c), per_class_cap)
        n = min(len(sub), cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    base_colors: List[Tuple[float, float, float]] = []
    for name in ("tab20", "tab20b", "tab20c"):
        try: base_colors.extend(sns.color_palette(name, 20))
        except Exception: pass
    if len(base_colors) < len(classes):
        base_colors = list(sns.color_palette("husl", len(classes)))
    ordered = list(sorted(classes, key=lambda s: int(s) if s.isdigit() else 99))
    return {cls: base_colors[i % len(base_colors)] for i, cls in enumerate(ordered)}

# =============================== TOKENIZER/MODEL (robust) ===============================
def _load_tokenizer_and_model(baseline: str):
    """Backward-compatible alias using the safe loader above."""
    return _safe_load_tok_and_model(baseline)

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str]:
    """
    Return (reps, filled, rep_mode_used).
    reps shape: (L, N, D) where L includes the embedding layer (layer 0).
    """
    df_sent["sentence_id"]  = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    # Robust tokenizer/model load (fixes GPT-2 "NoneType path" crashes)
    tokzr, model, model_id_used = _load_tokenizer_and_model(baseline)

    # Pre-tokenized input with word mapping
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    # If tokenizer supports add_prefix_space, supply it
    enc_kwargs.update(getattr(tokzr, "_safe_call_kwargs", {}))

    # GPT‑2 specifics: ensure pad token is set
    if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
        tokzr.pad_token = tokzr.eos_token
    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id

    if device == "cuda": model.half()
    model = model.eval().to(device)

    L = _num_hidden_layers(model) + 1   # include embeddings
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float32)  # float32 for stable SVD/stats
    filled = np.zeros(N, dtype=bool)

    # Choose/validate rep mode depending on model family
    gpt_like = _is_gpt_like(model)
    if gpt_like:
        rep_mode = word_rep_mode if word_rep_mode in {"last","mean"} else "last"
    else:
        rep_mode = word_rep_mode if word_rep_mode in {"first","mean","last"} else "first"

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id_used} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                # Map word_id -> token positions for this item
                word_map = {}
                wids = enc_be.word_ids(b)  # fast tokenizer needed
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        word_map.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = word_map.get(wid)
                    if not toks: continue
                    if rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float32, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} sampled words")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, rep_mode

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    """Bootstrap: sample M with replacement and apply compute_once(X_layer) -> scalar for each layer."""
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:    A[r, l] = float(compute_once(X))
            except Exception: A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# ---- Metric registries ----
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    "iso": _iso_once,     # IsoScore
    "pca99": _pca99_once  # lPCA @ 0.99 explained variance
}

HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    "gride": _dadapy_gride_once
}

LABELS = {
    "pca99":"lPCA 0.99",
    "gride":"GRIDE"
}

# Keep runtime reasonable (add more if you want)
ALL_METRICS = ["iso","gride","pca99"]

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw",
                                rep_mode_used: str = ""):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "arity",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)),
                "word_rep_mode": rep_mode_used or WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"arity_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None):
    plt.figure(figsize=(9, 5))
    for c in sorted(class_to_stats.keys(), key=lambda s: int(s)):
        stats = class_to_stats[c]
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = palette.get(c) if isinstance(palette, dict) else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    plt.legend(ncol=3, fontsize="small", title="Arity (4 = 4+)", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_arity_pipeline():
    # 1) Load tokens + arity classes (with first-word exclusion)
    df_sent, ar_df = load_arity_df(CSV_PATH)
    classes = sorted(ar_df.arity_class.unique(), key=lambda s: int(s))
    palette = make_class_palette(classes)
    print(f"✓ corpus ready — {len(ar_df):,} tokens across arity classes {classes}")

    # 2) Optional per-class cap (cap arity "0" to 30k via PER_CLASS_CAPS)
    raw_df = sample_raw(ar_df, RAW_MAX_PER_CLASS, PER_CLASS_CAPS)
    print("Sample sizes per arity (raw cap):")
    print(raw_df.arity_class.value_counts().sort_index().to_dict())

    # 3) Embed once (encoder or decoder; rep mode auto-validated)
    reps, filled, rep_mode_used = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.arity_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}  • rep_mode={rep_mode_used}")

    # 4) Metric loop
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        # choose bootstrap budget
        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        save_metric_csv_all_classes(metric, class_results, layers, BASELINE,
                                    subset_name="raw", rep_mode_used=rep_mode_used)
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE} • rep={rep_mode_used}",
                            out_path=PLOT_DIR / f"arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png",
                            palette=palette)
        print(f"  ✓ saved: CSV= {CSV_DIR}/arity_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/arity_raw_{metric}_{BASELINE}_rep-{rep_mode_used}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_arity_pipeline()


✓ corpus ready — 21,495 tokens across arity classes ['0', '1', '2', '3', '4', '5', '6']
Sample sizes per arity (raw cap):
{'0': 822, '1': 2011, '2': 5006, '3': 5978, '4': 4281, '5': 2201, '6': 1196}


bert-base-uncased (embed subset): 100%|████| 8607/8607 [01:02<00:00, 137.89it/s]


✓ embedded 21,495 tokens  • layers=13  • rep_mode=first

→ Computing metric: iso …
  ✓ saved: CSV= tables_ARITY_no_index/arity_bootstrap/arity_raw_iso_bert-base-uncased.csv  plot= results_ARITY_no_index/arity_raw_iso_bert-base-uncased_rep-first.png

→ Computing metric: gride …
  ✓ saved: CSV= tables_ARITY_no_index/arity_bootstrap/arity_raw_gride_bert-base-uncased.csv  plot= results_ARITY_no_index/arity_raw_gride_bert-base-uncased_rep-first.png

→ Computing metric: pca99 …
  ✓ saved: CSV= tables_ARITY_no_index/arity_bootstrap/arity_raw_pca99_bert-base-uncased.csv  plot= results_ARITY_no_index/arity_raw_pca99_bert-base-uncased_rep-first.png

✓ done (incremental outputs produced per metric).
