# ICA Dark Matter — Automated Interpretability

Run autointerp on all 100 ICA components + 100 baseline SAE features.
Uses Claude Sonnet for explain + score passes (Bills et al. methodology).

**Run this on the RunPod instance where `outputs/` already exists.**

In [None]:
# Cell 1: Setup & Load Data
# !pip install anthropic sae-lens transformer-lens

import os
import json
import numpy as np
import torch
from scipy.stats import kurtosis
from collections import defaultdict
import time
import re

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_DIR = "outputs"
AUTOINTERP_DIR = "autointerp_results"
os.makedirs(AUTOINTERP_DIR, exist_ok=True)

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# Load saved arrays
print("Loading saved arrays...")
activations = np.load(os.path.join(OUTPUT_DIR, "activations.npy"))
tokens_flat = np.load(os.path.join(OUTPUT_DIR, "tokens.npy"))
ica_directions = np.load(os.path.join(OUTPUT_DIR, "ica_directions.npy"))
ica_activations = np.load(os.path.join(OUTPUT_DIR, "ica_activations.npy"))

n_ica_components = ica_directions.shape[0]
n_tokens = activations.shape[0]

print(f"  activations: {activations.shape}")
print(f"  tokens: {tokens_flat.shape}")
print(f"  ica_directions: {ica_directions.shape}")
print(f"  ica_activations: {ica_activations.shape}")
print(f"  Device: {DEVICE}")

In [None]:
# Cell 2: Load Model (for tokenizer) and SAE
from sae_lens import SAE

try:
    from sae_lens import HookedSAETransformer
    model = HookedSAETransformer.from_pretrained("gpt2", device=DEVICE)
except Exception:
    from transformer_lens import HookedTransformer
    model = HookedTransformer.from_pretrained("gpt2", device=DEVICE)

sae = SAE.from_pretrained(
    release="gpt2-small-res-jb",
    sae_id="blocks.6.hook_resid_pre",
    device=DEVICE,
)

tokenizer = model.tokenizer
print("Model and SAE loaded.")

# Compute ICA-SAE cosine similarities (needed to identify novel components)
sae_decoder = sae.W_dec.detach().cpu().float().numpy()
sae_decoder_norms = np.maximum(np.linalg.norm(sae_decoder, axis=1, keepdims=True), 1e-8)
sae_decoder_normed = sae_decoder / sae_decoder_norms
cos_sim = ica_directions @ sae_decoder_normed.T
max_cos_sim = np.max(np.abs(cos_sim), axis=1)
del cos_sim

n_novel = (max_cos_sim < 0.3).sum()
print(f"Novel ICA components (cosine < 0.3): {n_novel}")
print(f"Mean max cosine sim: {max_cos_sim.mean():.3f}")

In [None]:
# Cell 3: Set your API key
import anthropic

# Set your API key — either uncomment and paste, or set env var
# os.environ["ANTHROPIC_API_KEY"] = "sk-ant-..."

client = anthropic.Anthropic()  # reads ANTHROPIC_API_KEY from env
print("Anthropic client initialized.")

In [None]:
# Cell 4: Collect Rich Activation Examples

CONTEXT_BEFORE = 50  # tokens before highlighted token
CONTEXT_AFTER = 20   # tokens after highlighted token
TOP_K = 40           # collect more to survive deduplication (need 20 usable)


def get_context(idx, tokens, tokenizer, ctx_before=CONTEXT_BEFORE, ctx_after=CONTEXT_AFTER):
    """Get surrounding context for a token at position idx."""
    start = max(0, idx - ctx_before)
    end = min(len(tokens), idx + ctx_after + 1)

    before_ids = tokens[start:idx].tolist()
    target_id = int(tokens[idx])
    after_ids = tokens[idx + 1:end].tolist()

    before_text = tokenizer.decode(before_ids)
    target_text = tokenizer.decode([target_id])
    after_text = tokenizer.decode(after_ids)

    highlighted = f"{before_text}[{target_text}]{after_text}"
    return highlighted, target_text


def collect_examples_for_activations(act_values, tokens, tokenizer, top_k=TOP_K, min_examples=20):
    """Collect top activating examples with context, with deduplication."""
    # Get more candidates than needed to survive deduplication
    n_candidates = min(top_k * 2, len(act_values))
    top_indices = np.argsort(np.abs(act_values))[-n_candidates:][::-1]

    examples = []
    seen_contexts = set()
    for idx in top_indices:
        if len(examples) >= top_k:
            break
        context, token_text = get_context(idx, tokens, tokenizer)
        # Deduplicate near-identical contexts
        context_key = context[:80]
        if context_key in seen_contexts:
            continue
        seen_contexts.add(context_key)
        examples.append({
            "context": context,
            "token": token_text,
            "activation": float(act_values[idx]),
            "position": int(idx),
        })
    return examples


def collect_negative_examples(act_values, tokens, tokenizer, n=10):
    """Collect random non-activating examples."""
    # For sparse features (like SAE), most activations are 0.
    # Use tokens where the feature does NOT fire (activation == 0 or very low).
    abs_acts = np.abs(act_values)

    # Use the 25th percentile as threshold — works for both sparse and dense activations
    threshold = np.percentile(abs_acts, 25)
    # Ensure threshold is at least slightly above 0 for sparse features
    if threshold <= 0:
        # Feature is sparse — just pick from the zeros
        low_indices = np.where(abs_acts == 0)[0]
        if len(low_indices) == 0:
            # Everything is nonzero — use bottom 25%
            low_indices = np.argsort(abs_acts)[:len(abs_acts) // 4]
    else:
        low_indices = np.where(abs_acts <= threshold)[0]

    if len(low_indices) == 0:
        return []

    chosen = np.random.choice(low_indices, size=min(n, len(low_indices)), replace=False)

    examples = []
    for idx in chosen:
        context, token_text = get_context(idx, tokens, tokenizer)
        examples.append({
            "context": context,
            "token": token_text,
            "activation": float(act_values[idx]),
            "position": int(idx),
        })
    return examples


# --- Collect ICA component examples ---
print(f"Collecting examples for {n_ica_components} ICA components...")
ica_examples = {}
for i in range(n_ica_components):
    acts = ica_activations[:, i]
    positives = collect_examples_for_activations(acts, tokens_flat, tokenizer)
    negatives = collect_negative_examples(acts, tokens_flat, tokenizer, n=10)
    ica_examples[i] = {"positives": positives, "negatives": negatives}
    if (i + 1) % 20 == 0:
        print(f"  {i+1}/{n_ica_components} ICA components done "
              f"(last: {len(positives)} pos, {len(negatives)} neg)")

# --- Select and collect SAE baseline features ---
print(f"\nSelecting 100 baseline SAE features...")

# Compute SAE feature activation density in chunks
chunk_size = 50_000
n_chunks = (n_tokens + chunk_size - 1) // chunk_size
n_sae_features = sae.W_dec.shape[0]
fire_counts = np.zeros(n_sae_features, dtype=np.int64)

for c in range(n_chunks):
    s, e = c * chunk_size, min((c + 1) * chunk_size, n_tokens)
    chunk = torch.tensor(activations[s:e], device=DEVICE, dtype=torch.float32)
    with torch.no_grad():
        sae_acts = sae.encode(chunk)
        fire_counts += (sae_acts > 0).sum(dim=0).cpu().numpy()
    del sae_acts, chunk

fire_rate = fire_counts / n_tokens
# Pick 100 active SAE features (fire on >0.1% of tokens), randomly sampled
active_sae_indices = np.where(fire_rate > 0.001)[0]
baseline_sae_indices = np.random.choice(active_sae_indices, size=min(100, len(active_sae_indices)), replace=False)
baseline_sae_indices.sort()
print(f"  Selected {len(baseline_sae_indices)} SAE features from {len(active_sae_indices)} active features")

# Collect SAE feature activations for baseline features
print(f"Collecting SAE feature activations for baseline...")
sae_feature_acts = np.zeros((n_tokens, len(baseline_sae_indices)), dtype=np.float32)
for c in range(n_chunks):
    s, e = c * chunk_size, min((c + 1) * chunk_size, n_tokens)
    chunk = torch.tensor(activations[s:e], device=DEVICE, dtype=torch.float32)
    with torch.no_grad():
        sae_acts = sae.encode(chunk)
        sae_feature_acts[s:e] = sae_acts[:, baseline_sae_indices].cpu().numpy()
    del sae_acts, chunk

if DEVICE == "cuda":
    torch.cuda.empty_cache()

# Collect examples for SAE features
print(f"Collecting examples for {len(baseline_sae_indices)} SAE features...")
sae_examples = {}
for j, feat_idx in enumerate(baseline_sae_indices):
    acts = sae_feature_acts[:, j]
    positives = collect_examples_for_activations(acts, tokens_flat, tokenizer)
    negatives = collect_negative_examples(acts, tokens_flat, tokenizer, n=10)
    sae_examples[int(feat_idx)] = {"positives": positives, "negatives": negatives}
    if (j + 1) % 20 == 0:
        print(f"  {j+1}/{len(baseline_sae_indices)} SAE features done "
              f"(last: {len(positives)} pos, {len(negatives)} neg)")

del sae_feature_acts

# Save examples
with open(os.path.join(AUTOINTERP_DIR, "ica_examples.json"), "w") as f:
    json.dump({str(k): v for k, v in ica_examples.items()}, f)
with open(os.path.join(AUTOINTERP_DIR, "sae_examples.json"), "w") as f:
    json.dump({str(k): v for k, v in sae_examples.items()}, f)
with open(os.path.join(AUTOINTERP_DIR, "baseline_sae_indices.json"), "w") as f:
    json.dump(baseline_sae_indices.tolist(), f)

print(f"\nSaved ica_examples.json and sae_examples.json to {AUTOINTERP_DIR}/")
print(f"Total ICA positives: {sum(len(v['positives']) for v in ica_examples.values())}")
print(f"Total ICA negatives: {sum(len(v['negatives']) for v in ica_examples.values())}")
print(f"Total SAE positives: {sum(len(v['positives']) for v in sae_examples.values())}")
print(f"Total SAE negatives: {sum(len(v['negatives']) for v in sae_examples.values())}")

In [None]:
# Cell 5: Explain Pass — Generate interpretations via Claude

EXPLAIN_MODEL = "claude-sonnet-4-20250514"
N_EXPLAIN_EXAMPLES = 10  # Use top 10 for explanation, reserve rest for scoring

EXPLAIN_SYSTEM = """You are an interpretability researcher analyzing directions in a neural network's activation space. You will be shown text excerpts where a particular direction activates strongly. The token in [brackets] is where the direction fires. Your job is to identify the common pattern, concept, or feature these examples share.

Respond with:
1. A short label (2-5 words) for what this direction represents
2. A one-sentence explanation
3. A confidence score from 1-5:
   - 5: Clear, monosemantic concept
   - 4: Coherent theme with minor variation
   - 3: Plausible pattern but somewhat noisy
   - 2: Weak or ambiguous pattern
   - 1: No discernible pattern / random

Respond ONLY with valid JSON, no other text:
{"label": "...", "explanation": "...", "confidence": N}"""


def build_explain_prompt(examples, component_id, component_type="direction"):
    """Build the user prompt for the explain pass."""
    lines = [f"Here are the top {len(examples)} activating examples for {component_type} #{component_id}:\n"]
    for i, ex in enumerate(examples):
        lines.append(f"Example {i+1} (activation: {ex['activation']:.3f}):")
        lines.append(ex["context"])
        lines.append("")
    return "\n".join(lines)


def call_explain(examples, component_id, component_type="direction"):
    """Call Claude to explain a component."""
    prompt = build_explain_prompt(examples[:N_EXPLAIN_EXAMPLES], component_id, component_type)
    response = client.messages.create(
        model=EXPLAIN_MODEL,
        max_tokens=256,
        temperature=0,
        system=EXPLAIN_SYSTEM,
        messages=[{"role": "user", "content": prompt}],
    )
    text = response.content[0].text.strip()
    # Parse JSON — handle potential markdown wrapping
    text = re.sub(r'^```json\s*', '', text)
    text = re.sub(r'\s*```$', '', text)
    try:
        result = json.loads(text)
    except json.JSONDecodeError:
        result = {"label": "PARSE_ERROR", "explanation": text, "confidence": 0}
    return result


# --- Run explain pass on ICA components ---
explanations = {"ica": {}, "sae": {}}

print(f"Running explain pass on {n_ica_components} ICA components...")
for i in range(n_ica_components):
    positives = ica_examples[i]["positives"]
    if len(positives) < 3:
        explanations["ica"][str(i)] = {"label": "TOO_FEW_EXAMPLES", "explanation": "", "confidence": 0}
        continue
    result = call_explain(positives, i, "ICA component")
    result["max_cos_sim"] = float(max_cos_sim[i])
    explanations["ica"][str(i)] = result
    if (i + 1) % 10 == 0:
        print(f"  {i+1}/{n_ica_components} — last: [{result.get('confidence', '?')}] {result.get('label', '?')}")
    time.sleep(0.5)  # Rate limit

# --- Run explain pass on SAE features ---
print(f"\nRunning explain pass on {len(baseline_sae_indices)} SAE features...")
for j, feat_idx in enumerate(baseline_sae_indices):
    positives = sae_examples[int(feat_idx)]["positives"]
    if len(positives) < 3:
        explanations["sae"][str(feat_idx)] = {"label": "TOO_FEW_EXAMPLES", "explanation": "", "confidence": 0}
        continue
    result = call_explain(positives, int(feat_idx), "SAE feature")
    explanations["sae"][str(feat_idx)] = result
    if (j + 1) % 10 == 0:
        print(f"  {j+1}/{len(baseline_sae_indices)} — last: [{result.get('confidence', '?')}] {result.get('label', '?')}")
    time.sleep(0.5)

# Save explanations
with open(os.path.join(AUTOINTERP_DIR, "explanations.json"), "w") as f:
    json.dump(explanations, f, indent=2)
print(f"\nSaved explanations.json")

# Quick summary
ica_confs = [v["confidence"] for v in explanations["ica"].values() if v["confidence"] > 0]
sae_confs = [v["confidence"] for v in explanations["sae"].values() if v["confidence"] > 0]
print(f"ICA mean confidence: {np.mean(ica_confs):.2f} (n={len(ica_confs)})")
print(f"SAE mean confidence: {np.mean(sae_confs):.2f} (n={len(sae_confs)})")

In [None]:
# Cell 6: Score Pass — Detection accuracy via Claude

SCORE_MODEL = "claude-sonnet-4-20250514"
N_EXPLAIN_EXAMPLES = 10  # first 10 used for explain, rest for scoring

SCORE_SYSTEM = """You are evaluating whether a proposed explanation for a neural network direction is correct. You will be given an explanation and a set of text excerpts. For each excerpt, predict whether this direction would activate strongly on the [bracketed] token (YES or NO).

Respond ONLY with a valid JSON array of predictions, no other text:
[{"example_id": 1, "prediction": "YES"}, {"example_id": 2, "prediction": "NO"}, ...]"""


def build_score_prompt(explanation, examples_with_labels):
    """Build the scoring prompt with shuffled positive + negative examples."""
    lines = [f'The proposed explanation is: "{explanation}"\n']
    for i, (ex, _) in enumerate(examples_with_labels):
        lines.append(f"Example {i+1}:")
        lines.append(ex["context"])
        lines.append("")
    return "\n".join(lines)


def call_score(explanation, positives, negatives):
    """Call Claude to score an explanation against held-out examples."""
    # Use held-out positives (after the first N_EXPLAIN_EXAMPLES) + negatives
    held_out_pos = positives[N_EXPLAIN_EXAMPLES:]
    neg = negatives

    # Lower threshold to 3 to handle sparse features with fewer examples
    if len(held_out_pos) < 3 or len(neg) < 3:
        return {
            "balanced_accuracy": None,
            "reason": f"too_few_examples (pos={len(held_out_pos)}, neg={len(neg)})",
        }

    # Combine and shuffle
    examples_with_labels = [(ex, True) for ex in held_out_pos] + [(ex, False) for ex in neg]
    np.random.shuffle(examples_with_labels)
    ground_truth = [label for _, label in examples_with_labels]
    n_total = len(examples_with_labels)

    prompt = build_score_prompt(explanation, examples_with_labels)
    response = client.messages.create(
        model=SCORE_MODEL,
        max_tokens=1024,
        temperature=0,
        system=SCORE_SYSTEM.replace("20 predictions", f"{n_total} predictions"),
        messages=[{"role": "user", "content": prompt}],
    )
    text = response.content[0].text.strip()
    text = re.sub(r'^```json\s*', '', text)
    text = re.sub(r'\s*```$', '', text)

    try:
        predictions = json.loads(text)
        pred_yes = [p["prediction"].upper() == "YES" for p in predictions]
    except (json.JSONDecodeError, KeyError, TypeError):
        return {"balanced_accuracy": None, "reason": "parse_error", "raw": text[:200]}

    if len(pred_yes) != len(ground_truth):
        return {
            "balanced_accuracy": None,
            "reason": f"length_mismatch ({len(pred_yes)} vs {len(ground_truth)})",
        }

    # Compute balanced accuracy
    tp = sum(1 for p, g in zip(pred_yes, ground_truth) if p and g)
    tn = sum(1 for p, g in zip(pred_yes, ground_truth) if not p and not g)
    n_pos = sum(ground_truth)
    n_neg = len(ground_truth) - n_pos
    tpr = tp / n_pos if n_pos > 0 else 0
    tnr = tn / n_neg if n_neg > 0 else 0
    balanced_acc = (tpr + tnr) / 2

    return {
        "balanced_accuracy": float(balanced_acc),
        "tpr": float(tpr),
        "tnr": float(tnr),
        "n_pos": n_pos,
        "n_neg": n_neg,
    }


# --- Run score pass on ICA components ---
scores = {"ica": {}, "sae": {}}

print(f"Running score pass on {n_ica_components} ICA components...")
for i in range(n_ica_components):
    exp = explanations["ica"].get(str(i), {})
    if exp.get("confidence", 0) == 0:
        scores["ica"][str(i)] = {"balanced_accuracy": None, "reason": "skipped_low_confidence"}
        continue
    result = call_score(
        exp["explanation"],
        ica_examples[i]["positives"],
        ica_examples[i]["negatives"],
    )
    scores["ica"][str(i)] = result
    if (i + 1) % 10 == 0:
        acc = result.get('balanced_accuracy', '?')
        acc_str = f"{acc:.2f}" if isinstance(acc, float) else str(acc)
        print(f"  {i+1}/{n_ica_components} — last acc: {acc_str}")
    time.sleep(0.5)

# --- Run score pass on SAE features ---
print(f"\nRunning score pass on {len(baseline_sae_indices)} SAE features...")
n_sae_skipped = 0
for j, feat_idx in enumerate(baseline_sae_indices):
    key = str(int(feat_idx))
    exp = explanations["sae"].get(key, {})
    if exp.get("confidence", 0) == 0:
        scores["sae"][key] = {"balanced_accuracy": None, "reason": "skipped_low_confidence"}
        n_sae_skipped += 1
        continue
    result = call_score(
        exp["explanation"],
        sae_examples[int(feat_idx)]["positives"],
        sae_examples[int(feat_idx)]["negatives"],
    )
    scores["sae"][key] = result
    if (j + 1) % 10 == 0:
        acc = result.get('balanced_accuracy')
        reason = result.get('reason', '')
        if acc is not None:
            print(f"  {j+1}/{len(baseline_sae_indices)} — last acc: {acc:.2f}")
        else:
            print(f"  {j+1}/{len(baseline_sae_indices)} — last: {reason}")
    time.sleep(0.5)

# Save scores
with open(os.path.join(AUTOINTERP_DIR, "scores.json"), "w") as f:
    json.dump(scores, f, indent=2)

# Diagnostic summary
ica_scored = sum(1 for v in scores["ica"].values() if v.get("balanced_accuracy") is not None)
sae_scored = sum(1 for v in scores["sae"].values() if v.get("balanced_accuracy") is not None)
sae_reasons = defaultdict(int)
for v in scores["sae"].values():
    if v.get("balanced_accuracy") is None:
        sae_reasons[v.get("reason", "unknown")] += 1

print(f"\nScoring complete:")
print(f"  ICA: {ica_scored}/{n_ica_components} scored successfully")
print(f"  SAE: {sae_scored}/{len(baseline_sae_indices)} scored successfully")
if sae_reasons:
    print(f"  SAE failures:")
    for reason, count in sae_reasons.items():
        print(f"    {reason}: {count}")

In [None]:
# Cell 7: Aggregate Results & Generate Tables

# Load if needed
if 'explanations' not in dir():
    with open(os.path.join(AUTOINTERP_DIR, "explanations.json")) as f:
        explanations = json.load(f)
    with open(os.path.join(AUTOINTERP_DIR, "scores.json")) as f:
        scores = json.load(f)

# --- Gather ICA stats ---
ica_confs = []
ica_accs = []
ica_novel_confs = []
ica_novel_accs = []
ica_overlap_confs = []
ica_overlap_accs = []

for i in range(n_ica_components):
    exp = explanations["ica"].get(str(i), {})
    sc = scores["ica"].get(str(i), {})
    conf = exp.get("confidence", 0)
    acc = sc.get("balanced_accuracy")

    if conf > 0:
        ica_confs.append(conf)
    if acc is not None:
        ica_accs.append(acc)

    is_novel = max_cos_sim[i] < 0.3
    if is_novel:
        if conf > 0:
            ica_novel_confs.append(conf)
        if acc is not None:
            ica_novel_accs.append(acc)
    else:
        if conf > 0:
            ica_overlap_confs.append(conf)
        if acc is not None:
            ica_overlap_accs.append(acc)

# --- Gather SAE stats ---
sae_confs = []
sae_accs = []
for feat_idx in baseline_sae_indices:
    exp = explanations["sae"].get(str(feat_idx), {})
    sc = scores["sae"].get(str(feat_idx), {})
    conf = exp.get("confidence", 0)
    acc = sc.get("balanced_accuracy")
    if conf > 0:
        sae_confs.append(conf)
    if acc is not None:
        sae_accs.append(acc)


def safe_mean(lst):
    return np.mean(lst) if lst else float('nan')

def safe_median(lst):
    return np.median(lst) if lst else float('nan')

def frac_above(lst, threshold):
    return np.mean([x >= threshold for x in lst]) if lst else float('nan')


# --- Print Table 1 ---
print("=" * 70)
print("TABLE 1: Interpretability Comparison")
print("=" * 70)
print(f"{'Metric':<40s} {'ICA (n=100)':>14s} {'SAE (n=100)':>14s}")
print("-" * 70)
print(f"{'Mean confidence (1-5)':<40s} {safe_mean(ica_confs):>14.2f} {safe_mean(sae_confs):>14.2f}")
print(f"{'Median confidence':<40s} {safe_median(ica_confs):>14.2f} {safe_median(sae_confs):>14.2f}")
print(f"{'Detection balanced acc (mean)':<40s} {safe_mean(ica_accs):>14.3f} {safe_mean(sae_accs):>14.3f}")
print(f"{'Detection balanced acc (median)':<40s} {safe_median(ica_accs):>14.3f} {safe_median(sae_accs):>14.3f}")
print(f"{'Fraction confidence >= 3':<40s} {frac_above(ica_confs, 3):>14.2f} {frac_above(sae_confs, 3):>14.2f}")
print(f"{'Fraction detection acc > 0.7':<40s} {frac_above(ica_accs, 0.7):>14.2f} {frac_above(sae_accs, 0.7):>14.2f}")
print()

# --- Print Table 2 ---
print("=" * 70)
print("TABLE 2: ICA Breakdown by Novelty")
print("=" * 70)
n_novel_total = (max_cos_sim < 0.3).sum()
n_overlap_total = (max_cos_sim >= 0.3).sum()
print(f"{'Metric':<40s} {'Novel (<0.3)':>14s} {'Overlapping':>14s}")
print(f"{'':40s} {'n='+str(n_novel_total):>14s} {'n='+str(n_overlap_total):>14s}")
print("-" * 70)
print(f"{'Mean detection accuracy':<40s} {safe_mean(ica_novel_accs):>14.3f} {safe_mean(ica_overlap_accs):>14.3f}")
print(f"{'Mean confidence':<40s} {safe_mean(ica_novel_confs):>14.2f} {safe_mean(ica_overlap_confs):>14.2f}")
print(f"{'Fraction acc > 0.7':<40s} {frac_above(ica_novel_accs, 0.7):>14.2f} {frac_above(ica_overlap_accs, 0.7):>14.2f}")
print()

# --- Interpretation ---
print("=" * 70)
print("INTERPRETATION")
print("=" * 70)
novel_mean = safe_mean(ica_novel_accs)
if not np.isnan(novel_mean):
    if novel_mean > 0.7:
        print(f"Novel ICA components have mean detection acc = {novel_mean:.3f} (> 0.7)")
        print("=> Strong evidence: dark matter contains interpretable features comparable to SAE features!")
    elif novel_mean > 0.6:
        print(f"Novel ICA components have mean detection acc = {novel_mean:.3f} (> 0.6)")
        print("=> Good evidence: dark matter has interpretable structure, somewhat noisier than SAE features.")
    else:
        print(f"Novel ICA components have mean detection acc = {novel_mean:.3f} (near chance)")
        print("=> Weak/null result: novel ICA components may not be robustly interpretable.")
else:
    print("Could not compute novel ICA detection accuracy (insufficient data).")

In [None]:
# Cell 8: Qualitative Showcase — Top Novel ICA Components

# Find novel ICA components with highest detection accuracy
novel_results = []
for i in range(n_ica_components):
    if max_cos_sim[i] >= 0.3:
        continue
    exp = explanations["ica"].get(str(i), {})
    sc = scores["ica"].get(str(i), {})
    acc = sc.get("balanced_accuracy")
    if acc is None:
        continue
    novel_results.append({
        "component": i,
        "label": exp.get("label", "?"),
        "explanation": exp.get("explanation", "?"),
        "confidence": exp.get("confidence", 0),
        "detection_acc": acc,
        "max_cos_sim": float(max_cos_sim[i]),
    })

novel_results.sort(key=lambda x: x["detection_acc"], reverse=True)

print("=" * 80)
print("SHOWCASE: Top Novel ICA Components (cosine < 0.3 to any SAE feature)")
print("=" * 80)

for rank, r in enumerate(novel_results[:10]):
    comp = r["component"]
    print(f"\n--- #{rank+1}: ICA Component {comp} ---")
    print(f"  Label:          {r['label']}")
    print(f"  Explanation:    {r['explanation']}")
    print(f"  Confidence:     {r['confidence']}/5")
    print(f"  Detection Acc:  {r['detection_acc']:.3f}")
    print(f"  SAE Cosine Sim: {r['max_cos_sim']:.3f}")
    print(f"  Top 3 activating examples:")
    positives = ica_examples[comp]["positives"]
    for j, ex in enumerate(positives[:3]):
        ctx = ex['context']
        if len(ctx) > 150:
            # Trim to center around the bracketed token
            bracket_pos = ctx.find('[')
            if bracket_pos > 0:
                start = max(0, bracket_pos - 60)
                end = min(len(ctx), bracket_pos + 90)
                ctx = '...' + ctx[start:end] + '...'
        print(f"    {j+1}. [{ex['activation']:+.2f}] {ctx}")

# Save showcase
showcase = {"novel_top_components": novel_results[:10]}
with open(os.path.join(AUTOINTERP_DIR, "showcase.json"), "w") as f:
    json.dump(showcase, f, indent=2)

# Save full results summary
summary = {
    "table1": {
        "ica": {
            "mean_confidence": float(safe_mean(ica_confs)),
            "median_confidence": float(safe_median(ica_confs)),
            "mean_detection_acc": float(safe_mean(ica_accs)),
            "frac_confidence_gte_3": float(frac_above(ica_confs, 3)),
            "frac_acc_gt_0.7": float(frac_above(ica_accs, 0.7)),
        },
        "sae": {
            "mean_confidence": float(safe_mean(sae_confs)),
            "median_confidence": float(safe_median(sae_confs)),
            "mean_detection_acc": float(safe_mean(sae_accs)),
            "frac_confidence_gte_3": float(frac_above(sae_confs, 3)),
            "frac_acc_gt_0.7": float(frac_above(sae_accs, 0.7)),
        },
    },
    "table2": {
        "novel": {
            "n": int(n_novel_total),
            "mean_detection_acc": float(safe_mean(ica_novel_accs)),
            "mean_confidence": float(safe_mean(ica_novel_confs)),
        },
        "overlapping": {
            "n": int(n_overlap_total),
            "mean_detection_acc": float(safe_mean(ica_overlap_accs)),
            "mean_confidence": float(safe_mean(ica_overlap_confs)),
        },
    },
}
with open(os.path.join(AUTOINTERP_DIR, "autointerp_summary.json"), "w") as f:
    json.dump(summary, f, indent=2)

print(f"\nSaved showcase.json and autointerp_summary.json to {AUTOINTERP_DIR}/")