In [1]:
from __future__ import annotations
import os, ast, random, inspect
from pathlib import Path
from typing import Dict, List

import numpy as np, pandas as pd, torch
import torch.utils.data as torchdata
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt, seaborn as sns
from transformers import AutoTokenizer, AutoModel, BertModel, AutoConfig

from IsoScore import IsoScore
from dadapy import Data
from skdim.id import MLE, MOM, TLE, CorrInt, FisherS, lPCA

2025-10-31 09:55:19.165274: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Correct

In [10]:
# -*- coding: utf-8 -*-
from __future__ import annotations

# ------------------------- CONFIG YOU LIKELY WANT TO TWEAK -------------------------
CSV_PATH         = "en_ewt-ud-train_sentences.csv"
BASELINE         = "bert-base-uncased"
WORD_REP_MODE    = "first"               # {"first","last","mean"}
EXCLUDE_POS      = {"X","SYM","INTJ","PART"}

MIN_PER_SHARD    = 5000                  # <-- your “metrics need ≥5k samples” requirement
# --- correlation points: want ≥ 47 ---
# For BERT-base, L = 13 layers  => set shards to 4 → 13 * 4 = 52 points
SHARDS_K = 4

# --- heavy metrics need ≥ 5000 samples per layer ---
HEAVY_MAX_PER_LAYER = 5000

# Ensure each shard can supply ~5000 tokens per layer:
# N_TOTAL_CAP should be at least 5000 * SHARDS_K
N_TOTAL_CAP = max(N_TOTAL_CAP, 5000 * SHARDS_K)  # e.g., 20_000 for 4 shards

BATCH_SIZE       = 16                    # embedding batch size (tune for your GPU/CPU)
RAND_SEED        = 42
DADAPY_GRID_RANGE= 32                    # GRIDE neighborhood range
HEAVY_MIN_PER_L  = 5000                  # per-layer sample floor for heavy metrics
HEAVY_MAX_PER_L  = 5000                  # also cap at 5k to bound RAM/time
DUP_JITTER_EPS   = 1e-6                  # tiny noise to avoid duplicate-row warnings

# Plot / output
NORMALIZE_HEATMAP_01 = False             # set True to map [-1,1] -> [0,1] in heatmaps
OUT_DIR  = Path("metric_corr_pos"); OUT_DIR.mkdir(parents=True, exist_ok=True)
PLOT_DIR = OUT_DIR / "plots";  PLOT_DIR.mkdir(exist_ok=True)
CSV_DIR  = OUT_DIR / "tables"; CSV_DIR.mkdir(exist_ok=True)

# -----------------------------------------------------------------------------------

# Stats + multiple testing
from scipy.stats import spearmanr
HAS_STATSMODELS = False
try:
    from statsmodels.stats.multitest import multipletests
    HAS_STATSMODELS = True
except Exception:
    pass

# System / libs
import os, gc, ast, random, inspect
from pathlib import Path
from typing import Dict, List, Tuple, Callable

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
from transformers import AutoTokenizer, AutoModel

# Optional deps
HAS_ISOSCORE = False
try:
    from isoscore import IsoScore
    HAS_ISOSCORE = True
except Exception:
    class _IsoScoreFallback:
        @staticmethod
        def IsoScore(X: np.ndarray) -> float:
            C = np.cov(X.T, ddof=0)
            ev = np.linalg.eigvalsh(C)
            if ev.mean() <= 0 or ev[-1] <= 0: return 0.0
            return float(np.clip(ev.mean() / ev[-1], 0.0, 1.0))
    IsoScore = _IsoScoreFallback()

HAS_DADAPY = False
try:
    from dadapy import Data
    HAS_DADAPY = True
except Exception:
    pass

HAS_SKDIM = False
try:
    from skdim.id import (
        MOM, TLE, CorrInt, FisherS, lPCA, MLE, ESS, MADA, KNN
    )
    HAS_SKDIM = True
except Exception:
    pass

# Repro & device
os.environ["TOKENIZERS_PARALLELISM"] = "true"
random.seed(RAND_SEED); np.random.seed(RAND_SEED); torch.manual_seed(RAND_SEED)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda": torch.backends.cudnn.benchmark = True

sns.set_style("darkgrid")
plt.rcParams["figure.dpi"] = 120


# ------------------------------ utility helpers ------------------------------
def _model_tag(name: str) -> str:
    return name.split("/")[-1].replace(":", "_")

def _to_list(x):
    if isinstance(x, str) and x.startswith("["):
        try:
            return ast.literal_eval(x)
        except Exception:
            return []
    return x

def _center(X): return X - X.mean(0, keepdims=True)


def _has_dupes(X: np.ndarray) -> bool:
    try:
        return np.unique(X, axis=0).shape[0] < X.shape[0]
    except Exception:
        return False

def _jitter_unique(X: np.ndarray, eps: float = 1e-6) -> np.ndarray:
    """
    Return X with exact duplicates removed and (if needed) tiny noise added
    to break remaining ties. All operations in float32 to avoid float16 ties.
    """
    X = X.astype(np.float32, copy=False)

    # Remove exact duplicates first (cheap and stable)
    try:
        Xu = np.unique(X, axis=0)
    except Exception:
        Xu = X  # fallback if axis=0 unique unsupported (older numpy)

    # If we dropped many rows or still suspect ties, add tiny jitter
    if Xu.shape[0] < X.shape[0] or _has_dupes(Xu):
        noise = np.random.normal(scale=eps, size=Xu.shape).astype(np.float32)
        Xu = Xu + noise
        # Optionally enforce uniqueness again (mostly for peace of mind)
        try:
            Xu = np.unique(Xu, axis=0)
        except Exception:
            pass
    return Xu
def _prep_for_knn(
    X: np.ndarray,
    cap: int | None,
    rng: np.random.Generator = np.random.default_rng(42),
    jitter_scale: float = 1e-6,
) -> np.ndarray:
    """
    1) Make contiguous float32
    2) Drop exact duplicate rows
    3) Optionally subsample to 'cap'
    4) Add tiny jitter big enough for float32 to break remaining ties
    """
    X = np.ascontiguousarray(X, dtype=np.float32)

    # Drop exact duplicates first (fast, low‑mem)
    X = np.unique(X, axis=0)

    # Subsample AFTER dedup (so cap doesn't reintroduce dup indices)
    if (cap is not None) and (X.shape[0] > cap):
        idx = rng.choice(X.shape[0], cap, replace=False)
        X = X[idx]

    # Add small jitter (scaled to global std; ensures it survives float32)
    std = float(X.std()) or 1.0
    eps = jitter_scale * std
    X = X + rng.normal(0.0, eps, size=X.shape).astype(np.float32)
    return X


def load_all_words(csv_path: str, exclude_pos: set[str] = EXCLUDE_POS) -> tuple[pd.DataFrame, pd.DataFrame]:
    df = pd.read_csv(csv_path, usecols=["sentence_id","tokens","pos"])
    df["sentence_id"] = df["sentence_id"].astype(str)
    df.tokens = df.tokens.apply(_to_list); df.pos = df.pos.apply(_to_list)
    rows = []
    for sid, toks, poss in df[["sentence_id","tokens","pos"]].itertuples(index=False):
        for wid, p in enumerate(poss):
            if (exclude_pos is None) or (p not in exclude_pos):
                rows.append((sid, wid))
    all_words = pd.DataFrame(rows, columns=["sentence_id","word_id"])
    return df, all_words

def pick_words_for_shards(all_words: pd.DataFrame,
                          shards_k: int,
                          min_per_shard: int,
                          cap: int | None = N_TOTAL_CAP) -> list[pd.DataFrame]:
    """Return a list of shard DataFrames, each with exactly min_per_shard rows."""
    total_needed = shards_k * min_per_shard
    if cap is None or cap < total_needed:
        cap = total_needed
    if len(all_words) < total_needed:
        raise ValueError(f"Not enough tokens: need {total_needed}, have {len(all_words)}.")

    pool = all_words.sample(cap, random_state=RAND_SEED).reset_index(drop=True)
    idx = np.arange(len(pool))
    np.random.default_rng(RAND_SEED).shuffle(idx)
    shards = []
    for s in range(shards_k):
        start = s * min_per_shard
        end   = start + min_per_shard
        sel   = pool.iloc[idx[start:end]].reset_index(drop=True)
        shards.append(sel)
    return shards

# ------------------------------ embedding (per shard) ------------------------------
def embed_words(df_sent: pd.DataFrame, words_df: pd.DataFrame,
                baseline: str = BASELINE, word_rep_mode: str = WORD_REP_MODE,
                batch_size: int = BATCH_SIZE) -> np.ndarray:
    """
    Embed ONLY the words in words_df and return reps of shape (L, N, D) for this shard.
    """
    # Build sid -> list[(gidx, wid)]
    words_df["sentence_id"] = words_df["sentence_id"].astype(str)
    by_sid: Dict[str, List[Tuple[int,int]]] = {}
    for gidx, (sid, wid) in enumerate(words_df.itertuples(index=False)):
        by_sid.setdefault(sid, []).append((gidx, int(wid)))

    sids = list(by_sid.keys())
    df_sel = (df_sent[df_sent.sentence_id.isin(sids)]
              .drop_duplicates("sentence_id")
              .set_index("sentence_id")
              .loc[sids])

    tokzr = AutoTokenizer.from_pretrained(baseline, use_fast=True)
    enc_kwargs = dict(is_split_into_words=True, return_tensors="pt", padding=True)
    if "add_prefix_space" in inspect.signature(tokzr.__call__).parameters:
        enc_kwargs["add_prefix_space"] = True

    model = AutoModel.from_pretrained(baseline, output_hidden_states=True).eval().to(device)
    if device == "cuda": model.half()

    # model dims
    L = (getattr(model.config, "num_hidden_layers", None) or getattr(model.config, "n_layer", 0)) + 1
    D = (getattr(model.config, "hidden_size", None)     or getattr(model.config, "n_embd", 0))
    N = len(words_df)

    reps   = np.zeros((L, N, D), np.float16)   #reps   = np.zeros((L, N, D), np.float32)
    filled = np.zeros(N, dtype=bool)

    tag = _model_tag(baseline)
    with torch.no_grad(), torch.cuda.amp.autocast(device == "cuda"):
        for start in tqdm(range(0, len(sids), batch_size), desc=f"Embed ({tag}) shard"):
            batch_ids    = sids[start : start + batch_size]
            batch_tokens = df_sel.loc[batch_ids, "tokens"].tolist()

            enc_be = tokzr(batch_tokens, **enc_kwargs)
            enc_t  = {k: v.to(device) for k, v in enc_be.items()}
            out = model(**enc_t)
            h = torch.stack(out.hidden_states).detach().cpu().numpy().astype(np.float32)  # (L,B,T,D)

            for b, sid in enumerate(batch_ids):
                mp = {}
                for tidx, wid in enumerate(enc_be.word_ids(b)):
                    if wid is not None:
                        mp.setdefault(int(wid), []).append(int(tidx))
                for gidx, wid in by_sid.get(sid, []):
                    toks = mp.get(wid)
                    if not toks: continue
                    if word_rep_mode == "first":
                        vec = h[:, b, toks[0], :]
                    elif word_rep_mode == "last":
                        vec = h[:, b, toks[-1], :]
                    else:
                        vec = h[:, b, toks, :].mean(axis=1)
                    reps[:, gidx, :] = vec.astype(np.float16, copy=False)
                    filled[gidx] = True

            del enc_be, enc_t, out, h
            if device == "cuda": torch.cuda.empty_cache()

    miss = int((~filled).sum())
    if miss:
        print(f"⚠ Missing vectors for {miss} of {N} sampled words in this shard (dropping them).")
        reps = reps[:, filled]
    del model; gc.collect()
    if device == "cuda": torch.cuda.empty_cache()
    return reps  # (L, N_filled, D)

# ------------------------------ spectral helpers & metrics ------------------------------
EPS = 1e-12

def _eigvals_from_X(X: np.ndarray) -> np.ndarray:
    Xc = _center(X.astype(np.float32, copy=False))
    try:
        _, S, _ = np.linalg.svd(Xc, full_matrices=False)
        lam = (S**2).astype(np.float64)
        lam.sort(); return lam[::-1]
    except Exception:
        return np.array([], dtype=np.float64)

def iso_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        try: out[l] = float(IsoScore.IsoScore(rep[l].astype(np.float32)))
        except: pass
    return out

def spectral_flatness_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        lam = _eigvals_from_X(rep[l]); 
        if lam.size: 
            gm = np.exp(np.mean(np.log(lam + EPS))); am = float(lam.mean() + EPS)
            out[l] = float(gm / am)
    return out

def vmf_kappa_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        X = rep[l].astype(np.float32, copy=False)
        if X.shape[0] < 2: continue
        Xn = X / (np.linalg.norm(X, axis=1, keepdims=True) + 1e-9)
        R = np.linalg.norm(Xn.mean(axis=0)); d = Xn.shape[1]
        kappa = 0.0 if R < 1e-9 else R * (d - R**2) / (1.0 - R**2 + 1e-9)
        out[l] = float(max(kappa, 0.0))
    return out

def erank_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        lam = _eigvals_from_X(rep[l])
        if lam.size:
            p = lam / (lam.sum() + EPS); H = -(p * np.log(p + EPS)).sum()
            out[l] = float(np.exp(H))
    return out

def pr_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        lam = _eigvals_from_X(rep[l])
        if lam.size:
            s1 = lam.sum(); s2 = (lam**2).sum()
            out[l] = float((s1**2) / (s2 + EPS))
    return out

def stable_rank_per_layer(rep: np.ndarray) -> np.ndarray:
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    for l in range(L):
        lam = _eigvals_from_X(rep[l])
        if lam.size:
            out[l] = float(lam.sum() / (lam.max() + EPS))
    return out

def _cap_layer(X: np.ndarray, cap: int | None = HEAVY_MAX_PER_LAYER) -> np.ndarray:
    n = X.shape[0]
    if cap is None or n <= cap:
        return X.astype(np.float32, copy=False)
    idx = np.random.default_rng(RAND_SEED).choice(n, int(cap), replace=False)
    return X[idx].astype(np.float32, copy=False)

def twonn_per_layer(rep: np.ndarray) -> np.ndarray:
    if not HAS_DADAPY: return np.full(rep.shape[0], np.nan, np.float32)
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    rng = np.random.default_rng(RAND_SEED)
    for l in range(L):
        X = _prep_for_knn(rep[l], cap=HEAVY_MAX_PER_LAYER, rng=rng)
        if X.shape[0] < 3: continue
        try:
            d = Data(coordinates=X)
            id_est, _, _ = d.compute_id_2NN()
            out[l] = float(id_est)
        except Exception:
            pass
    return out

def gride_per_layer(rep: np.ndarray, range_max: int = DADAPY_GRID_RANGE) -> np.ndarray:
    if not HAS_DADAPY: return np.full(rep.shape[0], np.nan, np.float32)
    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    rng = np.random.default_rng(RAND_SEED)
    for l in range(L):
        X = _prep_for_knn(rep[l], cap=HEAVY_MAX_PER_LAYER, rng=rng)
        if X.shape[0] < 20: continue
        try:
            d = Data(coordinates=X); d.compute_distances(maxk=range_max)
            ids, _, _ = d.return_id_scaling_gride(range_max=range_max)
            out[l] = float(ids[-1])
        except Exception:
            pass
    return out

def skdim_layer(rep: np.ndarray, est_name: str) -> np.ndarray:
    if not HAS_SKDIM: return np.full(rep.shape[0], np.nan, np.float32)
    makers = {
        "mom": lambda: MOM(), "tle": lambda: TLE(), "corrint": lambda: CorrInt(),
        "fishers": lambda: FisherS(), "lpca": lambda: lPCA(ver="FO"),
        "lpca99": lambda: lPCA(ver="ratio", alphaRatio=0.99),
        "lpca95": lambda: lPCA(ver="ratio", alphaRatio=0.95),
        "mle": lambda: MLE(), "ess": lambda: ESS(), "mada": lambda: MADA(),
    }
    make = makers.get(est_name)
    if make is None: return np.full(rep.shape[0], np.nan, np.float32)

    L = rep.shape[0]; out = np.full(L, np.nan, np.float32)
    rng = np.random.default_rng(RAND_SEED)
    for l in range(L):
        X = _prep_for_knn(rep[l], cap=HEAVY_MAX_PER_LAYER, rng=rng)
        if X.shape[0] < 3: continue
        try:
            est = make(); est.fit(X)
            out[l] = float(getattr(est, "dimension_", np.nan))
        except Exception:
            pass
    return out


# registry
METRIC_FUNS: Dict[str, Callable[[np.ndarray], np.ndarray]] = {
    "iso": iso_per_layer,
    "sf": spectral_flatness_per_layer,
    "vmf_kappa": vmf_kappa_per_layer,   # (anisotropy↑)
    "erank": erank_per_layer, "pr": pr_per_layer, "stable_rank": stable_rank_per_layer,
    "twonn": twonn_per_layer, "gride": gride_per_layer,
    "mom": lambda R: skdim_layer(R,"mom"),
    "tle": lambda R: skdim_layer(R,"tle"),
    "corrint": lambda R: skdim_layer(R,"corrint"),
    "fishers": lambda R: skdim_layer(R,"fishers"),
    "mle": lambda R: skdim_layer(R,"mle"),
    "ess": lambda R: skdim_layer(R,"ess"),
    "mada": lambda R: skdim_layer(R,"mada"),
    "lpca": lambda R: skdim_layer(R,"lpca"),
    "lpca99": lambda R: skdim_layer(R,"lpca99"),
    "lpca95": lambda R: skdim_layer(R,"lpca95"),
}

LABELS = {
    "iso":"IsoScore","sf":"Spectral Flatness","vmf_kappa":"vMF κ",
    "erank":"Effective Rank","pr":"Participation Ratio","stable_rank":"Stable Rank",
    "twonn":"TwoNN ID","gride":"GRIDE ID","mom":"MOM","tle":"TLE",
    "corrint":"CorrInt","fishers":"FisherS","mle":"MLE","ess":"ESS","mada":"MADA",
    "lpca":"lPCA (FO)","lpca99":"lPCA 0.99","lpca95":"lPCA 0.95",
}

FAMILIES = {
    "isotropy":   ["iso","sf","vmf_kappa"],
    "linear_id":  ["erank","pr","stable_rank","lpca","lpca95","lpca99"],
    "nonlinear":  ["twonn","gride","mom","tle","corrint","fishers","mle","ess","mada"],
}

# ------------------------------ correlation & plotting ------------------------------
def compute_metrics_for_shard(rep: np.ndarray, shard_idx: int) -> pd.DataFrame:
    rows = []
    for m, fn in METRIC_FUNS.items():
        vals = fn(rep)   # (L,)
        for l, v in enumerate(vals.tolist()):
            rows.append({"layer": l, "shard": shard_idx, "metric": m,
                         "value": (float(v) if (v is not None and np.isfinite(v)) else np.nan)})
    return pd.DataFrame(rows)

def build_spearman_corr_and_pvals(df_long: pd.DataFrame) -> tuple[pd.DataFrame, pd.DataFrame, int]:
    dfp = df_long.pivot_table(index=["layer","shard"], columns="metric",
                              values="value", aggfunc="mean")
    dfp = dfp.dropna(axis=1, how="all").dropna(axis=0, how="any")
    n_points = dfp.shape[0]
    if dfp.shape[1] < 2 or n_points < 2:
        cols = dfp.columns
        return (pd.DataFrame(np.nan, index=cols, columns=cols),
                pd.DataFrame(np.nan, index=cols, columns=cols), n_points)
    res = spearmanr(dfp.values, axis=0, nan_policy='omit')
    rho = pd.DataFrame(res.statistic, index=dfp.columns, columns=dfp.columns)
    p   = pd.DataFrame(res.pvalue,   index=dfp.columns, columns=dfp.columns)
    return rho, p, n_points

def fdr_bh_matrix(pval_df: pd.DataFrame, alpha: float = 0.05) -> tuple[pd.DataFrame, pd.DataFrame | None]:
    if not HAS_STATSMODELS:
        return pval_df.copy(), None
    p = pval_df.to_numpy(copy=True); k = p.shape[0]
    mask = np.ones_like(p, dtype=bool); np.fill_diagonal(mask, False)
    r, c = np.where(mask); p_vec = p[r, c]
    reject, q_vec, *_ = multipletests(p_vec, alpha=alpha, method="fdr_bh")
    q = np.full_like(p, np.nan, dtype=float); q[r, c] = q_vec
    rej = np.zeros_like(p, dtype=bool); rej[r, c] = reject
    return (pd.DataFrame(q, index=pval_df.index, columns=pval_df.columns),
            pd.DataFrame(rej, index=pval_df.index, columns=pval_df.columns))

def _grouped_order(cols: list[str]) -> tuple[list[str], list[int]]:
    present = [c for c in cols if isinstance(c, str)]
    order, bounds = [], []
    for fam in ("isotropy","linear_id","nonlinear"):
        wanted = FAMILIES[fam]
        have   = [m for m in wanted if m in present]
        if have:
            order.extend(have); bounds.append(len(order))
    # leftovers
    rest = [c for c in present if c not in order]
    if rest: order.extend(rest); bounds.append(len(order))
    return order, bounds

def _reorder_grouped(rho_df: pd.DataFrame, pval_df: pd.DataFrame | None):
    order, bounds = _grouped_order(list(rho_df.columns))
    rho_o = rho_df.loc[order, order]
    p_o   = (pval_df.loc[order, order] if pval_df is not None else None)
    pretty = {m: LABELS.get(m, m) for m in order}
    return rho_o, p_o, pretty, bounds

def _draw_family_boxes(ax, bounds: list[int], lw: float=2.0):
    # bounds are cumulative lengths; draw rectangles around each block
    start = 0
    for b in bounds:
        size = b - start
        if size <= 0: 
            start = b; continue
        rect = Rectangle((start, start), width=size, height=size,
                         fill=False, lw=lw)
        ax.add_patch(rect)
        start = b

def plot_heatmap_grouped(corr: pd.DataFrame, title: str, out: Path,
                         pvals: pd.DataFrame | None = None, alpha: float = 0.05,
                         boxes: list[int] | None = None, normalize01: bool = NORMALIZE_HEATMAP_01):
    data = corr.copy()
    if normalize01:
        data = (data + 1.0) / 2.0  # map [-1,1] -> [0,1]
        vmin, vmax = 0.0, 1.0
    else:
        vmin, vmax = -1.0, 1.0
    plt.figure(figsize=(1.0+0.55*len(data.columns), 1.0+0.55*len(data.columns)))
    ax = sns.heatmap(data, vmin=vmin, vmax=vmax, annot=True, fmt=".2f",
                     square=True, cbar=True, annot_kws={"size":8})
    ax.set_title(title)
    # significance stars
    if pvals is not None and not pvals.empty:
        k = len(data)
        for i in range(k):
            for j in range(k):
                if i == j: continue
                p = pvals.iat[i, j]
                if np.isfinite(p) and p < alpha:
                    stars = "***" if p < 1e-3 else "**" if p < 1e-2 else "*"
                    ax.text(j + 0.5, i + 0.5, stars, ha="center", va="bottom",
                            fontsize=9, fontweight="bold")
    # family boxes
    if boxes:
        _draw_family_boxes(ax, boxes, lw=2.0)
    plt.tight_layout(); plt.savefig(out, dpi=220); plt.close()

# ------------------------------ main driver (streaming by shards) ------------------------------
def run():
    tag = _model_tag(BASELINE)
    print("• Loading corpus …")
    df_sent, all_words = load_all_words(CSV_PATH, EXCLUDE_POS)
    print(f"  tokens available after POS filter: {len(all_words):,}")

    # Build shards with exactly MIN_PER_SHARD words each
    shards = pick_words_for_shards(all_words, SHARDS_K, MIN_PER_SHARD, cap=N_TOTAL_CAP)
    print(f"• Using {SHARDS_K} shards × {MIN_PER_SHARD} words = {SHARDS_K*MIN_PER_SHARD:,} tokens")

    df_long_parts = []
    for s, words_s in enumerate(shards):
        print(f"\n=== Shard {s+1}/{SHARDS_K} (N={len(words_s):,}) ===")
        reps = embed_words(df_sent, words_s, BASELINE, WORD_REP_MODE, BATCH_SIZE)  # (L, N, D) for this shard
        L, N, D = reps.shape
        if N < HEAVY_MIN_PER_L:
            raise RuntimeError(f"Shard {s}: need ≥{HEAVY_MIN_PER_L} tokens for heavy metrics; got {N}.")
        df_s = compute_metrics_for_shard(reps, s)
        df_long_parts.append(df_s)
        # free shard memory
        del reps, df_s; gc.collect()
        if device == "cuda": torch.cuda.empty_cache()

    # concat all shards
    df_long = pd.concat(df_long_parts, ignore_index=True)
    df_long["model"] = BASELINE
    df_long["word_rep_mode"] = WORD_REP_MODE
    df_long.to_csv(CSV_DIR / f"metrics_all_long_{tag}.csv", index=False)
    print(f"\n✓ metrics saved → {CSV_DIR / f'metrics_all_long_{tag}.csv'}")

    # Correlation (global, across all metrics)
    rho_df, pval_df, n_pts = build_spearman_corr_and_pvals(df_long)
    rho_df.to_csv(CSV_DIR / f"corr_spearman_{tag}.csv")
    pval_df.to_csv(CSV_DIR / f"corr_spearman_pvals_{tag}.csv")
    print(f"Correlation points (layers × shards with complete rows): n = {n_pts}")

    # Optional FDR
    q_df, rej_mat = fdr_bh_matrix(pval_df, alpha=0.05)
    if HAS_STATSMODELS:
        q_df.to_csv(CSV_DIR / f"corr_spearman_qvals_fdrbh_{tag}.csv")

    # Grouped + pretty labels + boxes around families
    rho_g, p_for_plot, pretty, bounds = _reorder_grouped(rho_df, (q_df if HAS_STATSMODELS else pval_df))
    rho_g_pretty = rho_g.rename(index=pretty, columns=pretty)
    p_plot_pretty = (p_for_plot.rename(index=pretty, columns=pretty) if p_for_plot is not None else None)

    # Plain heatmap
    plot_heatmap_grouped(
        rho_g_pretty,
        f"Spearman correlation • {tag}",
        PLOT_DIR / f"corr_spearman_{tag}.png",
        pvals=None, boxes=bounds, normalize01=NORMALIZE_HEATMAP_01
    )
    # With significance
    plot_heatmap_grouped(
        rho_g_pretty,
        f"Spearman correlation • {tag} (significance)",
        PLOT_DIR / f"corr_spearman_with_sig_{tag}.png",
        pvals=p_plot_pretty, boxes=bounds, normalize01=NORMALIZE_HEATMAP_01
    )
    print("✓ plots stored in", PLOT_DIR.resolve())


import math




if __name__ == "__main__":
    run()
MIN_POINTS = 47
# ...
L, N, D = reps.shape
shards_needed = max(1, math.ceil(MIN_POINTS / L))
df_long = compute_metrics_over_shards(reps, shards_needed)

• Loading corpus …
  tokens available after POS filter: 189,167
• Using 4 shards × 5000 words = 20,000 tokens

=== Shard 1/4 (N=5,000) ===


Embed (bert-base-uncased) shard: 100%|████████| 231/231 [00:05<00:00, 41.25it/s]





KeyboardInterrupt

