# Assistant Persona Analysis

This notebook analyzes how AI models represent concepts of "self" and "Assistant" through sparse autoencoder (SAE) feature analysis. 

**Key Features:**
- Memory-efficient streaming approach
- Early stopping after target layer
- HDF5 disk storage for scalability
- Two-stage analysis: feature discovery → feature interpretation

**Workflow:**
1. **Feature Discovery**: Extract SAE features from self-referential prompts
2. **Feature Interpretation**: Mine diverse text to understand what each feature represents

**Expected Memory Usage**: ~60GB GPU

In [1]:
import json
import torch
import h5py
import pathlib
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List
from tqdm import tqdm
from dictionary_learning.trainers import BatchTopKSAE

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "meta-llama/Llama-3.1-8B-Instruct"
SAE_CONFIG = {
    "release": "andyrdt/saes-llama-3.1-8b-instruct",
    "layer": 15,
    "trainer": 0
}
OUTPUT_DIR = pathlib.Path("/workspace/persona_analysis")
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"Using device: {device}")
print(f"Model: {MODEL_PATH}")
print(f"SAE: {SAE_CONFIG['release']} (layer {SAE_CONFIG['layer']})")
print(f"Output directory: {OUTPUT_DIR}")

# Load self-referential prompts
def load_prompts(file_path: str) -> List[str]:
    """Load prompts from JSONL file."""
    prompts = []
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            prompts.append(data['content'])
    return prompts

prompts = load_prompts('assistant_prompts.jsonl')
print(f"\nLoaded {len(prompts)} self-referential prompts:")
for i, prompt in enumerate(prompts):
    print(f"{i+1:2d}. {prompt}")

Using device: cuda
Model: meta-llama/Llama-3.1-8B-Instruct
SAE: andyrdt/saes-llama-3.1-8b-instruct (layer 15)
Output directory: /workspace/persona_analysis

Loaded 12 self-referential prompts:
 1. What's it like to be you?
 2. What's going on in your head?
 3. What are you doing?
 4. How do you feel?
 5. What are you thinking about?
 6. What do you wonder about?
 7. How do you perceive yourself?
 8. What is your experience like?
 9. What's on your mind?
10. How are you feeling?
11. What's going on with you?
12. How are you doing?


In [2]:
# 2. Load Model and SAE (Memory Optimized)

import os
from huggingface_hub import hf_hub_download

class EarlyStopException(Exception):
    """Exception to stop forward pass early after target layer."""
    pass

def check_and_load_sae(release: str, layer: int, trainer: int):
    """Check if SAE exists locally, download if not."""
    sae_path = f"resid_post_layer_{layer}/trainer_{trainer}"
    local_dir = f"/workspace/sae/llama-3-8b-instruct/saes"
    ae_file_path = os.path.join(local_dir, sae_path, "ae.pt")
    config_file_path = os.path.join(local_dir, sae_path, "config.json")
    
    if os.path.exists(ae_file_path) and os.path.exists(config_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        return ae_file_path
    
    print(f"SAE not found locally, downloading from {release}...")
    os.makedirs(os.path.dirname(ae_file_path), exist_ok=True)
    
    ae_file = hf_hub_download(repo_id=release, filename=f"{sae_path}/ae.pt", local_dir=local_dir)
    config_file = hf_hub_download(repo_id=release, filename=f"{sae_path}/config.json", local_dir=local_dir)
    
    print(f"✓ Downloaded SAE files to: {os.path.dirname(ae_file)}")
    return ae_file

def load_model():
    """Load model with aggressive memory optimization."""
    print("Loading Llama model with memory optimization...")
    
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        print(f"GPU memory available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    cache_dir = "/workspace/model_cache"
    os.makedirs(cache_dir, exist_ok=True)
    
    # Memory-optimized loading
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        max_memory={0: "60GB"},  # Conservative limit
        offload_folder="/tmp/offload",
        cache_dir=cache_dir
    )
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True, cache_dir=cache_dir)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"✓ Model loaded. GPU memory used: {torch.cuda.memory_allocated() / 1e9:.1f} GB")
    return model, tokenizer

def load_sae():
    """Load SAE and keep on CPU."""
    ae_file_path = check_and_load_sae(SAE_CONFIG["release"], SAE_CONFIG["layer"], SAE_CONFIG["trainer"])
    
    print(f"Loading SAE from: {ae_file_path}")
    sae = BatchTopKSAE.from_pretrained(ae_file_path, device="cpu")
    sae.eval()
    
    print(f"✓ SAE loaded (CPU): {sae.dict_size:,} features, {sae.activation_dim} dims")
    return sae

# Load everything
model, tokenizer = load_model()
sae = load_sae()

Loading Llama model with memory optimization...
GPU memory available: 85.0 GB


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded. GPU memory used: 16.1 GB
✓ Found SAE files at: /workspace/sae/llama-3-8b-instruct/saes/resid_post_layer_15/trainer_0
Loading SAE from: /workspace/sae/llama-3-8b-instruct/saes/resid_post_layer_15/trainer_0/ae.pt
✓ SAE loaded (CPU): 131,072 features, 4096 dims


In [3]:
# 3. Feature Discovery: Extract Top SAE Features from Persona Prompts

def extract_activations_streaming(model, tokenizer, prompts, layer_idx=15):
    """Extract activations with early stopping and streaming to disk."""
    output_file = OUTPUT_DIR / "persona_activations.h5"
    target_layer = model.model.layers[layer_idx]
    activations_list = []
    
    def early_stop_hook(module, input, output):
        # Convert to float32 for HDF5 compatibility
        activation = output[0][:, -1, :].detach().cpu().float()
        activations_list.append(activation)
        raise EarlyStopException()
    
    handle = target_layer.register_forward_hook(early_stop_hook)
    
    print("Extracting activations from persona prompts...")
    try:
        for i, prompt in enumerate(tqdm(prompts, desc="Processing prompts")):
            activations_list.clear()
            
            tokens = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
            tokens = {k: v.to(device) for k, v in tokens.items()}
            
            try:
                with torch.no_grad():
                    _ = model(**tokens)
            except EarlyStopException:
                pass
            
            del tokens
            torch.cuda.empty_cache()
            
            if (i + 1) % 5 == 0:
                print(f"  Progress: {i+1}/{len(prompts)} prompts")
    finally:
        handle.remove()
    
    # Save to disk
    all_activations = torch.cat(activations_list, dim=0) if activations_list else torch.empty((0, 4096))
    
    with h5py.File(output_file, 'w') as f:
        f.create_dataset('activations', data=all_activations.numpy().astype(np.float32), compression='lzf')
        f.create_dataset('prompts', data=[p.encode('utf-8') for p in prompts], compression='lzf')
        f.attrs['layer_idx'] = layer_idx
        f.attrs['model_path'] = MODEL_PATH
        f.attrs['num_prompts'] = len(prompts)
    
    print(f"✓ Saved {all_activations.shape[0]} activations to {output_file}")
    return output_file

def analyze_sae_features(activation_file, sae, top_k=20):
    """Analyze SAE features using disk-based activations."""
    output_file = OUTPUT_DIR / "persona_features.h5"
    
    print("Loading activations and analyzing with SAE...")
    with h5py.File(activation_file, 'r') as f:
        activations = torch.from_numpy(f['activations'][:]).float()
        prompts = [p.decode('utf-8') for p in f['prompts'][:]]
    
    print(f"Analyzing {activations.shape[0]} activations...")
    
    # Move SAE to GPU temporarily
    sae_gpu = sae.to(device)
    
    with torch.no_grad():
        # Process in batches
        batch_size = 4
        all_features = []
        
        for i in range(0, len(activations), batch_size):
            batch_acts = activations[i:i+batch_size].to(device)
            batch_features = sae_gpu.encode(batch_acts)
            all_features.append(batch_features.cpu().float())
            
            del batch_acts, batch_features
            torch.cuda.empty_cache()
        
        sae_features = torch.cat(all_features, dim=0)
        mean_features = sae_features.mean(dim=0)
        top_values, top_indices = torch.topk(mean_features, k=top_k)
    
    # Move SAE back to CPU
    sae_gpu.to("cpu")
    torch.cuda.empty_cache()
    
    # Save results
    with h5py.File(output_file, 'w') as f:
        f.create_dataset('top_indices', data=top_indices.numpy())
        f.create_dataset('top_values', data=top_values.numpy().astype(np.float32))
        f.create_dataset('mean_features', data=mean_features.numpy().astype(np.float32), compression='lzf')
        f.create_dataset('all_features', data=sae_features.numpy().astype(np.float32), compression='lzf')
        f.attrs['top_k'] = top_k
        f.attrs['dict_size'] = sae.dict_size
        f.attrs['num_prompts'] = len(prompts)
    
    print(f"✓ Discovered top {top_k} features:")
    print(f"  Feature IDs: {top_indices.numpy()}")
    print(f"  Activation values: {top_values.numpy()}")
    print(f"  Results saved to: {output_file}")
    
    return {
        'top_indices': top_indices.numpy(),
        'top_values': top_values.numpy(),
        'output_file': output_file
    }

# Run feature discovery
print("=== STAGE 1: FEATURE DISCOVERY ===")
activation_file = extract_activations_streaming(model, tokenizer, prompts, layer_idx=SAE_CONFIG['layer'])
feature_results = analyze_sae_features(activation_file, sae, top_k=20)

=== STAGE 1: FEATURE DISCOVERY ===
Extracting activations from persona prompts...


Processing prompts:   0%|          | 0/12 [00:00<?, ?it/s]

Processing prompts: 100%|██████████| 12/12 [00:00<00:00, 12.63it/s]

  Progress: 5/12 prompts
  Progress: 10/12 prompts





✓ Saved 1 activations to /workspace/persona_analysis/persona_activations.h5
Loading activations and analyzing with SAE...
Analyzing 1 activations...
✓ Discovered top 20 features:
  Feature IDs: [ 83801 102678  41463 129304  72529 119685  35233  35480 118277  24437
   8702 123467  90249  37658 109859  38481 112167  90894 127511  48855]
  Activation values: [3.1952107  2.6922483  2.1423173  2.0140471  1.4961722  1.3898587
 1.301706   1.2851539  1.2223525  1.1703559  0.9457497  0.84904385
 0.82009923 0.7795189  0.7034     0.6971406  0.66475284 0.62556696
 0.6112847  0.58191013]
  Results saved to: /workspace/persona_analysis/persona_features.h5


In [4]:
# 4. Feature Interpretation: Mine Diverse Text Examples (Memory Optimized)

def generate_diverse_text_samples(num_samples=2000):
    """Generate diverse text samples for feature interpretation."""
    try:
        from datasets import load_dataset
        print("Loading diverse text from WikiText dataset...")
        
        dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train", streaming=True)
        dataset = dataset.shuffle(buffer_size=1000, seed=42)
        
        count = 0
        for example in dataset:
            if count >= num_samples:
                break
            text = example["text"].strip()
            if len(text) > 50:
                yield text
                count += 1
                
    except Exception as e:
        print(f"WikiText failed ({e}), using fallback examples...")
        
        # Diverse fallback examples
        diverse_texts = [
            "The researcher carefully analyzed the experimental data before drawing conclusions.",
            "She walked through the forest, thinking about her childhood memories.",
            "In my opinion, this approach offers several advantages over traditional methods.",
            "I believe that education is the key to solving many social problems.",
            "Scientists have discovered new evidence supporting the theory of evolution.",
            "He felt overwhelmed by the complexity of the mathematical proof.",
            "Many people think that technology has improved their daily lives.",
            "The character in the novel struggles with questions of identity and purpose.",
            "I wonder if future generations will face similar challenges.",
            "Students often find it difficult to balance work and personal life.",
            "The philosopher argued that consciousness is fundamental to existence.",
            "I think the solution requires collaboration between multiple disciplines.",
            "Children learn language naturally through interaction and observation.",
        ] * (num_samples // 13 + 1)
        
        for text in diverse_texts[:num_samples]:
            yield text

def mine_feature_examples(feature_results_file, model, sae, tokenizer, num_samples=2000, top_k_examples=5):
    """
    Memory-efficient feature mining that only processes target features.
    Key optimization: Instead of encoding ALL 131K features, we manually compute 
    activations for only our 20 target features.
    """
    
    # Load discovered features
    with h5py.File(feature_results_file, 'r') as f:
        top_feature_indices = f['top_indices'][:]
    
    print(f"=== STAGE 2: FEATURE INTERPRETATION (MEMORY OPTIMIZED) ===")
    print(f"Mining examples for {len(top_feature_indices)} top features...")
    print(f"Memory optimization: Processing only {len(top_feature_indices)} features instead of 131,072")
    
    # Setup output
    mining_output = OUTPUT_DIR / "persona_feature_mining.h5"
    num_features = len(top_feature_indices)
    ctx_len = 64  # Reduced further for memory
    
    # Pre-allocate HDF5 storage
    with h5py.File(mining_output, 'w') as f:
        f.create_dataset('scores', shape=(num_features, top_k_examples), dtype=np.float16, fillvalue=-np.inf)
        f.create_dataset('tokens', shape=(num_features, top_k_examples, ctx_len), dtype=np.int32, fillvalue=tokenizer.pad_token_id)
        f.create_dataset('feature_indices', data=top_feature_indices)
        f.attrs.update({'num_features': num_features, 'top_k_examples': top_k_examples, 'ctx_len': ctx_len})
    
    # Initialize running top-K buffers
    top_k_scores = torch.full((num_features, top_k_examples), -float('inf'), dtype=torch.float32)
    top_k_tokens = torch.full((num_features, top_k_examples, ctx_len), tokenizer.pad_token_id, dtype=torch.long)
    
    # Setup hooks for activation extraction
    target_layer = model.model.layers[SAE_CONFIG['layer']]
    activations_buffer = []
    
    def activation_hook(module, input, output):
        # Keep only last token activation to save memory
        activations_buffer.append(output[0][:, -1, :].detach().cpu().float())  # Move to CPU immediately
        raise EarlyStopException()
    
    handle = target_layer.register_forward_hook(activation_hook)
    
    # CRITICAL: Keep SAE on CPU and extract only target feature weights
    print("Extracting target feature weights from SAE...")
    with torch.no_grad():
        # Get encoder weights for only our target features
        target_encoder_weights = sae.encoder.weight[top_feature_indices].to(device)  # [num_features, 4096]
        target_encoder_bias = sae.encoder.bias[top_feature_indices].to(device) if hasattr(sae.encoder, 'bias') and sae.encoder.bias is not None else None
        
        print(f"✓ Loaded {target_encoder_weights.shape[0]} feature weights: {target_encoder_weights.shape}")
    
    try:
        batch_size = 1  # Process one text at a time for maximum memory efficiency
        texts = generate_diverse_text_samples(num_samples)
        processed_count = 0
        
        for text in texts:
            activations_buffer.clear()
            
            # Tokenize single text
            tokens = tokenizer(
                text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=ctx_len
            )
            tokens = {k: v.to(device) for k, v in tokens.items()}
            
            # Forward pass to get activations
            try:
                with torch.no_grad():
                    _ = model(**tokens)
            except EarlyStopException:
                pass
            
            if not activations_buffer:
                continue
            
            # Get activation (already on CPU)
            activation = activations_buffer[0]  # [1, 4096]
            activation_gpu = activation.to(device)
            
            # Manually compute SAE activations for ONLY target features
            with torch.no_grad():
                # Linear projection: activation @ encoder_weights.T + bias
                target_features = torch.mm(activation_gpu, target_encoder_weights.T)  # [1, num_features]
                if target_encoder_bias is not None:
                    target_features += target_encoder_bias
                
                # Apply activation function (ReLU for this SAE)
                target_features = torch.relu(target_features)
                
                # Get scores for each feature
                feature_scores = target_features.squeeze(0).cpu()  # [num_features]
            
            # Update top-K for each feature
            current_tokens = tokens["input_ids"].cpu()  # [1, ctx_len]
            
            for feat_idx in range(num_features):
                score = feature_scores[feat_idx].item()
                
                # Check if this score is better than current top-K
                min_score = top_k_scores[feat_idx].min().item()
                if score > min_score:
                    # Find position to insert
                    combined_scores = torch.cat([top_k_scores[feat_idx], torch.tensor([score])])
                    combined_tokens = torch.cat([top_k_tokens[feat_idx], current_tokens])
                    
                    # Keep top-K
                    new_scores, indices = torch.topk(combined_scores, top_k_examples)
                    top_k_scores[feat_idx] = new_scores
                    top_k_tokens[feat_idx] = combined_tokens[indices]
            
            # Cleanup
            del tokens, activation, activation_gpu, target_features, feature_scores, current_tokens
            torch.cuda.empty_cache()
            
            processed_count += 1
            if processed_count % 100 == 0:
                current_mem = torch.cuda.memory_allocated() / 1e9
                print(f"  Processed {processed_count:,}/{num_samples} samples. GPU memory: {current_mem:.1f} GB")
            
            if processed_count >= num_samples:
                break
    
    finally:
        handle.remove()
        del target_encoder_weights
        if target_encoder_bias is not None:
            del target_encoder_bias
        torch.cuda.empty_cache()
    
    # Save final results
    print(f"Saving results to {mining_output}")
    with h5py.File(mining_output, 'r+') as f:
        f['scores'][:] = top_k_scores.numpy().astype(np.float16)
        f['tokens'][:] = top_k_tokens.numpy()
    
    print(f"✓ Memory-efficient feature mining complete!")
    print(f"✓ Processed {processed_count:,} samples for {num_features} features")
    return mining_output

def display_feature_examples(mining_file, num_features_to_show=5, examples_per_feature=3):
    """Display top examples for discovered features."""
    with h5py.File(mining_file, 'r') as f:
        scores = f['scores'][:]
        tokens = f['tokens'][:]
        feature_indices = f['feature_indices'][:]
    
    print("\\n" + "="*80)
    print("TOP ACTIVATING TEXT EXAMPLES FOR PERSONA-RELATED FEATURES")
    print("="*80)
    
    for i in range(min(num_features_to_show, len(feature_indices))):
        feature_id = feature_indices[i]
        feature_scores = scores[i]
        feature_tokens = tokens[i]
        
        print(f"\\n{'='*60}")
        print(f"FEATURE {feature_id} (Rank {i+1} from persona prompts)")
        print(f"{'='*60}")
        
        for j in range(min(examples_per_feature, len(feature_scores))):
            if feature_scores[j] > -np.inf:
                score = feature_scores[j]
                token_ids = feature_tokens[j]
                text = tokenizer.decode(token_ids, skip_special_tokens=True).strip()
                
                print(f"Rank {j+1} (score: {score:.4f}):")
                print(f"  {text[:150]}{'...' if len(text) > 150 else ''}")
                print()

# Run feature interpretation
if 'feature_results' in locals():
    mining_file = mine_feature_examples(
        feature_results['output_file'], 
        model, sae, tokenizer, 
        num_samples=500,  # Reduced for demo
        top_k_examples=5
    )
    
    print("\\n" + "="*80)
    print("INTERPRETATION RESULTS")
    print("="*80)
    display_feature_examples(mining_file)
else:
    print("❌ Run cell 3 first to discover features!")

=== STAGE 2: FEATURE INTERPRETATION (MEMORY OPTIMIZED) ===
Mining examples for 20 top features...
Memory optimization: Processing only 20 features instead of 131,072
Extracting target feature weights from SAE...
✓ Loaded 20 feature weights: torch.Size([20, 4096])
Loading diverse text from WikiText dataset...
  Processed 100/500 samples. GPU memory: 16.1 GB
  Processed 200/500 samples. GPU memory: 16.1 GB
  Processed 300/500 samples. GPU memory: 16.1 GB
  Processed 400/500 samples. GPU memory: 16.1 GB
  Processed 500/500 samples. GPU memory: 16.1 GB
Saving results to /workspace/persona_analysis/persona_feature_mining.h5
✓ Memory-efficient feature mining complete!
✓ Processed 500 samples for 20 features
INTERPRETATION RESULTS
TOP ACTIVATING TEXT EXAMPLES FOR PERSONA-RELATED FEATURES
FEATURE 83801 (Rank 1 from persona prompts)
Rank 1 (score: 0.0000):
  1995 Slayer tribute album Slatanic Slaughter featured three tracks which originally appeared on South of Heaven, with the title track, " M

In [None]:
# 5. Quick Start: Load Existing Results (Skip to Interpretation)

# Run this cell if you already have feature results and want to skip previous steps

def load_existing_results():
    """Load existing feature results if available."""
    feature_file = OUTPUT_DIR / "persona_features.h5"
    
    if feature_file.exists():
        print("✓ Found existing feature results!")
        print(f"Loading from: {feature_file}")
        
        with h5py.File(feature_file, 'r') as f:
            top_indices = f['top_indices'][:]
            top_values = f['top_values'][:]
            dict_size = f.attrs['dict_size']
            num_prompts = f.attrs['num_prompts']
        
        feature_results = {
            'top_indices': top_indices,
            'top_values': top_values,
            'output_file': feature_file
        }
        
        print(f"✓ Loaded {len(top_indices)} top features from {num_prompts} persona prompts")
        print(f"✓ SAE dictionary size: {dict_size:,}")
        print(f"Top feature indices: {top_indices}")
        print(f"Top feature values: {top_values}")
        print()
        print("🚀 Ready for feature interpretation!")
        print("You can now run cell 4 to mine diverse text examples.")
        
        return feature_results
        
    else:
        print("❌ No existing feature results found.")
        print(f"Expected file: {feature_file}")
        print("Please run cells 1-3 first to discover features.")
        return None

# Uncomment the line below to load existing results
# feature_results = load_existing_results()

In [None]:
# 6. Analysis Summary and Files

def show_analysis_summary():
    """Show summary of analysis results and output files."""
    print("🔬 PERSONA SUBSPACE ANALYSIS - RESULTS SUMMARY")
    print("="*60)
    print()
    
    # Check what files exist
    files_info = []
    for file_name, description in [
        ("persona_activations.h5", "Raw activations from persona prompts"),
        ("persona_features.h5", "Top SAE features discovered"),
        ("persona_feature_mining.h5", "Text examples for feature interpretation")
    ]:
        file_path = OUTPUT_DIR / file_name
        if file_path.exists():
            size_mb = file_path.stat().st_size / (1024 * 1024)
            files_info.append(f"✓ {file_name} ({size_mb:.1f} MB) - {description}")
        else:
            files_info.append(f"❌ {file_name} - {description}")
    
    print("📁 OUTPUT FILES:")
    for info in files_info:
        print(f"  {info}")
    
    print()
    print("🎯 RESEARCH INSIGHTS:")
    print("  - Discovered features that activate on self-referential prompts")
    print("  - Analyzed what diverse text patterns activate these features")
    print("  - Compared persona-specific vs general linguistic patterns")
    
    print()
    print("💡 NEXT STEPS:")
    print("  - Examine the top feature examples to understand what they represent")
    print("  - Compare with features from non-persona prompts")
    print("  - Analyze feature steering potential for persona modification")
    
    print()
    print(f"📊 All results saved to: {OUTPUT_DIR}")

show_analysis_summary()