In [1]:
from __future__ import annotations
import os, ast, random, inspect
from pathlib import Path
from typing import Dict, List

import numpy as np, pandas as pd, torch
import torch.utils.data as torchdata
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoTokenizer, AutoModel, BertModel, AutoConfig

from IsoScore import IsoScore
from dadapy import Data
from skdim.id import MLE, MOM, TLE, CorrInt, FisherS, lPCA

2025-11-14 09:00:00.822766: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# ===== Optional deps (gracefully skipped if not installed) =====
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import MOM, TLE, CorrInt, FisherS, lPCA, MLE, DANCo, ESS, MiND_ML, MADA, KNN
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: use library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            return float(np.clip(ev.mean() / (ev[-1] + 1e-9), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH       = "en_ewt-ud-train_sentences.csv"
REL_COL_HINT   = "relation_type"        # <-- your existing list[str] column with UD relations
TOP_K_REL      = 10                      # keep only the 10 most frequent relations

BASELINE       = "bert-base-uncased"     # set to "gpt2" for GPT-2
WORD_REP_MODE  = "first"                 # BERT: {"first","last","mean"}; GPT-2: {"last","mean"}

RAW_MAX_PER_CLASS             = int(1e12)  # no cap per relation for fast metrics
N_BOOTSTRAP_FAST              = 50
N_BOOTSTRAP_HEAVY             = 20
FAST_BS_MAX_SAMP_PER_CLASS    = int(1e12)
HEAVY_BS_MAX_SAMP_PER_CLASS   = 5000

RAND_SEED = 42
PLOT_DIR  = Path("results_REL"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR   = Path("tables_REL") / "relation_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 1
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120
EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

# ========= Metric single-call functions =========
# --- Isotropy ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS))); am = float(lam.mean() + EPS)
    return float(gm / am)  # higher = flatter = more isotropic

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng(RAND_SEED)
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))  # higher ≈ more anisotropic

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))  # higher = more anisotropic

# --- Linear ID ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float: return _pcaXX_once(X, 0.95)
def _pca99_once(X: np.ndarray) -> float: return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (DADApy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

# --- Non-linear (scikit-dimension) ---
def _skdim_factory(name: str):
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca95": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "mind_ml": MiND_ML, "ess": ESS,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None
    def _builder():
        if name == "lpca":   return cls(ver="FO")
        if name == "lpca95": return cls(ver="ratio", alphaRatio=0.95)
        if name == "lpca99": return cls(ver="ratio", alphaRatio=0.99)
        return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None
    def _once(X: np.ndarray) -> float:
        est = build(); est.fit(_jitter_unique(X)); return float(getattr(est, "dimension_", np.nan))
    return _once

# ---- Metric registries (ALL METRICS) ----
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    # Isotropy
    "iso": _iso_once, "sf": _sf_once, "rand": _rand_once, "vmf_kappa": _vmf_kappa_once,
    # Linear ID
    "erank": _erank_once, "pr": _pr_once, "stable_rank": _stable_rank_once,
    "pca95": _pca95_once, "pca99": _pca99_once,
}
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    # DADApy
    "twonn": _dadapy_twonn_once, "gride": _dadapy_gride_once,
    # scikit-dimension
    "mom": _skdim_once_builder("mom"), "tle": _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"), "fishers": _skdim_once_builder("fishers"),
    "lpca": _skdim_once_builder("lpca"), "lpca95_skdim": _skdim_once_builder("lpca95"),
    "lpca99_skdim": _skdim_once_builder("lpca99"),
    "mle": _skdim_once_builder("mle"), "danco": _skdim_once_builder("danco"),
    "mind_ml": _skdim_once_builder("mind_ml"), "ess": _skdim_once_builder("ess"),
    "mada": _skdim_once_builder("mada"), "knn": _skdim_once_builder("knn"),
}
LABELS = {
    # Isotropy
    "iso":"IsoScore", "sf":"Spectral Flatness", "rand":"RandCos |μ| (anisotropy↑)", "vmf_kappa":"vMF κ (anisotropy↑)",
    # Linear ID
    "erank":"Effective Rank", "pr":"Participation Ratio", "stable_rank":"Stable Rank",
    "pca95":"lPCA 0.95", "pca99":"lPCA 0.99",
    # Non-linear (DADApy)
    "twonn":"TwoNN ID", "gride":"GRIDE ID",
    # Non-linear (skdim)
    "mom":"MOM", "tle":"TLE", "corrint":"CorrInt", "fishers":"FisherS",
    "lpca":"lPCA FO", "lpca95_skdim":"lPCA 0.95 (skdim)", "lpca99_skdim":"lPCA 0.99 (skdim)",
    "mle":"MLE", "danco":"DANCo", "mind_ml":"MiND-ML", "ess":"ESS", "mada":"MADA", "knn":"KNN",
}
#ALL_METRICS = list(FAST_ONCE.keys()) + [k for k, v in HEAVY_ONCE.items() if v is not None]
ALL_METRICS =["gride", "lpca99_skdim"]

# =============================== DATA: use existing relation_type column ===============================
def _pick_relation_col(df: pd.DataFrame) -> str:
    cands = [REL_COL_HINT, "typed_dependency", "relation", "deprel", "ud_rel", "rel", "REL", "Rel"]
    for c in cands:
        if c in df.columns: return c
    raise ValueError(f"No relation column found. Tried: {', '.join(cands)}")

def load_relations_topk_from_column(csv_path: str, top_k: int = TOP_K_REL):
    """
    Expects CSV with:
      - sentence_id (str)
      - tokens        (list[str]) — one row per sentence
      - relation_type (list[str]) — UD relation labels per token (or similarly named)
    Expands to token-level rows with 'relation_class' and keeps only top_k most frequent labels.
    """
    df_all = pd.read_csv(csv_path)
    rel_col = _pick_relation_col(df_all)
    df = df_all[["sentence_id","tokens", rel_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens  = df.tokens.apply(_to_list)
    df[rel_col] = df[rel_col].apply(_to_list)

    rows = []
    for sid, toks, rels in df[["sentence_id","tokens", rel_col]].itertuples(index=False):
        L = min(len(toks), len(rels))
        for wid in range(L):
            r = str(rels[wid])
            rows.append((sid, wid, r, toks[wid]))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","relation_class","word"])
    if df_tok.empty:
        raise ValueError("No token rows constructed—check that your relation column contains lists of strings.")

    # keep only top-k most frequent
    top = df_tok.relation_class.value_counts().nlargest(top_k).index.tolist()
    df_tok = df_tok[df_tok.relation_class.isin(top)].reset_index(drop=True)

    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    return df_sent, df_tok, top

def sample_raw(df_tok: pd.DataFrame, per_class_cap: int = RAW_MAX_PER_CLASS) -> pd.DataFrame:
    picks = []
    for c, sub in df_tok.groupby("relation_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    # Use distinct qualitative colors; stable order by frequency list order
    base = list(sns.color_palette("tab20", 20)) + list(sns.color_palette("tab20b", 20)) + list(sns.color_palette("tab20c", 20))
    if len(base) < len(classes):
        base = list(sns.color_palette("husl", len(classes)))
    return {cls: base[i % len(base)] for i, cls in enumerate(classes)}

# =============================== EMBEDDING (BERT & GPT‑2) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True)
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True
    if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
        tokzr.pad_token = tokzr.eos_token

    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
        model.config.pad_token_id = tokzr.pad_token_id
    if device == "cuda": model.half()

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{baseline} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be one of {'first','last','mean'} (for GPT-2 use 'last' or 'mean').")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing: print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw"):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "relation_type",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"relation_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None,
                        classes_order: List[str] | None = None):
    plt.figure(figsize=(10.5, 5.5))
    order = classes_order if classes_order is not None else sorted(class_to_stats.keys())
    for c in order:
        stats = class_to_stats.get(c)
        if not stats: continue
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = (palette.get(c) if isinstance(palette, dict) else None) if palette else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    ncol = 5 if len(order) >= 10 else 3
    plt.legend(ncol=ncol, fontsize="small", title="UD relation (top‑10)", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_relation_topk_pipeline():
    # 1) Load token lists + top‑K relations from existing column
    df_sent, rel_df, top_rel = load_relations_topk_from_column(CSV_PATH, top_k=TOP_K_REL)
    classes = list(top_rel)  # keep in frequency order
    palette = make_class_palette(classes)
    print(f"✓ corpus ready — {len(rel_df):,} tokens across relations {classes}")

    # 2) Optional per-class cap (currently unlimited for fast metrics)
    raw_df = sample_raw(rel_df, RAW_MAX_PER_CLASS)
    print("Sample sizes per relation (raw cap):")
    counts = raw_df.relation_class.value_counts()
    print({k: int(counts[k]) for k in classes})

    # 3) Embed once (BERT: 'first/last/mean'; GPT‑2: set WORD_REP_MODE='last' or 'mean')
    reps, filled = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.relation_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric loop (ALL METRICS)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        save_metric_csv_all_classes(metric, class_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                            out_path=PLOT_DIR / f"relation_raw_{metric}_{BASELINE}.png",
                            palette=palette, classes_order=classes)
        print(f"  ✓ saved: CSV= {CSV_DIR}/relation_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/relation_raw_{metric}_{BASELINE}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_relation_topk_pipeline()



In [6]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModel

# ===== Optional deps (gracefully skipped if not installed) =====
HAS_DADAPY = False
try:
    from dadapy import Data  # DADApy ID estimators (TwoNN, GRIDE)
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import MOM, TLE, CorrInt, FisherS, lPCA, MLE, DANCo, ESS, MiND_ML, MADA, KNN
    HAS_SKDIM = True
except Exception:
    pass

# IsoScore: use library if available, else a simple monotone fallback
try:
    from isoscore import IsoScore
    _HAS_ISOSCORE = True
except Exception:
    _HAS_ISOSCORE = False
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0:
                return 0.0
            return float(np.clip(ev.mean() / (ev[-1] + 1e-9), 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

# =============================== CONFIG ===============================
CSV_PATH       = "en_ewt-ud-train_sentences.csv"
REL_COL_HINT   = "relation_type"        # column with UD relations (list[str])
TOP_K_REL      = 10                     # keep only the 10 most frequent relations

# ---- model config: GPT-2 + last token representation ----
BASELINE       = "gpt2"                 # you can also try "openai-community/gpt2"
WORD_REP_MODE  = "last"                 # for GPT-2 use {"last","mean"}

# ---- NEW: drop first word of each sentence (index 0) ----
EXCLUDE_INDEX_0 = True

RAW_MAX_PER_CLASS             = int(1e12)  # no cap per relation for fast metrics
N_BOOTSTRAP_FAST              = 50
N_BOOTSTRAP_HEAVY             = 20
FAST_BS_MAX_SAMP_PER_CLASS    = int(1e12)
HEAVY_BS_MAX_SAMP_PER_CLASS   = 5000

RAND_SEED = 42
PLOT_DIR  = Path("results_REL_GPT2_no_idx0"); PLOT_DIR.mkdir(exist_ok=True, parents=True)
CSV_DIR   = Path("tables_REL_GPT2_no_idx0") / "relation_bootstrap"; CSV_DIR.mkdir(exist_ok=True, parents=True)

BATCH_SIZE = 1
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120
EPS = 1e-12

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _center(X: np.ndarray) -> np.ndarray:
    return X - X.mean(0, keepdims=True)

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort()
        return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    try:
        if np.unique(X, axis=0).shape[0] < X.shape[0]:
            X = X + np.random.normal(scale=eps, size=X.shape).astype(X.dtype)
    except Exception:
        pass
    return X

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)
    if n is None: raise ValueError("Cannot determine number of hidden layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

def _is_gpt_like(model) -> bool:
    mt = str(getattr(model.config, "model_type", "")).lower()
    name = str(getattr(getattr(model, "name_or_path", ""), "lower", lambda: "")())
    return ("gpt2" in mt) or ("gpt2" in name)

# ===== robust GPT‑2 loader to avoid NoneType vocab_file bug =====
def _safe_load_tok_and_model(baseline: str):
    """
    Robust tokenizer+model loader.
    - For GPT-2, avoids 'vocab_file NoneType' problems in some transformers setups
      by temporarily guarding os.path.isfile(None).
    - Tries both 'gpt2' and 'openai-community/gpt2' where appropriate.
    Returns: (tokzr, model, model_id_used)
    """
    import os as _os

    bl = baseline.lower()
    if "gpt2" in bl:
        candidates = []
        for mid in (baseline, "openai-community/gpt2", "gpt2"):
            if mid not in candidates:
                candidates.append(mid)
    else:
        candidates = [baseline]

    last_err = None
    for mid in candidates:
        try:
            # patch os.path.isfile just during tokenizer init
            _orig_isfile = _os.path.isfile
            def _patched_isfile(p):
                if p is None:  # avoid TypeError: stat: path should be string, bytes, os.PathLike or integer, not NoneType
                    return False
                return _orig_isfile(p)
            _os.path.isfile = _patched_isfile
            try:
                tokzr = AutoTokenizer.from_pretrained(mid, use_fast=True, add_prefix_space=True)
            finally:
                _os.path.isfile = _orig_isfile

            # GPT‑ish niceties
            if getattr(tokzr, "padding_side", None) != "right":
                tokzr.padding_side = "right"
            if tokzr.pad_token is None and getattr(tokzr, "eos_token", None) is not None:
                tokzr.pad_token = tokzr.eos_token

            # model
            model = AutoModel.from_pretrained(mid, output_hidden_states=True)
            if getattr(model.config, "pad_token_id", None) is None and tokzr.pad_token_id is not None:
                model.config.pad_token_id = tokzr.pad_token_id

            # attach call kwargs (add_prefix_space for GPT‑2 like models)
            call_kwargs = {}
            if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
                call_kwargs["add_prefix_space"] = True
            tokzr._safe_call_kwargs = call_kwargs

            return tokzr, model, mid
        except Exception as e:
            last_err = e
            continue

    raise RuntimeError(f"Failed to load tokenizer/model for {baseline}. Tried {candidates}. Last error: {last_err}")

# ========= Metric single-call functions =========
# --- Isotropy ---
def _iso_once(X: np.ndarray) -> float:
    return float(IsoScore.IsoScore(X))

def _sf_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    gm = np.exp(np.mean(np.log(lam + EPS))); am = float(lam.mean() + EPS)
    return float(gm / am)  # higher = flatter = more isotropic

def _rand_once(X: np.ndarray, K: int = 2000) -> float:
    n = X.shape[0]
    if n < 2: return np.nan
    rng = np.random.default_rng(RAND_SEED)
    K_eff = min(K, (n*(n-1))//2)
    i = rng.integers(0, n, size=K_eff)
    j = rng.integers(0, n, size=K_eff)
    same = i == j
    if same.any():
        j[same] = rng.integers(0, n, size=same.sum())
    A, B = X[i], X[j]
    num = np.sum(A*B, axis=1)
    den = (np.linalg.norm(A, axis=1)*np.linalg.norm(B, axis=1) + 1e-9)
    return float(np.mean(np.abs(num/den)))  # higher ≈ more anisotropic

def _vmf_kappa_once(X: np.ndarray) -> float:
    if X.shape[0] < 2: return np.nan
    Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
    R = np.linalg.norm(Xn.mean(axis=0))
    d = Xn.shape[1]
    if R < 1e-9: return 0.0
    return float(max(R * (d - R**2) / (1.0 - R**2 + 1e-9), 0.0))  # higher = more anisotropic

# --- Linear ID ---
def _pcaXX_once(X: np.ndarray, var_ratio: float) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    c = np.cumsum(lam); thr = c[-1] * var_ratio
    return float(np.searchsorted(c, thr) + 1)

def _pca95_once(X: np.ndarray) -> float: return _pcaXX_once(X, 0.95)
def _pca99_once(X: np.ndarray) -> float: return _pcaXX_once(X, 0.99)

def _erank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    p = lam / (lam.sum() + EPS)
    H = -(p * np.log(p + EPS)).sum()
    return float(np.exp(H))

def _pr_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    s1 = lam.sum(); s2 = (lam**2).sum()
    return float((s1**2) / (s2 + EPS))

def _stable_rank_once(X: np.ndarray) -> float:
    lam = _eigvals_from_X(X)
    if lam.size == 0: return np.nan
    return float(lam.sum() / (lam.max() + EPS))

# --- Non-linear (DADApy) ---
def _dadapy_twonn_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    id_est, _, _ = d.compute_id_2NN()
    return float(id_est)

def _dadapy_gride_once(X: np.ndarray) -> float:
    if not HAS_DADAPY: return np.nan
    d = Data(coordinates=_jitter_unique(X))
    d.compute_distances(maxk=64)
    ids, _, _ = d.return_id_scaling_gride(range_max=64)
    return float(ids[-1])

# --- Non-linear (scikit-dimension) ---
def _skdim_factory(name: str):
    if not HAS_SKDIM: return None
    mapping = {
        "mom": MOM, "tle": TLE, "corrint": CorrInt, "fishers": FisherS,
        "lpca": lPCA, "lpca95": lPCA, "lpca99": lPCA,
        "mle": MLE, "danco": DANCo, "mind_ml": MiND_ML, "ess": ESS,
        "mada": MADA, "knn": KNN,
    }
    cls = mapping.get(name)
    if cls is None: return None
    def _builder():
        if name == "lpca":   return cls(ver="FO")
        if name == "lpca95": return cls(ver="ratio", alphaRatio=0.95)
        if name == "lpca99": return cls(ver="ratio", alphaRatio=0.99)
        return cls()
    return _builder

def _skdim_once_builder(name: str) -> Callable[[np.ndarray], float] | None:
    build = _skdim_factory(name)
    if build is None: return None
    def _once(X: np.ndarray) -> float:
        est = build(); est.fit(_jitter_unique(X)); return float(getattr(est, "dimension_", np.nan))
    return _once

# ---- Metric registries (ALL METRICS) ----
FAST_ONCE: Dict[str, Callable[[np.ndarray], float]] = {
    # Isotropy
    "iso": _iso_once, "sf": _sf_once, "rand": _rand_once, "vmf_kappa": _vmf_kappa_once,
    # Linear ID
    "erank": _erank_once, "pr": _pr_once, "stable_rank": _stable_rank_once,
    "pca95": _pca95_once, "pca99": _pca99_once,
}
HEAVY_ONCE: Dict[str, Callable[[np.ndarray], float] | None] = {
    # DADApy
    "twonn": _dadapy_twonn_once, "gride": _dadapy_gride_once,
    # scikit-dimension
    "mom": _skdim_once_builder("mom"), "tle": _skdim_once_builder("tle"),
    "corrint": _skdim_once_builder("corrint"), "fishers": _skdim_once_builder("fishers"),
    "lpca": _skdim_once_builder("lpca"), "lpca95_skdim": _skdim_once_builder("lpca95"),
    "lpca99_skdim": _skdim_once_builder("lpca99"),
    "mle": _skdim_once_builder("mle"), "danco": _skdim_once_builder("danco"),
    "mind_ml": _skdim_once_builder("mind_ml"), "ess": _skdim_once_builder("ess"),
    "mada": _skdim_once_builder("mada"), "knn": _skdim_once_builder("knn"),
}
LABELS = {
    # Isotropy
    "iso":"IsoScore", "sf":"Spectral Flatness", "rand":"RandCos |μ| (anisotropy↑)", "vmf_kappa":"vMF κ (anisotropy↑)",
    # Linear ID
    "erank":"Effective Rank", "pr":"Participation Ratio", "stable_rank":"Stable Rank",
    "pca95":"lPCA 0.95", "pca99":"lPCA 0.99",
    # Non-linear (DADApy)
    "twonn":"TwoNN ID", "gride":"GRIDE ID",
    # Non-linear (skdim)
    "mom":"MOM", "tle":"TLE", "corrint":"CorrInt", "fishers":"FisherS",
    "lpca":"lPCA FO", "lpca95_skdim":"lPCA 0.95 (skdim)", "lpca99_skdim":"lPCA 0.99 (skdim)",
    "mle":"MLE", "danco":"DANCo", "mind_ml":"MiND-ML", "ess":"ESS", "mada":"MADA", "knn":"KNN",
}
# Keep it small for now
ALL_METRICS = ["gride", "iso", "lpca99_skdim"]

# =============================== DATA: relation_type column ===============================
def _pick_relation_col(df: pd.DataFrame) -> str:
    cands = [REL_COL_HINT, "typed_dependency", "relation", "deprel", "ud_rel", "rel", "REL", "Rel"]
    for c in cands:
        if c in df.columns: return c
    raise ValueError(f"No relation column found. Tried: {', '.join(cands)}")

def load_relations_topk_from_column(csv_path: str, top_k: int = TOP_K_REL):
    """
    Expects CSV with:
      - sentence_id (str)
      - tokens        (list[str])
      - relation_type (list[str]) or similar
    Expands to token-level rows with 'relation_class' and keeps only top_k most frequent.
    Drops tokens at word_id == 0 (first word) if EXCLUDE_INDEX_0=True.
    """
    df_all = pd.read_csv(csv_path)
    rel_col = _pick_relation_col(df_all)
    df = df_all[["sentence_id","tokens", rel_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens  = df.tokens.apply(_to_list)
    df[rel_col] = df[rel_col].apply(_to_list)

    rows = []
    for sid, toks, rels in df[["sentence_id","tokens", rel_col]].itertuples(index=False):
        L = min(len(toks), len(rels))
        for wid in range(L):
            # ---- drop first word (index 0) ----
            if EXCLUDE_INDEX_0 and wid == 0:
                continue
            r = str(rels[wid])
            rows.append((sid, wid, r, toks[wid]))

    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","relation_class","word"])
    # extra safety (if logic above ever changes)
    if EXCLUDE_INDEX_0 and not df_tok.empty:
        df_tok = df_tok[df_tok.word_id != 0].reset_index(drop=True)

    if df_tok.empty:
        raise ValueError("No token rows constructed—check that your relation column contains lists of strings.")

    # keep only top-k most frequent
    top = df_tok.relation_class.value_counts().nlargest(top_k).index.tolist()
    df_tok = df_tok[df_tok.relation_class.isin(top)].reset_index(drop=True)

    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    return df_sent, df_tok, top

def sample_raw(df_tok: pd.DataFrame, per_class_cap: int = RAW_MAX_PER_CLASS) -> pd.DataFrame:
    picks = []
    for c, sub in df_tok.groupby("relation_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

def make_class_palette(classes: List[str]) -> Dict[str, Tuple[float, float, float]]:
    base = list(sns.color_palette("tab20", 20)) \
         + list(sns.color_palette("tab20b", 20)) \
         + list(sns.color_palette("tab20c", 20))
    if len(base) < len(classes):
        base = list(sns.color_palette("husl", len(classes)))
    return {cls: base[i % len(base)] for i, cls in enumerate(classes)}

# =============================== EMBEDDING (GPT‑2/other) ===============================
def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray]:
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    # sid -> list[(global_idx, word_id)]
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, model_id_used = _safe_load_tok_and_model(baseline)
    model = model.eval().to(device)
    if device == "cuda":
        model.half()

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    enc_kwargs.update(getattr(tokzr, "_safe_call_kwargs", {}))

    L = _num_hidden_layers(model) + 1   # include embedding layer
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    # choose rep mode per model family
    gpt_like = _is_gpt_like(model)
    if gpt_like:
        rep_mode = word_rep_mode if word_rep_mode in {"last","mean"} else "last"
    else:
        rep_mode = word_rep_mode if word_rep_mode in {"first","last","mean"} else "first"

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{model_id_used} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp: Dict[int, List[int]] = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks:
                        continue
                    if rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:  # "mean"
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} of {N} tokens")
    del model; gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()
    return reps, filled

# =============================== BOOTSTRAP CORE ===============================
def _bs_layer_loop(rep_sub: np.ndarray, M: int, n_reps: int, compute_once: Callable[[np.ndarray], float]):
    L, N, D = rep_sub.shape
    rng = np.random.default_rng(RAND_SEED)
    A = np.full((n_reps, L), np.nan, np.float32)
    for r in range(n_reps):
        idx = rng.integers(0, N, size=M)
        for l in range(L):
            X = rep_sub[l, idx].astype(np.float32, copy=False)
            try:
                A[r, l] = float(compute_once(X))
            except Exception:
                A[r, l] = np.nan
    mu = np.nanmean(A, axis=0).astype(np.float32)
    lo = np.nanpercentile(A, 2.5, axis=0).astype(np.float32)
    hi = np.nanpercentile(A, 97.5, axis=0).astype(np.float32)
    return mu, lo, hi

# =============================== SAVE / PLOT ===============================
def save_metric_csv_all_classes(metric: str,
                                class_to_stats: Dict[str, Dict[str, np.ndarray]],
                                layers: np.ndarray,
                                baseline: str,
                                subset_name: str = "raw"):
    rows = []
    for c, stats in class_to_stats.items():
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        for l, val in enumerate(mu):
            rows.append({
                "subset": subset_name, "model": baseline, "feature": "relation_type",
                "class": c, "metric": metric, "layer": int(layers[l]),
                "mean": float(val) if np.isfinite(val) else np.nan,
                "ci_low": float(lo[l]) if isinstance(lo, np.ndarray) and np.isfinite(lo[l]) else np.nan,
                "ci_high": float(hi[l]) if isinstance(hi, np.ndarray) and np.isfinite(hi[l]) else np.nan,
                "n_tokens": int(stats.get("n", 0)), "word_rep_mode": WORD_REP_MODE,
                "source_csv": Path(CSV_PATH).name,
            })
    df = pd.DataFrame(rows)
    out = CSV_DIR / f"relation_{subset_name}_{metric}_{baseline}.csv"
    df.to_csv(out, index=False)

def plot_metric_with_ci(class_to_stats: Dict[str, Dict[str, np.ndarray]],
                        layers: np.ndarray, metric: str, title: str, out_path: Path,
                        palette: Dict[str, Tuple[float, float, float]] | None = None,
                        classes_order: List[str] | None = None):
    plt.figure(figsize=(10.5, 5.5))
    order = classes_order if classes_order is not None else sorted(class_to_stats.keys())
    for c in order:
        stats = class_to_stats.get(c)
        if not stats: continue
        mu, lo, hi = stats["mean"], stats.get("lo"), stats.get("hi")
        if mu is None or np.all(np.isnan(mu)): continue
        color = (palette.get(c) if isinstance(palette, dict) else None) if palette else None
        plt.plot(layers, mu, label=c, lw=1.8, color=color)
        if isinstance(lo, np.ndarray) and isinstance(hi, np.ndarray) and not np.all(np.isnan(lo)):
            plt.fill_between(layers, lo, hi, alpha=0.15, color=color)
    plt.xlabel("Layer"); plt.ylabel(LABELS.get(metric, metric.upper())); plt.title(title)
    ncol = 5 if len(order) >= 10 else 3
    plt.legend(ncol=ncol, fontsize="small", title="UD relation (top‑10)", frameon=False)
    plt.tight_layout(); plt.savefig(out_path, dpi=220); plt.close()

# =============================== DRIVER ===============================
def run_relation_topk_pipeline():
    # 1) Load token lists + top‑K relations from existing column
    df_sent, rel_df, top_rel = load_relations_topk_from_column(CSV_PATH, top_k=TOP_K_REL)
    classes = list(top_rel)  # keep in frequency order
    palette = make_class_palette(classes)
    print(f"✓ corpus ready — {len(rel_df):,} tokens across relations {classes}")
    print(f"✓ first word of each sentence removed (EXCLUDE_INDEX_0={EXCLUDE_INDEX_0})")

    # 2) Optional per-class cap (currently unlimited for fast metrics)
    raw_df = sample_raw(rel_df, RAW_MAX_PER_CLASS)
    print("Sample sizes per relation (raw cap):")
    counts = raw_df.relation_class.value_counts()
    print({k: int(counts[k]) for k in classes})

    # 3) Embed once
    reps, filled = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)
    cls_arr = raw_df.relation_class.values
    L = reps.shape[0]; layers = np.arange(L)
    print(f"✓ embedded {len(raw_df):,} tokens  • layers={L}")

    # 4) Metric loop (ALL METRICS)
    for metric in ALL_METRICS:
        print(f"\n→ Computing metric: {metric} …")
        compute_once = FAST_ONCE.get(metric) or HEAVY_ONCE.get(metric)
        if compute_once is None:
            print(f"  (skipping {metric}: estimator unavailable)")
            continue

        n_bs = N_BOOTSTRAP_FAST if metric in FAST_ONCE else N_BOOTSTRAP_HEAVY
        Mcap = FAST_BS_MAX_SAMP_PER_CLASS if metric in FAST_ONCE else HEAVY_BS_MAX_SAMP_PER_CLASS

        class_results: Dict[str, Dict[str, np.ndarray]] = {}
        for c in classes:
            idx = np.where(cls_arr == c)[0]
            if idx.size < 3:
                continue
            sub = reps[:, idx]  # (L, n_c, D)
            Nc = sub.shape[1]
            M = min(Mcap, Nc)
            mu, lo, hi = _bs_layer_loop(sub, M, n_bs, compute_once)
            class_results[c] = {"mean": mu, "lo": lo, "hi": hi, "n": int(Nc)}

        save_metric_csv_all_classes(metric, class_results, layers, BASELINE, subset_name="raw")
        plot_metric_with_ci(class_results, layers, metric,
                            title=f"{LABELS.get(metric, metric.upper())} • {BASELINE}",
                            out_path=PLOT_DIR / f"relation_raw_{metric}_{BASELINE}.png",
                            palette=palette, classes_order=classes)
        print(f"  ✓ saved: CSV= {CSV_DIR}/relation_raw_{metric}_{BASELINE}.csv  "
              f"plot= {PLOT_DIR}/relation_raw_{metric}_{BASELINE}.png")

        del class_results; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    del reps; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    print("\n✓ done (incremental outputs produced per metric).")

if __name__ == "__main__":
    run_relation_topk_pipeline()


✓ corpus ready — 118,011 tokens across relations ['punct', 'case', 'det', 'nsubj', 'obj', 'advmod', 'amod', 'root', 'obl', 'conj']
✓ first word of each sentence removed (EXCLUDE_INDEX_0=True)
Sample sizes per relation (raw cap):
{'punct': 21787, 'case': 16041, 'det': 14417, 'nsubj': 12331, 'obj': 9661, 'advmod': 9485, 'amod': 9076, 'root': 8915, 'obl': 8741, 'conj': 7557}


openai-community/gpt2 (embed subset): 100%|█| 10062/10062 [01:31<00:00, 110.53it


✓ embedded 118,011 tokens  • layers=13

→ Computing metric: gride …
  ✓ saved: CSV= tables_REL_GPT2_no_idx0/relation_bootstrap/relation_raw_gride_gpt2.csv  plot= results_REL_GPT2_no_idx0/relation_raw_gride_gpt2.png

→ Computing metric: iso …
  ✓ saved: CSV= tables_REL_GPT2_no_idx0/relation_bootstrap/relation_raw_iso_gpt2.csv  plot= results_REL_GPT2_no_idx0/relation_raw_iso_gpt2.png

→ Computing metric: lpca99_skdim …
  ✓ saved: CSV= tables_REL_GPT2_no_idx0/relation_bootstrap/relation_raw_lpca99_skdim_gpt2.csv  plot= results_REL_GPT2_no_idx0/relation_raw_lpca99_skdim_gpt2.png

✓ done (incremental outputs produced per metric).


In [None]:
from __future__ import annotations
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel, GPT2TokenizerFast

# Plotly
import plotly.graph_objects as go
import plotly.express as px
import plotly.colors as pc

# =============================== CONFIG ===============================
CSV_PATH       = "en_ewt-ud-train_sentences.csv"   # needs: sentence_id, tokens (list[str]), relation_type (list[str]) or similar
REL_COL_HINT   = "relation_type"                   # will try common alternatives if this is absent
TOP_K_REL      = 10                                # keep top-K relations (legend size / clarity)

BASELINE       = "gpt2"                            # or "bert-base-uncased"
WORD_REP_MODE  = "last"                            # BERT: {"first","last","mean"}; GPT-2: {"last","mean"}

# Optional: subsample for plotting smoothness (per class)
PCA_MAX_PER_CLASS = None                           # e.g., 2000; None = use all tokens

# Output
OUT_DIR  = Path("pca3d_relation_type"); OUT_DIR.mkdir(parents=True, exist_ok=True)

# Throughput / device
BATCH_SIZE = 2
RAND_SEED  = 42
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
    torch.backends.cudnn.benchmark = True

# Work offline if requested (pre-cached models)
LOCAL_ONLY = (
    os.environ.get("TRANSFORMERS_OFFLINE") == "1"
    or os.environ.get("HF_HUB_OFFLINE") == "1"
)

# =============================== HELPERS ===============================
def _to_list(x):
    return ast.literal_eval(x) if isinstance(x, str) and x.startswith("[") else x

def _num_hidden_layers(model) -> int:
    n = getattr(model.config, "num_hidden_layers", None)
    if n is None: n = getattr(model.config, "n_layer", None)   # GPT-2
    if n is None: raise ValueError("Cannot determine num_hidden_layers")
    return int(n)

def _hidden_size(model) -> int:
    d = getattr(model.config, "hidden_size", None)
    if d is None: d = getattr(model.config, "n_embd", None)    # GPT-2
    if d is None: raise ValueError("Cannot determine hidden size")
    return int(d)

def _pick_relation_col(df: pd.DataFrame) -> str:
    cands = [REL_COL_HINT, "typed_dependency", "relation", "deprel", "ud_rel", "rel", "REL", "Rel"]
    for c in cands:
        if c in df.columns:
            return c
    raise ValueError(f"No relation column found. Tried: {', '.join(cands)}")

def load_relations_topk_from_column(csv_path: str, top_k: int = TOP_K_REL):
    """
    Expand per-sentence lists -> token rows and keep only top_k frequent relations.
    Returns:
      df_sent: sentence_id + tokens
      df_tok : sentence_id, word_id, relation_class, word
      top    : list[str] relation labels kept (by frequency)
    """
    df_all = pd.read_csv(csv_path)
    rel_col = _pick_relation_col(df_all)
    df = df_all[["sentence_id", "tokens", rel_col]].copy()
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list)
    df[rel_col] = df[rel_col].apply(_to_list)

    rows = []
    for sid, toks, rels in df[["sentence_id","tokens", rel_col]].itertuples(index=False):
        L = min(len(toks), len(rels))
        for wid in range(L):
            r = str(rels[wid])
            rows.append((sid, wid, r, toks[wid]))
    df_tok = pd.DataFrame(rows, columns=["sentence_id","word_id","relation_class","word"])
    if df_tok.empty:
        raise ValueError("No token rows constructed—check your relation column contains lists.")
    top = df_tok.relation_class.value_counts().nlargest(top_k).index.tolist()
    df_tok = df_tok[df_tok.relation_class.isin(top)].reset_index(drop=True)
    df_sent = df[["sentence_id","tokens"]].drop_duplicates("sentence_id")
    return df_sent, df_tok, top

def sample_per_class(df_tok: pd.DataFrame, per_class_cap: int | None) -> pd.DataFrame:
    """Optional per-class subsample for plotting."""
    if per_class_cap is None:
        return df_tok.reset_index(drop=True)
    picks = []
    for c, sub in df_tok.groupby("relation_class", sort=False):
        n = min(len(sub), per_class_cap)
        picks.append(sub.sample(n, random_state=RAND_SEED, replace=False))
    return pd.concat(picks, ignore_index=True)

# ---------- Robust loader (fixes GPT-2 pitfalls) ----------
def _load_tok_and_model(model_id: str):
    """
    Tries a few candidates, forces a Fast tokenizer, sets PAD=EOS for GPT‑2,
    and returns (tokenizer, model, resolved_id).
    """
    candidates = [model_id]
    if model_id.lower() in {"gpt2", "gpt-2"}:
        # The canonical repo is "openai-community/gpt2" on HF Hub; also try distilgpt2
        candidates += ["openai-community/gpt2", "distilgpt2"]

    last_err = None
    for mid in candidates:
        try:
            if "gpt2" in mid.lower():
                tok = GPT2TokenizerFast.from_pretrained(
                    mid, add_prefix_space=True, local_files_only=LOCAL_ONLY
                )
            else:
                tok = AutoTokenizer.from_pretrained(
                    mid, use_fast=True, add_prefix_space=True, local_files_only=LOCAL_ONLY
                )
            # Right-padding + pad token (GPT‑2 has no pad token by default)
            if getattr(tok, "padding_side", None) != "right":
                tok.padding_side = "right"
            if tok.pad_token is None and getattr(tok, "eos_token", None) is not None:
                tok.pad_token = tok.eos_token  # <- critical for GPT‑2 batching
            mdl = AutoModel.from_pretrained(
                mid, output_hidden_states=True, local_files_only=LOCAL_ONLY
            )
            if getattr(mdl.config, "pad_token_id", None) is None and tok.pad_token_id is not None:
                mdl.config.pad_token_id = tok.pad_token_id

            mdl = mdl.eval().to(device)
            if device == "cuda":
                mdl.half()
            return tok, mdl, mid
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(
        "Could not load tokenizer/model. Attempts:\n  " +
        "\n  ".join(candidates) +
        f"\nLast error: {repr(last_err)}"
    )

def embed_subset(df_sent: pd.DataFrame,
                 subset_df: pd.DataFrame,
                 baseline: str = BASELINE,
                 word_rep_mode: str = WORD_REP_MODE,
                 batch_size: int = BATCH_SIZE) -> Tuple[np.ndarray, np.ndarray, str]:
    """
    Return (reps (L,N,D), filled mask (N,), resolved_model_id).
    """
    df_sent["sentence_id"]   = df_sent["sentence_id"].astype(str)
    subset_df["sentence_id"] = subset_df["sentence_id"].astype(str)

    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(subset_df[["sentence_id","word_id"]].itertuples(index=False)):
        by_sid.setdefault(str(sid), []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr, model, resolved_id = _load_tok_and_model(baseline)

    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    L = _num_hidden_layers(model) + 1   # include embeddings
    D = _hidden_size(model)
    N = len(subset_df)

    reps   = np.zeros((L, N, D), np.float16)
    filled = np.zeros(N, dtype=bool)

    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"{resolved_id} (embed subset)"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                wids = enc_be.word_ids(b)  # needs Fast tokenizer
                if wids is None:
                    raise RuntimeError("Fast tokenizer required (word_ids unavailable).")
                for tidx, wid in enumerate(wids):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))

                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    elif word_rep_mode == "mean":
                        vec = h[:, b, toks, :].mean(axis=1)
                    else:
                        raise ValueError("WORD_REP_MODE must be in {'first','last','mean'}.")
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda":
                torch.cuda.empty_cache()

    missing = int((~filled).sum())
    if missing:
        print(f"⚠ Missing vectors for {missing} tokens (skipped in PCA).")
        reps = reps[:, filled]
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps, filled, resolved_id

# =============================== PCA 3D PER LAYER ===============================
def _pca3d_layer(X: np.ndarray, n_components: int = 3) -> np.ndarray:
    """Lightweight PCA to 3D via SVD (no sklearn dependency)."""
    X = X.astype(np.float32, copy=False)
    Xc = X - X.mean(0, keepdims=True)
    U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
    return (U[:, :n_components] * S[:n_components]).astype(np.float32, copy=False)  # (n,3)

def _qual_palette_for_classes(classes: List[str]) -> Dict[str, str]:
    """
    Build a large discrete palette:
    1) stack multiple qualitative sets,
    2) if still not enough, sample a continuous scale.
    """
    seqs = [
        px.colors.qualitative.Alphabet,   # 26
        px.colors.qualitative.Dark24,     # 24
        px.colors.qualitative.Light24,    # 24
        px.colors.qualitative.Set3,       # 12
        px.colors.qualitative.Bold,       # 10
        px.colors.qualitative.Set2,       # 8
        px.colors.qualitative.Set1,       # 9
        px.colors.qualitative.Pastel2,    # 8
        px.colors.qualitative.Pastel1,    # 9
        px.colors.qualitative.Safe,       # 11
        px.colors.qualitative.Vivid,      # 11
        px.colors.qualitative.D3,         # 10
    ]
    pool = []
    for s in seqs:
        pool.extend(s)
    k = len(classes)
    if len(pool) < k:
        t = np.linspace(0.0, 1.0, k, endpoint=True)
        cont = [pc.sample_colorscale("Turbo", [ti])[0] for ti in t]
        colors = cont
    else:
        colors = pool[:k]
    return {cls: colors[i] for i, cls in enumerate(classes)}

def pca3d_by_relation_and_plot(reps: np.ndarray,
                               words: List[str],
                               classes_arr: np.ndarray,
                               all_classes: List[str],
                               model_tag: str,
                               html_out: Path):
    """
    Build one 3D scatter trace per class per layer, with a layer slider.
    """
    L, N, D = reps.shape
    print(f"PCA plotting on {N:,} tokens across {L} layers…")
    # PCA per layer
    Y_layers: List[np.ndarray] = []
    for l in range(L):
        Y_layers.append(_pca3d_layer(reps[l]))  # (N,3)

    cmap = _qual_palette_for_classes(all_classes)
    traces = []
    for l in range(L):
        Y = Y_layers[l]
        show_legend = (l == 0)
        for j, c in enumerate(all_classes):
            mask = (classes_arr == c)
            if not np.any(mask):
                x = y = z = []
                hov = []
            else:
                x, y, z = Y[mask, 0], Y[mask, 1], Y[mask, 2]
                hov = [f"{w} | rel={c}" for w in np.asarray(words)[mask]]
            traces.append(
                go.Scatter3d(
                    x=x, y=y, z=z,
                    mode="markers",
                    marker=dict(size=2, opacity=0.75, color=cmap[c]),
                    name=f"{c} (Layer {l})",
                    hovertext=hov,
                    # NOTE: not an f-string; contains %{...} tokens for Plotly
                    hovertemplate=(
                        "<b>%{hovertext}</b><br>"
                        "x=%{x:.3f}<br>y=%{y:.3f}<br>z=%{z:.3f}"
                        "<extra></extra>"
                    ),
                    visible=(l == 0),
                    showlegend=show_legend,
                    legendgroup=c
                )
            )

    # Slider: toggle visibility for the traces of the selected layer
    n_per_layer = len(all_classes)
    n_total = n_per_layer * L
    steps = []
    for l in range(L):
        vis = [False] * n_total
        s = l * n_per_layer
        vis[s : s + n_per_layer] = [True] * n_per_layer
        steps.append(dict(
            method="update",
            args=[{"visible": vis},
                  {"title": f"{model_tag} • PCA 3D by relation_type • Layer {l} (drag to rotate)"}],
            label=str(l),
        ))

    sliders = [dict(
        active=0, steps=steps,
        currentvalue={"prefix": "Layer: "},
        pad={"t": 10}
    )]

    layout = go.Layout(
        title=f"{model_tag} • PCA 3D by relation_type • Layer 0 (drag to rotate)",
        scene=dict(xaxis_title="PC1", yaxis_title="PC2", zaxis_title="PC3", aspectmode="data"),
        margin=dict(l=0, r=0, b=0, t=40),
        sliders=sliders,
        showlegend=True,
        legend=dict(title="relation_type (Top‑K)", itemsizing="trace")
    )

    fig = go.Figure(data=traces, layout=layout)
    fig.show()
    html_out = Path(html_out)  # ensure Path-like
    fig.write_html(str(html_out), include_plotlyjs="cdn")
    print("✓ Saved interactive HTML to:", html_out.resolve())

# =============================== DRIVER ===============================
def run_pca3d_relation_type():
    # 1) Load top‑K relation classes
    df_sent, rel_df, top_rel = load_relations_topk_from_column(CSV_PATH, top_k=TOP_K_REL)
    classes = list(top_rel)  # keep frequency order
    print(f"✓ plotting subset — {len(rel_df):,} tokens across relation classes {classes}")

    # 2) Optional per-class cap
    raw_df = sample_per_class(rel_df, PCA_MAX_PER_CLASS)

    # 3) Embed once
    reps, filled, resolved_id = embed_subset(df_sent, raw_df, BASELINE, WORD_REP_MODE, BATCH_SIZE)
    raw_df = raw_df.reset_index(drop=True).loc[filled].reset_index(drop=True)

    # Labels and words for hover
    cls_arr = raw_df.relation_class.values.astype(str)
    words   = raw_df.word.astype(str).tolist()

    # 4) PCA→3D per layer + Plotly
    html_out = OUT_DIR / f"{resolved_id.replace('/','_')}_relation_type_pca3d_layers.html"
    pca3d_by_relation_and_plot(
        reps.astype(np.float32, copy=False), words, cls_arr, classes,
        model_tag=resolved_id, html_out=html_out
    )

    # Cleanup
    del reps; gc.collect()
    if device == "cuda":
        torch.cuda.empty_cache()

if __name__ == "__main__":
    run_pca3d_relation_type()
