# Cross-SAE Feature Alignment (FeatureMatch, cosine + top-k)

This example shows how to quantify correspondence between two SAE dictionaries using **FeatureMatch**.

🧪 **v0.1 scope**: cosine similarity matrix, top-k per-feature matches, summary stats, and a simple heatmap.

👉 By default this notebook runs on **synthetic data** (so it works anywhere). Replace the synthetic block with the **SAELens code collection** cell to use real models.


In [None]:
# If running on Colab or a fresh env, uncomment the line below to install from GitHub:
# !pip install "git+https://github.com/Course-Correct-Labs/featurematch.git"

import torch
from featurematch.featurematch import align_features
from featurematch.viz import plot_heatmap
import matplotlib.pyplot as plt

torch.manual_seed(0)

## Option A: Synthetic demo (default)
This section creates a permutation-aligned pair of code matrices (`Z_a`, `Z_b_perm`) and a random baseline (`Z_b_rand`).

- Expect **perfect alignment** for the permutation case (mean≈1.0).
- Expect **low alignment** for random codes (mean≈0.15–0.20).

In [None]:
N, K = 200, 64
Z_a = torch.randn(N, K)
perm = torch.randperm(K)
P = torch.zeros(K, K)
P[torch.arange(K), perm] = 1.0
Z_b_perm = Z_a @ P  # permutation case (perfect alignment)
Z_b_rand = torch.randn(N, K)  # random baseline

res_perm = align_features(Z_a, Z_b_perm, topk=5, threshold=0.8, device="cpu")
print("Permutation case stats:", res_perm.stats)
plot_heatmap(res_perm.cosine, title="FeatureMatch: Cosine (Permutation)")
plt.show()

res_rand = align_features(Z_a, Z_b_rand, topk=5, threshold=0.8, device="cpu")
print("Random case stats:", res_rand.stats)
plot_heatmap(res_rand.cosine, title="FeatureMatch: Cosine (Random)")
plt.show()

res_perm.top_matches[:3]  # preview first 3 rows

## Option B: Collect codes from real SAEs using SAELens (uncomment to use)
Use this section when you have two trained SAEs (same hook/layer) and an evaluation dataset. The **only requirement** is that both code matrices are `[N, K]` and derived from the **same** tokens/examples.

❗️Note: keep batch size modest to avoid OOM during code collection. Subsample if needed.

In [None]:
# %% Real-SAEs example (template) -----------------------------------------
# from sae_lens import SAE
# import torch
# from featurematch.featurematch import align_features
# from featurematch.viz import plot_heatmap
# import matplotlib.pyplot as plt

# # 1) Load your two SAEs (same hook/layer)
# sae_a = SAE.load_from_pretrained("PATH/OR/ALIAS/TO/SAE_A")
# sae_b = SAE.load_from_pretrained("PATH/OR/ALIAS/TO/SAE_B")
# sae_a.eval(); sae_b.eval()

# # 2) Prepare evaluation tokens (same for both)
# # tokens: LongTensor [N, T] or as required by your pipeline
# tokens = ...  # your eval batch(es)

# # 3) Collect codes Z_a, Z_b (shape [N, K])
# with torch.no_grad():
#     _, Z_a, _ = sae_a(tokens)  # adjust to your SAE forward signature
#     _, Z_b, _ = sae_b(tokens)

# # 4) Align & visualize
# res = align_features(Z_a, Z_b, topk=5, threshold=0.8, device="cuda" if torch.cuda.is_available() else "cpu")
# print("Alignment stats:", res.stats)
# plot_heatmap(res.cosine, title="Cross-SAE Feature Alignment (Cosine)")
# plt.show()

# # 5) Inspect top matches for first few features
# res.top_matches[:5]

### Interpretation (v0.1 heuristics)
- **mean_best ≥ 0.85**: strong reproducibility (dictionaries mostly aligned)
- **0.70–0.85**: partial alignment (seeds/hparams differ)
- **< 0.70**: low alignment (different dictionaries)
- `% above threshold` (default 0.8): quick sanity metric; >60% typically indicates similar runs

**Important:** Always compare codes from the **same hook/layer** on the **same dataset**.