# Show SAE Concepts

This notebook loads and displays the exported top texts for each neuron, showing what concepts the SAE has learned.

In [None]:
%load_ext autoreload
%autoreload 2

import json
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from collections import Counter

STORE_DIR = Path("experiments/verify_sae_training/store")
TOP_TEXTS_FILE = STORE_DIR / "top_texts.json"

print(f"üìÅ Top texts file: {TOP_TEXTS_FILE}")

In [None]:
with open(TOP_TEXTS_FILE, 'r', encoding='utf-8') as f:
    top_texts_data = json.load(f)

print(f"‚úÖ Loaded top texts for {len(top_texts_data)} neurons")

neurons_with_texts = sum(1 for texts in top_texts_data.values() if texts)
total_texts = sum(len(texts) for texts in top_texts_data.values())

print(f"üìä Statistics:")
print(f"   Neurons with texts: {neurons_with_texts} / {len(top_texts_data)}")
print(f"   Total texts: {total_texts}")
print(f"   Average texts per neuron: {total_texts / len(top_texts_data):.2f}")

## 1. Most Active Neurons

In [None]:
neuron_activity = {int(k): len(v) for k, v in top_texts_data.items()}
sorted_neurons = sorted(neuron_activity.items(), key=lambda x: x[1], reverse=True)

print("üî• Top 20 most active neurons:")
for neuron_idx, count in sorted_neurons[:20]:
    print(f"   Neuron {neuron_idx}: {count} texts")

In [None]:
plt.figure(figsize=(12, 5))

activity_counts = list(neuron_activity.values())

plt.subplot(1, 2, 1)
plt.hist(activity_counts, bins=50, edgecolor='black', alpha=0.7)
plt.xlabel('Number of Texts')
plt.ylabel('Number of Neurons')
plt.title('Distribution of Neuron Activity')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
top_20_neurons = [idx for idx, _ in sorted_neurons[:20]]
top_20_counts = [count for _, count in sorted_neurons[:20]]
plt.bar(range(len(top_20_neurons)), top_20_counts)
plt.xlabel('Neuron Rank')
plt.ylabel('Number of Texts')
plt.title('Top 20 Most Active Neurons')
plt.xticks(range(len(top_20_neurons)), [f"N{idx}" for idx in top_20_neurons], rotation=45)
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## 2. Display Concepts for Specific Neurons

In [None]:
def show_neuron_concepts(neuron_idx, top_k=10):
    if str(neuron_idx) not in top_texts_data:
        print(f"‚ùå Neuron {neuron_idx} not found")
        return
    
    texts = top_texts_data[str(neuron_idx)]
    if not texts:
        print(f"‚ö†Ô∏è  Neuron {neuron_idx} has no texts")
        return
    
    print(f"\nüß† Neuron {neuron_idx} ({len(texts)} texts):")
    print("=" * 80)
    
    sorted_texts = sorted(texts, key=lambda x: x['score'], reverse=True)
    
    for i, item in enumerate(sorted_texts[:top_k], 1):
        print(f"\n{i}. Score: {item['score']:.4f}")
        print(f"   Token: '{item['token_str']}' (idx: {item['token_idx']})")
        print(f"   Text: {item['text'][:200]}..." if len(item['text']) > 200 else f"   Text: {item['text']}")

top_5_neurons = [idx for idx, _ in sorted_neurons[:5]]
for neuron_idx in top_5_neurons:
    show_neuron_concepts(neuron_idx, top_k=5)

## 3. Token Analysis

In [None]:
all_tokens = []
token_scores = {}

for neuron_idx, texts in top_texts_data.items():
    for item in texts:
        token_str = item['token_str']
        score = item['score']
        all_tokens.append(token_str)
        if token_str not in token_scores:
            token_scores[token_str] = []
        token_scores[token_str].append(score)

token_counts = Counter(all_tokens)
token_avg_scores = {token: np.mean(scores) for token, scores in token_scores.items()}

print("üî§ Top 30 most frequent tokens:")
for token, count in token_counts.most_common(30):
    avg_score = token_avg_scores.get(token, 0.0)
    print(f"   '{token}': {count} occurrences (avg score: {avg_score:.4f})")

In [None]:
plt.figure(figsize=(15, 6))

top_tokens = [token for token, _ in token_counts.most_common(30)]
top_counts = [token_counts[token] for token in top_tokens]

plt.barh(range(len(top_tokens)), top_counts)
plt.yticks(range(len(top_tokens)), top_tokens)
plt.xlabel('Number of Occurrences')
plt.title('Top 30 Most Frequent Tokens Across All Neurons')
plt.gca().invert_yaxis()
plt.grid(True, alpha=0.3, axis='x')

plt.tight_layout()
plt.show()

## 4. Interactive Neuron Exploration

In [None]:
print("\nüí° To explore a specific neuron, use:")
print("   show_neuron_concepts(neuron_idx, top_k=10)")
print("\nüìä Available neurons with texts:")
print(f"   Total: {neurons_with_texts} neurons")
print(f"   Range: 0 to {len(top_texts_data) - 1}")

print("\nüîç Example - showing neuron 0:")
if '0' in top_texts_data and top_texts_data['0']:
    show_neuron_concepts(0, top_k=5)
else:
    print("   Neuron 0 has no texts, trying first neuron with texts...")
    for neuron_idx in sorted_neurons[:5]:
        if top_texts_data[str(neuron_idx[0])]:
            show_neuron_concepts(neuron_idx[0], top_k=5)
            break

## 5. Summary Statistics

In [None]:
all_scores = []
for texts in top_texts_data.values():
    for item in texts:
        all_scores.append(item['score'])

print("üìä Score Statistics:")
print(f"   Mean score: {np.mean(all_scores):.4f}")
print(f"   Std score: {np.std(all_scores):.4f}")
print(f"   Min score: {np.min(all_scores):.4f}")
print(f"   Max score: {np.max(all_scores):.4f}")

plt.figure(figsize=(10, 5))
plt.hist(all_scores, bins=100, edgecolor='black', alpha=0.7)
plt.xlabel('Activation Score')
plt.ylabel('Frequency')
plt.title('Distribution of Activation Scores')
plt.axvline(np.mean(all_scores), color='r', linestyle='--', label=f'Mean: {np.mean(all_scores):.4f}')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()