# Metric correlation analysis (Spearman)

This notebook produces the **pairwise Spearman rank correlations** between the geometric metrics used in the paper and saves:

- `spearman_corr.csv` (correlation coefficients)
- `spearman_pvals.csv` (p-values)
- `spearman_corr_with_stars.csv` (pretty table with significance stars)
- `metric_correlation_heatmap.pdf` (figure)

**Input:** a `metric_points.csv` file (one row = one random subsample / “point”; columns = metric values).  
The notebook is intentionally lightweight: it does *not* re-embed a model unless you want it to.

> Tip: if your `metric_points.csv` contains multiple sample sizes (column `n`), set `N_FILTER` below to select one.


In [None]:
from pathlib import Path
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# seaborn is optional; matplotlib-only also works
try:
    import seaborn as sns
    _HAS_SNS = True
except Exception:
    _HAS_SNS = False

from scipy.stats import spearmanr

# -------------------- paths --------------------
# Point this to your exported metric points
POINTS_CSV = Path("metric_points.csv")  # <- change if needed

OUT_DIR = Path("outputs") / "metric_correlations"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# -------------------- filtering --------------------
# If your file contains multiple values of n (subsample size), set N_FILTER to an int (e.g., 2000).
# Otherwise leave as None.
N_FILTER = None

# -------------------- plot style --------------------
AXIS_FONTSIZE = 12
TICK_FONTSIZE = 10
TITLE_FONTSIZE = 12

HEATMAP_FIGSIZE = (10, 8)   # adjust for number of metrics
HEATMAP_VMIN, HEATMAP_VMAX = -1.0, 1.0

print("POINTS_CSV:", POINTS_CSV.resolve())
print("OUT_DIR   :", OUT_DIR.resolve())
print("seaborn   :", _HAS_SNS)


In [None]:
# -------------------- load metric points --------------------
assert POINTS_CSV.exists(), f"Could not find {POINTS_CSV} (set POINTS_CSV to the right path)."

points = pd.read_csv(POINTS_CSV)

# Optional: if a point_id column exists, make it the index
for idx_col in ["point_id", "id"]:
    if idx_col in points.columns:
        points = points.set_index(idx_col)
        break

# If your generator wrote an 'n' column, optionally filter by n
if "n" in points.columns:
    uniq_n = sorted(points["n"].dropna().unique().tolist())
    print("n values in file:", uniq_n)
    if N_FILTER is not None:
        points = points[points["n"] == int(N_FILTER)].copy()
        print("Filtered to n =", N_FILTER, "-> rows:", len(points))
    else:
        # if multiple n exist and N_FILTER is None, keep all rows but warn
        if len(uniq_n) > 1:
            print("Warning: multiple n present. Consider setting N_FILTER to avoid mixing scales.")

print("rows:", len(points), "cols:", len(points.columns))
points.head()


In [None]:
# -------------------- metric name cleanup --------------------
# These are *display* names used in the plot/table. They don't change your raw data.
RENAME_METRICS = {
    # PCA naming
    "PCA": "PCA@99",
    "pca": "PCA@99",
    "pca99": "PCA@99",
    "PCA (99%)": "PCA@99",

    # skdim lPCA global-fit labels (your usage) -> clearer names
    "lpca99": "PCA@99",
    "lpca95": "PCA@95",
    "lpca": "PCA FO",  # FO rule / eigen-threshold variant
    "lPCA": "PCA FO",
    "lPCA FO": "PCA FO",
    "Local PCA (lPCA FO)": "PCA FO",

    # cosmetics
    "vmf_kappa": "vMF κ",
    "iso": "IsoScore",
    "sf": "Spectral flatness",
    "erank": "Effective rank",
    "pr": "Participation ratio",
    "stable_rank": "Stable rank",
    "twonn": "TwoNN",
    "gride": "GRIDE",
    "corrint": "CorrInt",
    "fishers": "FisherS",
    "mom": "MOM",
    "tle": "TLE",
    "mle": "MLE",
    "ess": "ESS",
    "mada": "MADA",
    "knn": "KNN",
}

points = points.rename(columns={c: RENAME_METRICS.get(c, c) for c in points.columns})

# Keep only numeric columns that look like metrics (drop metadata columns if present)
META_COLS = {"n", "N", "seed", "layer", "layer_idx", "model", "subset", "class"}
metric_cols = [c for c in points.columns if c not in META_COLS and pd.api.types.is_numeric_dtype(points[c])]

metrics = points[metric_cols].copy()

# Drop all-NaN metrics
all_nan = [c for c in metrics.columns if metrics[c].isna().all()]
if all_nan:
    print("Dropping all-NaN metrics:", all_nan)
    metrics = metrics.drop(columns=all_nan)

print("Using metric columns:", list(metrics.columns))
metrics.describe().T[["count","mean","std","min","max"]]


In [None]:
# -------------------- metric ordering / families --------------------
ISO_METRICS = ["IsoScore", "Spectral flatness", "vMF κ"]

LINEAR_ID_METRICS = [
    "Effective rank",
    "Participation ratio",
    "Stable rank",
    "PCA FO",
    "PCA@95",
    "PCA@99",
]

NONLINEAR_ID_METRICS = [
    "TwoNN",
    "GRIDE",
    "CorrInt",
    "FisherS",
    "MLE",
    "MOM",
    "TLE",
    "ESS",
    "MADA",
    "KNN",
]

# Keep only those present (so the notebook doesn't break if you didn't compute some estimators)
def _keep_present(order, cols):
    return [m for m in order if m in cols]

iso_order = _keep_present(ISO_METRICS, metrics.columns)
lin_order = _keep_present(LINEAR_ID_METRICS, metrics.columns)
nl_order  = _keep_present(NONLINEAR_ID_METRICS, metrics.columns)

ORDER = iso_order + lin_order + nl_order + [c for c in metrics.columns if c not in (set(iso_order)|set(lin_order)|set(nl_order))]

metrics = metrics[ORDER]

print("Order used (len=%d):" % len(ORDER))
print(ORDER)

# handy for drawing separators
CUTS = {
    "iso_end": len(iso_order),
    "lin_end": len(iso_order) + len(lin_order),
    "nl_end":  len(iso_order) + len(lin_order) + len(nl_order),
}
CUTS


In [None]:
# -------------------- Spearman correlation (pairwise with NaN handling) --------------------
def spearman_pairwise(df: pd.DataFrame):
    cols = list(df.columns)
    k = len(cols)
    corr = np.full((k, k), np.nan, dtype=float)
    pval = np.full((k, k), np.nan, dtype=float)
    nobs = np.zeros((k, k), dtype=int)

    for i, a in enumerate(cols):
        for j, b in enumerate(cols):
            if j < i:
                continue
            x = df[a].to_numpy(dtype=float)
            y = df[b].to_numpy(dtype=float)
            m = np.isfinite(x) & np.isfinite(y)
            n = int(m.sum())
            nobs[i, j] = nobs[j, i] = n
            if n < 3:
                continue
            r, p = spearmanr(x[m], y[m])
            corr[i, j] = corr[j, i] = float(r)
            pval[i, j] = pval[j, i] = float(p)

    corr_df = pd.DataFrame(corr, index=cols, columns=cols)
    pval_df = pd.DataFrame(pval, index=cols, columns=cols)
    nobs_df = pd.DataFrame(nobs, index=cols, columns=cols)
    return corr_df, pval_df, nobs_df

corr, pval, nobs = spearman_pairwise(metrics)

print("Computed correlations for", corr.shape[0], "metrics.")
corr.iloc[:5, :5]


In [None]:
# -------------------- significance stars + save tables --------------------
def p_to_stars(p: float) -> str:
    if not np.isfinite(p):
        return ""
    if p < 1e-3:
        return "***"
    if p < 1e-2:
        return "**"
    if p < 5e-2:
        return "*"
    return ""

stars = pval.applymap(p_to_stars)

# pretty table: "0.87***"
corr_star = corr.copy()
for r in corr_star.index:
    for c in corr_star.columns:
        if np.isfinite(corr_star.loc[r, c]):
            corr_star.loc[r, c] = f"{corr_star.loc[r, c]:.2f}{stars.loc[r, c]}"
        else:
            corr_star.loc[r, c] = ""

# save
corr.to_csv(OUT_DIR / "spearman_corr.csv")
pval.to_csv(OUT_DIR / "spearman_pvals.csv")
nobs.to_csv(OUT_DIR / "spearman_nobs.csv")
corr_star.to_csv(OUT_DIR / "spearman_corr_with_stars.csv")

print("Saved:")
print(" ", OUT_DIR / "spearman_corr.csv")
print(" ", OUT_DIR / "spearman_pvals.csv")
print(" ", OUT_DIR / "spearman_nobs.csv")
print(" ", OUT_DIR / "spearman_corr_with_stars.csv")

corr_star.iloc[:8, :8]


In [None]:
# -------------------- plot heatmap --------------------
# Use seaborn if available (nicer labels); otherwise fallback to matplotlib.
heat_df = corr.copy()

fig, ax = plt.subplots(figsize=HEATMAP_FIGSIZE)

if _HAS_SNS:
    sns.set_style("white")
    sns.heatmap(
        heat_df,
        vmin=HEATMAP_VMIN, vmax=HEATMAP_VMAX,
        cmap="coolwarm",
        square=True,
        linewidths=0.5,
        cbar_kws={"label": "Spearman $\rho$"},
        ax=ax,
    )
else:
    im = ax.imshow(
        heat_df.to_numpy(),
        vmin=HEATMAP_VMIN, vmax=HEATMAP_VMAX,
        cmap="coolwarm",
        aspect="equal",
    )
    fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Spearman $\rho$")
    ax.set_xticks(np.arange(len(heat_df.columns)))
    ax.set_yticks(np.arange(len(heat_df.index)))
    ax.set_xticklabels(heat_df.columns)
    ax.set_yticklabels(heat_df.index)

ax.set_title("Metric correlations (Spearman)", fontsize=TITLE_FONTSIZE)

ax.tick_params(axis="x", labelsize=TICK_FONTSIZE, rotation=45, ha="right")
ax.tick_params(axis="y", labelsize=TICK_FONTSIZE)

# family separators (if we have at least 2 families present)
for cut in [CUTS["iso_end"], CUTS["lin_end"], CUTS["nl_end"]]:
    if 0 < cut < len(heat_df.columns):
        ax.axhline(cut, color="black", linewidth=1.2)
        ax.axvline(cut, color="black", linewidth=1.2)

fig.tight_layout()
fig.savefig(OUT_DIR / "metric_correlation_heatmap.pdf", bbox_inches="tight")
fig.savefig(OUT_DIR / "metric_correlation_heatmap.png", dpi=300, bbox_inches="tight")
plt.show()

print("Saved:")
print(" ", OUT_DIR / "metric_correlation_heatmap.pdf")


### Optional: inspect strongest (absolute) correlations

This is sometimes useful to sanity-check that metrics that are supposed to be related (e.g., different spectrum-based proxies) actually cluster together.


In [None]:
# List the top |rho| off-diagonal pairs
pairs = []
cols = list(corr.columns)
for i in range(len(cols)):
    for j in range(i+1, len(cols)):
        r = corr.iloc[i, j]
        if np.isfinite(r):
            pairs.append((abs(r), r, cols[i], cols[j], int(nobs.iloc[i, j])))

pairs = sorted(pairs, reverse=True)
pd.DataFrame(pairs[:25], columns=["abs_rho", "rho", "metric_a", "metric_b", "n_obs"])
