# ICA Analysis of SAE Residuals ("Dark Matter")

Decompose the ~35% reconstruction error left over after SAE decomposition of GPT-2 activations using ICA.
Determine whether the SAE residual contains interpretable structure that SAEs miss.

In [None]:
# Cell 1: Environment Setup & Imports
# Uncomment the following line to install dependencies:
# !pip install sae-lens transformer-lens torch scikit-learn datasets matplotlib scipy numpy

import os
import json
import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import kurtosis
from sklearn.decomposition import PCA, FastICA

# Device selection
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
if DEVICE == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Create outputs directory
OUTPUT_DIR = "outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Constants
TARGET_TOKENS = 1_000_000
CONTEXT_LEN = 128
LAYER = 6
RANDOM_SEED = 42
HOOK_NAME = f"blocks.{LAYER}.hook_resid_pre"

np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

print(f"\nConfiguration:")
print(f"  Target tokens: {TARGET_TOKENS:,}")
print(f"  Context length: {CONTEXT_LEN}")
print(f"  Layer: {LAYER}")
print(f"  Hook: {HOOK_NAME}")
print(f"  Random seed: {RANDOM_SEED}")

In [None]:
# Cell 2: Load Model and SAE
from sae_lens import SAE

# Load GPT-2 small — try HookedSAETransformer first, fall back to HookedTransformer
try:
    from sae_lens import HookedSAETransformer
    model = HookedSAETransformer.from_pretrained("gpt2", device=DEVICE)
    print("Loaded model via HookedSAETransformer")
except (ImportError, Exception) as e:
    print(f"HookedSAETransformer not available ({e}), falling back to HookedTransformer")
    from transformer_lens import HookedTransformer
    model = HookedTransformer.from_pretrained("gpt2", device=DEVICE)
    print("Loaded model via HookedTransformer")

# Load SAE for layer 6 residual stream
# sae-lens v6+ returns just the SAE object (not a 3-tuple)
sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.6.hook_resid_pre",
    device=DEVICE,
)
print("\nSAE loaded successfully")

# Print model and SAE info
d_model = model.cfg.d_model
print(f"\nModel info:")
print(f"  d_model: {d_model}")
print(f"  n_layers: {model.cfg.n_layers}")
print(f"  n_heads: {model.cfg.n_heads}")

print(f"\nSAE info:")
print(f"  W_dec shape: {sae.W_dec.shape}")
print(f"  n_features: {sae.W_dec.shape[0]}")
print(f"  d_model: {sae.W_dec.shape[1]}")

In [None]:
# Cell 2b: RELOAD — Load saved data and recompute derived variables
# Run this (after Cell 1 + Cell 2) to skip Cells 3-8 and jump to Cell 9+.
# Requires: outputs/*.npy from a previous run, model and sae from Cell 2.

print("Loading saved arrays from outputs/...")
residuals = np.load(os.path.join(OUTPUT_DIR, "residuals.npy"))
activations = np.load(os.path.join(OUTPUT_DIR, "activations.npy"))
tokens_flat = np.load(os.path.join(OUTPUT_DIR, "tokens.npy"))
ica_directions = np.load(os.path.join(OUTPUT_DIR, "ica_directions.npy"))
ica_activations = np.load(os.path.join(OUTPUT_DIR, "ica_activations.npy"))

print(f"  residuals:       {residuals.shape}  ({residuals.nbytes/1e9:.2f} GB)")
print(f"  activations:     {activations.shape}  ({activations.nbytes/1e9:.2f} GB)")
print(f"  tokens_flat:     {tokens_flat.shape}")
print(f"  ica_directions:  {ica_directions.shape}")
print(f"  ica_activations: {ica_activations.shape}")

# --- Recompute Cell 4 variables (variance & kurtosis) ---
total_var = np.var(activations, axis=0).sum()
residual_var = np.var(residuals, axis=0).sum()
frac_unexplained = residual_var / total_var
kurt = kurtosis(residuals, axis=0)
print(f"\nVariance: {frac_unexplained*100:.1f}% unexplained by SAE")
print(f"Residual kurtosis mean: {kurt.mean():.3f}")

# --- Recompute Cell 5 variables (PCA component counts) ---
pca_diagnostic = PCA(n_components=min(100, residuals.shape[1]))
pca_diagnostic.fit(residuals)
cumvar = np.cumsum(pca_diagnostic.explained_variance_ratio_)
n_components_90 = int(np.searchsorted(cumvar, 0.90) + 1)
n_components_95 = int(np.searchsorted(cumvar, 0.95) + 1)
n_components_99 = int(np.searchsorted(cumvar, 0.99) + 1)
n_ica_components = ica_directions.shape[0]
print(f"PCA: 90%={n_components_90}, 95%={n_components_95}, 99%={n_components_99}")
print(f"ICA components (from saved): {n_ica_components}")

# --- Recompute Cell 6 variables (PCA whitening + ICA objects) ---
pca_pre = PCA(n_components=n_ica_components, whiten=True, random_state=RANDOM_SEED)
residuals_whitened = pca_pre.fit_transform(residuals)

ica = FastICA(
    n_components=n_ica_components,
    algorithm='parallel',
    whiten=False,
    max_iter=1000,
    tol=1e-4,
    random_state=RANDOM_SEED,
)
ica_sources = ica.fit_transform(residuals_whitened)
del residuals_whitened
print(f"ICA refit: converged in {ica.n_iter_} iterations")

# --- Recompute Cell 7 variables (SAE similarity) ---
sae_decoder = sae.W_dec.detach().cpu().float().numpy()
sae_decoder_norms = np.maximum(np.linalg.norm(sae_decoder, axis=1, keepdims=True), 1e-8)
sae_decoder_normed = sae_decoder / sae_decoder_norms
cos_sim = ica_directions @ sae_decoder_normed.T
max_cos_sim = np.max(np.abs(cos_sim), axis=1)
best_sae_match = np.argmax(np.abs(cos_sim), axis=1)
n_high = int((max_cos_sim > 0.8).sum())
n_medium = int(((max_cos_sim > 0.3) & (max_cos_sim <= 0.8)).sum())
n_low = int((max_cos_sim <= 0.3).sum())
del sae_decoder, sae_decoder_normed, cos_sim
print(f"SAE similarity: {n_low} novel / {n_medium} medium / {n_high} high")

d_model = model.cfg.d_model
print(f"\nReload complete — ready to run Cell 9+")

In [None]:
# Cell 3: Collect Activations and Compute Residuals
from datasets import load_dataset

# Stream OpenWebText
dataset = load_dataset("Skylion007/openwebtext", split="train", streaming=True, trust_remote_code=True)

all_residuals = []
all_activations = []
all_tokens = []
tokens_collected = 0
batch_texts = []
BATCH_SIZE = 32

# Get the BOS/EOT token id so we can exclude it
BOS_TOKEN_ID = model.tokenizer.bos_token_id  # same as eos for GPT-2: 50256

# Set up tokenizer for proper truncation and padding
model.tokenizer.pad_token = model.tokenizer.eos_token

print(f"Collecting ~{TARGET_TOKENS:,} tokens...")
print(f"Excluding BOS/EOT token (id={BOS_TOKEN_ID}) from all stored data")

for i, example in enumerate(dataset):
    text = example["text"]
    if len(text.strip()) < 20:
        continue
    batch_texts.append(text)

    if len(batch_texts) < BATCH_SIZE:
        continue

    # Tokenize with explicit truncation to avoid huge intermediate tensors
    tokenized = model.tokenizer(
        batch_texts,
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=CONTEXT_LEN - 1,  # leave room for BOS
    )
    # Prepend BOS token
    input_ids = tokenized["input_ids"].to(DEVICE)
    bos = torch.full((input_ids.shape[0], 1), BOS_TOKEN_ID, dtype=input_ids.dtype, device=DEVICE)
    tokens = torch.cat([bos, input_ids], dim=1)[:, :CONTEXT_LEN]

    with torch.no_grad():
        # Run model, cache activations at the hook point
        _, cache = model.run_with_cache(
            tokens,
            names_filter=[HOOK_NAME],
        )

        # Get activations: shape (batch, seq_len, d_model)
        acts = cache[HOOK_NAME]

        # Reshape to (batch * seq_len, d_model)
        acts_flat = acts.reshape(-1, acts.shape[-1])
        tokens_flat_batch = tokens.reshape(-1)

        # Build mask to exclude BOS/EOT/PAD tokens (all token id 50256 in GPT-2)
        keep_mask = tokens_flat_batch != BOS_TOKEN_ID

        # Apply mask — only keep real content tokens
        acts_kept = acts_flat[keep_mask]

        # Run through SAE encode/decode to get reconstruction
        feature_acts = sae.encode(acts_kept)
        reconstructed = sae.decode(feature_acts)

        # Compute residual = original - reconstruction
        residual = acts_kept - reconstructed

        # Store on CPU as numpy
        all_residuals.append(residual.cpu().float().numpy())
        all_activations.append(acts_kept.cpu().float().numpy())
        all_tokens.append(tokens_flat_batch[keep_mask].cpu().numpy())

    # Count tokens kept
    tokens_collected += int(keep_mask.sum())
    batch_texts = []

    # Progress report
    if (i // BATCH_SIZE) % 25 == 0:
        mem_gb = sum(r.nbytes for r in all_residuals) / 1e9
        print(f"  {tokens_collected:,} / {TARGET_TOKENS:,} tokens | {mem_gb:.1f} GB stored")

    if tokens_collected >= TARGET_TOKENS:
        break

    # Clear GPU cache periodically
    if i % 500 == 0 and DEVICE == "cuda":
        torch.cuda.empty_cache()

# Concatenate all collected data
residuals = np.concatenate(all_residuals, axis=0)
activations = np.concatenate(all_activations, axis=0)
tokens_flat = np.concatenate(all_tokens, axis=0)

# Free the lists
del all_residuals, all_activations, all_tokens
if DEVICE == "cuda":
    torch.cuda.empty_cache()

print(f"\nCollection complete!")
print(f"  Tokens kept (excl. BOS/EOT/PAD): {residuals.shape[0]:,}")
print(f"  Residuals shape: {residuals.shape}")
print(f"  Activations shape: {activations.shape}")
print(f"  Tokens shape: {tokens_flat.shape}")
print(f"  Residuals memory: {residuals.nbytes / 1e9:.2f} GB")
print(f"  Activations memory: {activations.nbytes / 1e9:.2f} GB")

# Save intermediates
np.save(os.path.join(OUTPUT_DIR, "residuals.npy"), residuals)
np.save(os.path.join(OUTPUT_DIR, "activations.npy"), activations)
np.save(os.path.join(OUTPUT_DIR, "tokens.npy"), tokens_flat)
print(f"\nSaved residuals.npy, activations.npy, tokens.npy to {OUTPUT_DIR}/")

In [None]:
# Cell 4: Basic Diagnostics — Variance & Kurtosis

# Fraction of variance unexplained by SAE
total_var = np.var(activations, axis=0).sum()
residual_var = np.var(residuals, axis=0).sum()
frac_unexplained = residual_var / total_var

print(f"Variance Analysis:")
print(f"  Total activation variance: {total_var:.2f}")
print(f"  Residual variance: {residual_var:.2f}")
print(f"  Fraction unexplained by SAE: {frac_unexplained:.3f} ({frac_unexplained*100:.1f}%)")

# Per-dimension kurtosis of residual
kurt = kurtosis(residuals, axis=0)  # excess kurtosis (Gaussian = 0)

print(f"\nResidual Kurtosis (excess):")
print(f"  Mean: {kurt.mean():.3f}")
print(f"  Std: {kurt.std():.3f}")
print(f"  Min: {kurt.min():.3f}")
print(f"  Max: {kurt.max():.3f}")
print(f"  Dims with kurtosis > 1: {(kurt > 1).sum()} / {len(kurt)}")
print(f"  Dims with kurtosis > 3: {(kurt > 3).sum()} / {len(kurt)}")

# Decision point
print(f"\n--- DECISION POINT ---")
if abs(kurt.mean()) < 0.5:
    print("Mean kurtosis is near zero — residual is approximately Gaussian.")
    print("ICA may not find much structure. This would mean SAEs capture")
    print("most non-Gaussian structure, and the dark matter is noise.")
    print("(Still a valid and interesting finding!)")
else:
    print(f"Mean kurtosis = {kurt.mean():.3f} — substantial non-Gaussianity detected!")
    print("The residual contains non-Gaussian structure that ICA should be able to decompose.")
    print("Proceeding with ICA is well-motivated.")

# Plot kurtosis histogram
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(kurt, bins=50, edgecolor='black', alpha=0.7)
ax.axvline(x=0, color='r', linestyle='--', linewidth=2, label='Gaussian (kurtosis=0)')
ax.axvline(x=kurt.mean(), color='orange', linestyle='-', linewidth=2, label=f'Mean={kurt.mean():.2f}')
ax.set_xlabel("Excess Kurtosis", fontsize=12)
ax.set_ylabel("Count (dimensions)", fontsize=12)
ax.set_title("Per-Dimension Kurtosis of SAE Residual", fontsize=14)
ax.legend(fontsize=11)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "residual_kurtosis.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved residual_kurtosis.png")

In [None]:
# Cell 5: PCA on Residuals

pca_diagnostic = PCA(n_components=min(100, residuals.shape[1]))
pca_diagnostic.fit(residuals)

cumvar = np.cumsum(pca_diagnostic.explained_variance_ratio_)

# Determine optimal component count for 90% variance
n_components_90 = int(np.searchsorted(cumvar, 0.90) + 1)
n_components_95 = int(np.searchsorted(cumvar, 0.95) + 1)
n_components_99 = int(np.searchsorted(cumvar, 0.99) + 1)

print(f"PCA on Residuals:")
print(f"  Components for 90% variance: {n_components_90}")
print(f"  Components for 95% variance: {n_components_95}")
print(f"  Components for 99% variance: {n_components_99}")
print(f"  Variance in first 10 components: {cumvar[9]*100:.1f}%")
print(f"  Variance in first 50 components: {cumvar[min(49, len(cumvar)-1)]*100:.1f}%")

# Plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(range(1, len(cumvar) + 1), cumvar, 'b-', linewidth=2)
ax.axhline(y=0.90, color='r', linestyle='--', alpha=0.7, label=f'90% ({n_components_90} components)')
ax.axhline(y=0.95, color='orange', linestyle='--', alpha=0.7, label=f'95% ({n_components_95} components)')
ax.axvline(x=n_components_90, color='r', linestyle=':', alpha=0.5)
ax.set_xlabel("Number of PCA Components", fontsize=12)
ax.set_ylabel("Cumulative Variance Explained", fontsize=12)
ax.set_title("PCA of SAE Residual — Cumulative Variance Explained", fontsize=14)
ax.legend(fontsize=11)
ax.set_ylim([0, 1.02])
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "residual_pca.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved residual_pca.png")

# Choose n_components for ICA — use 90% variance threshold, but cap at 100
n_ica_components = min(n_components_90, 100)
print(f"\nRecommended ICA component count: {n_ica_components}")

In [None]:
# Cell 6: Run FastICA

print(f"Running ICA with {n_ica_components} components...")

# PCA pre-whitening
pca_pre = PCA(n_components=n_ica_components, whiten=True, random_state=RANDOM_SEED)
residuals_whitened = pca_pre.fit_transform(residuals)
print(f"PCA whitening complete. Whitened shape: {residuals_whitened.shape}")

# Run FastICA on whitened data
ica = FastICA(
    n_components=n_ica_components,
    algorithm='parallel',
    whiten=False,  # Already whitened via PCA
    max_iter=1000,
    tol=1e-4,
    random_state=RANDOM_SEED,
)
ica_sources = ica.fit_transform(residuals_whitened)
print(f"ICA converged in {ica.n_iter_} iterations")

# Compute ICA directions in original 768-dim space
# ica.mixing_ shape: (n_ica_components, n_ica_components) — maps sources to whitened space
# pca_pre.components_ shape: (n_ica_components, 768) — maps whitened space to original space
ica_directions = ica.mixing_.T @ pca_pre.components_  # shape: (n_ica_components, 768)

# Normalize each direction to unit length
ica_directions = ica_directions / np.linalg.norm(ica_directions, axis=1, keepdims=True)
print(f"ICA directions shape: {ica_directions.shape}")

# --- Robustness check: run with 3 seeds and compare ---
print(f"\nRunning robustness check with 3 seeds...")
seeds = [42, 123, 999]
all_directions = []

for seed in seeds:
    ica_check = FastICA(
        n_components=n_ica_components,
        algorithm='parallel',
        whiten=False,
        max_iter=1000,
        tol=1e-4,
        random_state=seed,
    )
    ica_check.fit(residuals_whitened)
    dirs = ica_check.mixing_.T @ pca_pre.components_
    dirs = dirs / np.linalg.norm(dirs, axis=1, keepdims=True)
    all_directions.append(dirs)
    print(f"  Seed {seed}: converged in {ica_check.n_iter_} iterations")

# Compare seeds pairwise: for each component in seed_i, find best match in seed_j
print(f"\nRobustness (max cosine similarity of best-matched components across seeds):")
for i in range(len(seeds)):
    for j in range(i + 1, len(seeds)):
        cos_matrix = np.abs(all_directions[i] @ all_directions[j].T)
        # For each component in seed_i, find best match in seed_j
        best_matches = cos_matrix.max(axis=1)
        print(f"  Seeds {seeds[i]} vs {seeds[j]}: "
              f"mean best-match cosine = {best_matches.mean():.3f}, "
              f"min = {best_matches.min():.3f}, "
              f"components with match > 0.9: {(best_matches > 0.9).sum()}/{n_ica_components}")

del all_directions, ica_check, residuals_whitened

# Save ICA directions
np.save(os.path.join(OUTPUT_DIR, "ica_directions.npy"), ica_directions)
print(f"\nSaved ica_directions.npy")

In [None]:
# Cell 7: Compare ICA Directions to SAE Decoder

# Extract and normalize SAE decoder weights
sae_decoder = sae.W_dec.detach().cpu().float().numpy()  # shape: (n_features, d_model)
print(f"SAE decoder shape: {sae_decoder.shape}")

sae_decoder_norms = np.linalg.norm(sae_decoder, axis=1, keepdims=True)
# Avoid division by zero for dead features
sae_decoder_norms = np.maximum(sae_decoder_norms, 1e-8)
sae_decoder_normed = sae_decoder / sae_decoder_norms

# Compute cosine similarity: each ICA direction vs all SAE features
# ica_directions shape: (n_ica_components, 768)
# sae_decoder_normed shape: (n_features, 768)
cos_sim = ica_directions @ sae_decoder_normed.T  # shape: (n_ica_components, n_features)

# For each ICA component, find the max absolute cosine similarity with any SAE feature
max_cos_sim = np.max(np.abs(cos_sim), axis=1)  # shape: (n_ica_components,)
best_sae_match = np.argmax(np.abs(cos_sim), axis=1)

# Categorize
n_high = int((max_cos_sim > 0.8).sum())
n_medium = int(((max_cos_sim > 0.3) & (max_cos_sim <= 0.8)).sum())
n_low = int((max_cos_sim <= 0.3).sum())

print(f"\nICA-to-SAE Similarity:")
print(f"  Mean max cosine similarity: {max_cos_sim.mean():.3f}")
print(f"  Median: {np.median(max_cos_sim):.3f}")
print(f"  High similarity (>0.8):   {n_high:3d} components — SAE already knows these directions")
print(f"  Medium similarity (0.3-0.8): {n_medium:3d} components — partially overlapping with SAE")
print(f"  Low similarity (<0.3):    {n_low:3d} components — genuinely NEW directions!")

# Print top-5 most novel and top-5 most similar
sorted_by_sim = np.argsort(max_cos_sim)
print(f"\nMost novel ICA components (lowest SAE similarity):")
for idx in sorted_by_sim[:5]:
    print(f"  Component {idx}: max cosine sim = {max_cos_sim[idx]:.3f} (best SAE match: feature {best_sae_match[idx]})")

print(f"\nMost SAE-similar ICA components:")
for idx in sorted_by_sim[-5:][::-1]:
    print(f"  Component {idx}: max cosine sim = {max_cos_sim[idx]:.3f} (best SAE match: feature {best_sae_match[idx]})")

# Plot
fig, ax = plt.subplots(figsize=(10, 5))
ax.hist(max_cos_sim, bins=30, edgecolor='black', alpha=0.7, color='steelblue')
ax.axvline(x=0.3, color='green', linestyle='--', linewidth=2, alpha=0.7, label='Low/Medium boundary (0.3)')
ax.axvline(x=0.8, color='red', linestyle='--', linewidth=2, alpha=0.7, label='Medium/High boundary (0.8)')
ax.set_xlabel("Max Cosine Similarity with Any SAE Feature", fontsize=12)
ax.set_ylabel("Count (ICA Components)", fontsize=12)
ax.set_title("ICA Component Similarity to Nearest SAE Feature", fontsize=14)
ax.legend(fontsize=11)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "ica_sae_similarity.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved ica_sae_similarity.png")

In [None]:
# Cell 8: Interpret ICA Components

# Project all activations onto ICA directions to get ICA activations
ica_activations = activations @ ica_directions.T  # shape: (N, n_ica_components)
print(f"ICA activations shape: {ica_activations.shape}")


def get_top_activating_examples(component_idx, k=20):
    """Find the top-k tokens that most strongly activate a given ICA component."""
    acts = ica_activations[:, component_idx]
    top_indices = np.argsort(np.abs(acts))[-k:][::-1]

    results = []
    for idx in top_indices:
        token_id = int(tokens_flat[idx])
        token_str = model.to_string([token_id])
        # Get surrounding context (+-5 tokens), respecting boundaries
        start = max(0, idx - 5)
        end = min(len(tokens_flat), idx + 6)
        context_ids = tokens_flat[start:end].tolist()
        context_str = model.to_string(context_ids)
        # Mark the target token position within the context
        pos_in_context = idx - start
        results.append({
            'token': token_str,
            'context': context_str,
            'activation': float(acts[idx]),
            'position': int(idx),
        })
    return results


# Prioritize components with low SAE similarity (most novel)
sorted_by_novelty = np.argsort(max_cos_sim)  # ascending — most novel first
top_components = sorted_by_novelty[:10]

print(f"\nInterpreting top 10 most novel ICA components (lowest SAE similarity):")
print(f"{'='*80}")

for comp_idx in top_components:
    cos = max_cos_sim[comp_idx]
    comp_kurtosis = kurtosis(ica_activations[:, comp_idx])
    examples = get_top_activating_examples(comp_idx, k=20)

    print(f"\n{'='*80}")
    print(f"ICA Component {comp_idx} | Max SAE cosine sim: {cos:.3f} | Kurtosis: {comp_kurtosis:.2f}")
    print(f"{'='*80}")
    for ex in examples[:10]:  # Print top 10
        act_str = f"{ex['activation']:+.3f}"
        print(f"  [{act_str:>8s}]  token='{ex['token']}'")
        print(f"             context: ...{repr(ex['context'])}...")

# Save ICA activations
np.save(os.path.join(OUTPUT_DIR, "ica_activations.npy"), ica_activations)
print(f"\nSaved ica_activations.npy")

In [None]:
# Cell 9: Statistical Comparison — ICA vs SAE

# Kurtosis of ICA component activations
ica_kurtosis = kurtosis(ica_activations, axis=0)
print(f"ICA component kurtosis:")
print(f"  Mean: {ica_kurtosis.mean():.2f}")
print(f"  Median: {np.median(ica_kurtosis):.2f}")
print(f"  Min: {ica_kurtosis.min():.2f}, Max: {ica_kurtosis.max():.2f}")

# Compute kurtosis of active SAE features — STREAMING to avoid OOM
# We accumulate raw moments (sum of x, x^2, x^3, x^4) per feature across chunks,
# then compute kurtosis from those without ever storing the full array.
print(f"\nComputing SAE feature statistics (streaming)...")
chunk_size = 50_000
n_samples = activations.shape[0]
n_chunks = (n_samples + chunk_size - 1) // chunk_size

# Single pass: accumulate counts, sums, and raw moments for all features
n_features = sae.W_dec.shape[0]
fire_counts = np.zeros(n_features, dtype=np.float64)
sum_x = np.zeros(n_features, dtype=np.float64)
sum_x2 = np.zeros(n_features, dtype=np.float64)
sum_x4 = np.zeros(n_features, dtype=np.float64)

for i in range(n_chunks):
    start = i * chunk_size
    end = min(start + chunk_size, n_samples)
    chunk = torch.tensor(activations[start:end], device=DEVICE, dtype=torch.float32)
    with torch.no_grad():
        sae_acts_chunk = sae.encode(chunk).cpu().numpy().astype(np.float64)
    fire_counts += (sae_acts_chunk > 0).sum(axis=0)
    sum_x += sae_acts_chunk.sum(axis=0)
    sum_x2 += (sae_acts_chunk ** 2).sum(axis=0)
    sum_x4 += (sae_acts_chunk ** 4).sum(axis=0)
    if (i + 1) % 5 == 0:
        print(f"  Chunk {i+1}/{n_chunks} processed")
    del sae_acts_chunk

if DEVICE == "cuda":
    torch.cuda.empty_cache()

# Identify active features (fire on >0.1% of tokens)
feature_fire_rate = fire_counts / n_samples
active_mask = feature_fire_rate > 0.001
n_active = int(active_mask.sum())
print(f"  Active SAE features (fire >0.1%): {n_active} / {n_features}")

# Compute excess kurtosis from raw moments for active features
# kurtosis = E[(X-mu)^4] / Var(X)^2 - 3
# E[(X-mu)^4] = E[X^4] - 4*mu*E[X^3] + 6*mu^2*E[X^2] - 3*mu^4
# But we didn't store E[X^3]. Use the simpler form:
# Var = E[X^2] - E[X]^2
# E[(X-mu)^4] = E[X^4] - 4*E[X]*E[X^3] + 6*E[X]^2*E[X^2] - 3*E[X]^4
# Since we skipped x^3, use the relationship:
# excess_kurtosis = (m4 / m2^2) - 3
# where m2 = central 2nd moment, m4 = central 4th moment
# m2 = E[X^2] - mu^2
# m4 = E[X^4] - 4*mu*E[X^3] + 6*mu^2*E[X^2] - 3*mu^4
# We need E[X^3], so let's do a second quick pass just for that
sum_x3 = np.zeros(n_features, dtype=np.float64)
for i in range(n_chunks):
    start = i * chunk_size
    end = min(start + chunk_size, n_samples)
    chunk = torch.tensor(activations[start:end], device=DEVICE, dtype=torch.float32)
    with torch.no_grad():
        sae_acts_chunk = sae.encode(chunk).cpu().numpy().astype(np.float64)
    sum_x3 += (sae_acts_chunk ** 3).sum(axis=0)
    del sae_acts_chunk

if DEVICE == "cuda":
    torch.cuda.empty_cache()

# Now compute kurtosis for active features
mean = sum_x[active_mask] / n_samples
ex2 = sum_x2[active_mask] / n_samples
ex3 = sum_x3[active_mask] / n_samples
ex4 = sum_x4[active_mask] / n_samples

var = ex2 - mean**2
# Central 4th moment: E[(X-mu)^4] = E[X^4] - 4*mu*E[X^3] + 6*mu^2*E[X^2] - 3*mu^4
m4 = ex4 - 4*mean*ex3 + 6*(mean**2)*ex2 - 3*mean**4

# Avoid division by zero for features with near-zero variance
valid = var > 1e-12
sae_kurtosis = np.full(n_active, np.nan)
sae_kurtosis[valid] = (m4[valid] / var[valid]**2) - 3.0

# Drop NaN values for stats
sae_kurtosis_clean = sae_kurtosis[~np.isnan(sae_kurtosis)]
print(f"\nSAE feature kurtosis (active features, {len(sae_kurtosis_clean)} valid):")
print(f"  Mean: {sae_kurtosis_clean.mean():.2f}")
print(f"  Median: {np.median(sae_kurtosis_clean):.2f}")
print(f"  Min: {sae_kurtosis_clean.min():.2f}, Max: {sae_kurtosis_clean.max():.2f}")

del sum_x, sum_x2, sum_x3, sum_x4, fire_counts

# Plot overlapping histograms
fig, ax = plt.subplots(figsize=(10, 5))
# Clip extreme values for better visualization
ica_kurt_clipped = np.clip(ica_kurtosis, -10, 200)
sae_kurt_clipped = np.clip(sae_kurtosis_clean, -10, 200)
ax.hist(ica_kurt_clipped, bins=50, alpha=0.5, label=f'ICA components (n={len(ica_kurtosis)})', density=True, color='steelblue')
ax.hist(sae_kurt_clipped, bins=50, alpha=0.5, label=f'SAE features (n={len(sae_kurtosis_clean)})', density=True, color='orange')
ax.set_xlabel("Excess Kurtosis", fontsize=12)
ax.set_ylabel("Density", fontsize=12)
ax.set_title("Kurtosis: ICA Components vs Active SAE Features", fontsize=14)
ax.legend(fontsize=11)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, "ica_vs_sae_kurtosis.png"), dpi=150, bbox_inches='tight')
plt.show()
print(f"Saved ica_vs_sae_kurtosis.png")

# How much of the residual variance does ICA explain? (sanity check)
ica_reconstruction = ica_sources @ ica.mixing_.T @ pca_pre.components_
ica_recon_var = np.var(residuals - ica_reconstruction, axis=0).sum()
ica_frac_explained = 1.0 - ica_recon_var / residual_var
print(f"\nICA explains {ica_frac_explained*100:.1f}% of residual variance")
print(f"(Expected ~{pca_pre.explained_variance_ratio_.sum()*100:.1f}% based on PCA components kept)")

In [None]:
# Cell 10: Save Results & Summary Report

results = {
    "model": "gpt2-small",
    "sae": "gpt2-small-res-jb",
    "sae_id": "blocks.6.hook_resid_pre",
    "layer": LAYER,
    "tokens_analyzed": int(residuals.shape[0]),
    "d_model": int(d_model),
    "n_ica_components": int(n_ica_components),
    "variance_analysis": {
        "total_activation_variance": float(total_var),
        "residual_variance": float(residual_var),
        "fraction_unexplained_by_sae": float(frac_unexplained),
    },
    "residual_kurtosis": {
        "mean": float(kurt.mean()),
        "std": float(kurt.std()),
        "min": float(kurt.min()),
        "max": float(kurt.max()),
    },
    "pca": {
        "n_components_90pct_variance": int(n_components_90),
        "n_components_95pct_variance": int(n_components_95),
        "n_components_99pct_variance": int(n_components_99),
    },
    "ica": {
        "n_components": int(n_ica_components),
        "iterations_to_converge": int(ica.n_iter_),
        "fraction_residual_variance_explained": float(ica_frac_explained),
    },
    "ica_sae_similarity": {
        "mean_max_cosine": float(max_cos_sim.mean()),
        "median_max_cosine": float(np.median(max_cos_sim)),
        "high_similarity_count_gt_0.8": int(n_high),
        "medium_similarity_count_0.3_to_0.8": int(n_medium),
        "low_similarity_count_lt_0.3": int(n_low),
    },
    "kurtosis_comparison": {
        "ica_component_kurtosis_mean": float(ica_kurtosis.mean()),
        "ica_component_kurtosis_median": float(np.median(ica_kurtosis)),
        "sae_feature_kurtosis_mean": float(sae_kurtosis_clean.mean()),
        "sae_feature_kurtosis_median": float(np.median(sae_kurtosis_clean)),
    },
}

# Save JSON
results_path = os.path.join(OUTPUT_DIR, "ica_dark_matter_results.json")
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)
print(f"Saved results to {results_path}")

# Print formatted summary
print(f"\n{'='*70}")
print(f"  ICA DARK MATTER ANALYSIS — SUMMARY REPORT")
print(f"{'='*70}")
print(f"")
print(f"  Model: GPT-2 small (d_model={d_model})")
print(f"  SAE: gpt2-small-res-jb, layer {LAYER}")
print(f"  Tokens analyzed: {residuals.shape[0]:,}")
print(f"")
print(f"  --- Variance ---")
print(f"  SAE reconstruction error: {frac_unexplained*100:.1f}% of total variance")
print(f"  ICA captures {ica_frac_explained*100:.1f}% of that residual variance")
print(f"")
print(f"  --- Residual Structure ---")
print(f"  Mean residual kurtosis: {kurt.mean():.3f}", end="")
if abs(kurt.mean()) < 0.5:
    print(f" (near-Gaussian — limited non-Gaussian structure)")
else:
    print(f" (non-Gaussian — ICA found meaningful structure!)")
print(f"  PCA components for 90% residual variance: {n_components_90}")
print(f"")
print(f"  --- ICA vs SAE ---")
print(f"  ICA components: {n_ica_components}")
print(f"  Mean max cosine similarity to SAE: {max_cos_sim.mean():.3f}")
print(f"  Novel directions (cosine <0.3): {n_low} ({n_low/n_ica_components*100:.0f}%)")
print(f"  Partially overlapping (0.3-0.8): {n_medium} ({n_medium/n_ica_components*100:.0f}%)")
print(f"  Redundant with SAE (>0.8):       {n_high} ({n_high/n_ica_components*100:.0f}%)")
print(f"")
print(f"  --- Sparsity ---")
print(f"  ICA component kurtosis (mean): {ica_kurtosis.mean():.2f}")
print(f"  SAE feature kurtosis (mean):   {sae_kurtosis_clean.mean():.2f}")
print(f"")

# Interpretation
print(f"  --- Interpretation ---")
if n_low > n_ica_components * 0.3:
    print(f"  A substantial fraction ({n_low}/{n_ica_components}) of ICA components represent")
    print(f"  genuinely novel directions not captured by the SAE. The dark matter contains")
    print(f"  interpretable structure that SAEs miss.")
elif n_high > n_ica_components * 0.5:
    print(f"  Most ICA components ({n_high}/{n_ica_components}) closely match SAE features.")
    print(f"  The dark matter is mostly due to SAE under-reconstruction of known features,")
    print(f"  not genuinely missing features. Better SAE training could reduce this.")
else:
    print(f"  The ICA components show a mix of novel and SAE-overlapping directions.")
    print(f"  Some dark matter is structured, some overlaps with known features.")
print(f"")
print(f"{'='*70}")
print(f"\nOutput files:")
for f_name in ["residuals.npy", "activations.npy", "tokens.npy",
               "ica_directions.npy", "ica_activations.npy",
               "residual_kurtosis.png", "residual_pca.png",
               "ica_sae_similarity.png", "ica_vs_sae_kurtosis.png",
               "ica_dark_matter_results.json"]:
    full_path = os.path.join(OUTPUT_DIR, f_name)
    exists = os.path.exists(full_path)
    size = os.path.getsize(full_path) / 1e6 if exists else 0
    status = f"{size:.1f} MB" if exists else "MISSING"
    print(f"  {f_name:40s} {status}")