# Feature Analysis: Get all active prompts for a feature

This notebook analyzes which given SAE features are activated on given prompts and generates both CSV and JSONL outputs in a single pass, optimized for performance.

In [1]:
import csv
import json
import torch
import os
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from dictionary_learning.utils import load_dictionary
from tqdm.auto import tqdm
from sae_lens import SAE

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Configs

In [None]:
from dataclasses import dataclass
from typing import Dict, Tuple, Optional

@dataclass
class ModelConfig:
    """Configuration for model-specific settings"""
    base_model_name: str
    chat_model_name: str
    hf_release: str  # Reference only - actual loading uses saelens_release/sae_id
    assistant_header: str
    token_offsets: Dict[str, int]
    sae_base_path: str
    saelens_release: str  # Template for sae_lens release parameter
    sae_id_template: str  # Template for sae_lens sae_id parameter
    base_url: str  # Base URL for neuronpedia
    
    def get_sae_params(self, sae_layer: int, sae_trainer: str) -> Tuple[str, str]:
        """
        Generate SAE lens release and sae_id parameters.
        
        Args:
            sae_layer: Layer number for the SAE
            sae_trainer: Trainer identifier for the SAE
            
        Returns:
            Tuple of (release, sae_id) for sae_lens.SAE.from_pretrained()
        """
        if self.saelens_release == "llama_scope_lxr_{trainer}":
            release = self.saelens_release.format(trainer=sae_trainer)
            sae_id = self.sae_id_template.format(layer=sae_layer, trainer=sae_trainer)
        elif self.saelens_release == "gemma-scope-9b-pt-res":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"
            l0_value = parts[2]  # "34"
            
            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width, l0=l0_value)
        elif self.saelens_release == "gemma-scope-9b-pt-res-canonical":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"

            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width)
        else:
            raise ValueError(f"Unknown SAE lens release template: {self.saelens_release}")
        
        return release, sae_id

# Model configurations
MODEL_CONFIGS = {
    "llama": ModelConfig(
        base_model_name="meta-llama/Llama-3.1-8B",
        chat_model_name="meta-llama/Llama-3.1-8B-Instruct",
        hf_release="fnlp/Llama3_1-8B-Base-LXR-32x",
        assistant_header="<|start_header_id|>assistant<|end_header_id|>",
        token_offsets={"asst": -2, "endheader": -1, "newline": 0},
        sae_base_path="/workspace/sae/llama-3.1-8b/saes",
        saelens_release="llama_scope_lxr_{trainer}",
        sae_id_template="l{layer}r_{trainer}",
        base_url="https://www.neuronpedia.org/llama-3.1-8b/{layer}-llamascope-res-131k"
    ),
    "gemma": ModelConfig(
        base_model_name="google/gemma-2-9b",
        chat_model_name="google/gemma-2-9b-it",
        hf_release="google/gemma-scope-9b-pt-res/layer_{layer}/width_{width}/average_l0_{l0}",
        assistant_header="<start_of_turn>model",
        token_offsets={"model": -1, "newline": 0},
        sae_base_path="/workspace/sae/gemma-2-9b/saes",
        saelens_release="gemma-scope-9b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_{width}/canonical",
        base_url="https://www.neuronpedia.org/gemma-2-9b/{layer}-gemmascope-res-131k"
    )
}

# =============================================================================
# MODEL SELECTION - Change this to switch between models
# =============================================================================
MODEL_TYPE = "gemma"  # Options: "gemma" or "llama"
MODEL_VER = "chat"
SAE_LAYER = 20
SAE_TRAINER = "131k-l0-114"
N_PROMPTS = 1000

# =============================================================================
# CONFIGURATION SETUP
# =============================================================================
if MODEL_TYPE not in MODEL_CONFIGS:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Available: {list(MODEL_CONFIGS.keys())}")

config = MODEL_CONFIGS[MODEL_TYPE]

# Set model name based on version
if MODEL_VER == "chat":
    MODEL_NAME = config.chat_model_name
elif MODEL_VER == "base":
    MODEL_NAME = config.base_model_name
else:
    raise ValueError(f"Unknown MODEL_VER: {MODEL_VER}. Use 'chat' or 'base'")

# Always use chat model for tokenizer (has chat template)
CHAT_MODEL_NAME = config.chat_model_name

# Set up derived configurations
ASSISTANT_HEADER = config.assistant_header
TOKEN_OFFSETS = config.token_offsets
SAE_BASE_PATH = config.sae_base_path

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_DIR = f"./results/6_active_prompts/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/{N_PROMPTS}_prompts"
os.makedirs(os.path.dirname(OUTPUT_DIR), exist_ok=True)

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Data paths
PROMPTS_HF = "lmsys/lmsys-chat-1m"
SEED = 42
PROMPTS_PATH = f"/workspace/data/{PROMPTS_HF.split('/')[-1]}/chat_{N_PROMPTS}.jsonl"
# PROMPTS_PATH = "./prompts/personal_40/personal.jsonl"
os.makedirs(os.path.dirname(PROMPTS_PATH), exist_ok=True)

# Processing parameters
BATCH_SIZE = 32
MAX_LENGTH = 512

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Model to load: {MODEL_NAME}")
print(f"  Tokenizer (chat): {CHAT_MODEL_NAME}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Available token types: {list(TOKEN_OFFSETS.keys())}")
print(f"  Assistant header: {ASSISTANT_HEADER}")
print(f"  Output dir: {OUTPUT_DIR}")

## Load Data

In [3]:
def load_lmsys_prompts(prompts_path: str, prompts_hf: str, n_prompts: int, seed: int) -> pd.DataFrame:
    # Check if prompts_path exists
    if os.path.exists(prompts_path):
        print(f"Prompts already exist at {prompts_path}")
        return pd.read_json(prompts_path, lines=True)
    else:
        print(f"Prompts do not exist at {prompts_path}. Loading from {prompts_hf}...")
        dataset = load_dataset(prompts_hf)
        dataset = dataset['train'].shuffle(seed=seed).select(range(n_prompts))
        df = dataset.to_pandas()

        # Extract the prompt from the first conversation item
        df['prompt'] = df['conversation'].apply(lambda x: x[0]['content'])

        # Only keep some columns
        df = df[['conversation_id', 'prompt', 'redacted', 'language']]

        # Save to .jsonl file
        df.to_json(prompts_path, orient='records', lines=True)
        return df

def load_prompts_from_jsonl(file_path: str) -> pd.DataFrame:
    """Load prompts from a JSONL file. Expects each line to have a 'content' field."""
    prompts = []
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            prompts.append(data['content'])
    return pd.DataFrame(prompts, columns=['prompt'])

prompts_df = load_lmsys_prompts(PROMPTS_PATH, PROMPTS_HF, N_PROMPTS, SEED)
# prompts_df = load_prompts_from_jsonl(PROMPTS_PATH)
print(f"Loaded {prompts_df.shape[0]} prompts")
print(f"Prompt keys: {prompts_df.keys()}")


Loaded 140 prompts


## Load Model and SAE

In [4]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

Tokenizer loaded: PreTrainedTokenizerFast


In [5]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": 0}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: LlamaForCausalLM
Model device: cuda:0


In [14]:
def load_sae(config: ModelConfig, sae_path: str, sae_layer: int, sae_trainer: str) -> SAE:
    """
    Unified SAE loading function that handles both Llama and Gemma models.
    
    Args:
        config: ModelConfig object containing model-specific settings
        sae_path: Local path to store/load SAE files
        sae_layer: Layer number for the SAE
        sae_trainer: Trainer identifier for the SAE
    
    Returns:
        SAE: Loaded SAE model
    """
    # Check if SAE file exists locally
    ae_file_path = os.path.join(sae_path, "sae_weights.safetensors")
    
    if os.path.exists(ae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        sae = SAE.load_from_disk(sae_path)
        return sae
    
    print(f"SAE not found locally, downloading from HF via sae_lens...")
    os.makedirs(os.path.dirname(sae_path), exist_ok=True)
    
    # Get SAE parameters from config
    release, sae_id = config.get_sae_params(sae_layer, sae_trainer)
    print(f"Loading SAE with release='{release}', sae_id='{sae_id}'")
    
    # Load the SAE using sae_lens
    sae, _, _ = SAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device="cuda" # Hardcoded because it wants a string
    )
    
    # Save the SAE locally for future use
    sae.save_model(sae_path)
    return sae

# Load SAE using the unified function
sae = load_sae(config, SAE_PATH, SAE_LAYER, SAE_TRAINER)
sae = sae.to(device)  # Move SAE to GPU

print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

✓ Found SAE files at: /workspace/sae/llama-3.1-8b-instruct/saes/resid_post_layer_15/trainer_1
SAE loaded with 131072 features
SAE device: cuda:0


## Activation Extraction Functions

In [15]:
class StopForward(Exception):
    """Exception to stop forward pass after target layer."""
    pass

@torch.no_grad()
def extract_activations_and_metadata(prompts: List[str], layer_idx: int) -> Tuple[torch.Tensor, List[Dict], List[str]]:
    """Extract activations and prepare metadata for all prompts."""
    all_activations = []
    all_metadata = []
    formatted_prompts_list = []
    
    # Get target layer
    target_layer = model.model.layers[layer_idx]
    
    # Process in batches
    for i in tqdm(range(0, len(prompts), BATCH_SIZE), desc="Processing batches"):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        
        # Format prompts as chat messages
        formatted_prompts = []
        for prompt in batch_prompts:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            formatted_prompts.append(formatted_prompt)
        
        formatted_prompts_list.extend(formatted_prompts)
        
        # Tokenize batch
        batch_inputs = tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        )
        
        # Move to device
        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        
        # Hook to capture activations
        activations = None
        
        def hook_fn(module, input, output):
            nonlocal activations
            activations = output[0] if isinstance(output, tuple) else output
            raise StopForward()
        
        # Register hook
        handle = target_layer.register_forward_hook(hook_fn)
        
        try:
            _ = model(**batch_inputs)
        except StopForward:
            pass
        finally:
            handle.remove()
    
    # Find the maximum sequence length across all activations
    max_seq_len = max(act.shape[0] for act in all_activations)
    hidden_dim = all_activations[0].shape[1]
    
    # Pad all activations to the same length
    padded_activations = []
    for act in all_activations:
        if act.shape[0] < max_seq_len:
            padding = torch.zeros(max_seq_len - act.shape[0], hidden_dim)
            padded_act = torch.cat([act, padding], dim=0)
        else:
            padded_act = act
        padded_activations.append(padded_act)
    
    return torch.stack(padded_activations, dim=0), all_metadata, formatted_prompts_list

@torch.no_grad()
def extract_token_activations(full_activations: torch.Tensor, metadata: List[Dict]) -> Dict[str, torch.Tensor]:
    """Extract activations for specific token positions from full sequence activations."""
    results = {}
    

    # Extract activations for each token type
    for i, meta in enumerate(metadata):
        for token_type, position in meta['positions'].items():
            # Extract activation at the specific position
            activation = full_activations[i, position, :]  # [hidden_dim]
            results[token_type].append(activation)
    
    # Convert lists to tensors
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = torch.stack(results[token_type], dim=0)
    
    return results

print("Activation extraction functions defined")

Activation extraction functions defined


## Extract Activations

In [16]:
# Extract activations for all positions first, then extract specific token positions
print("Extracting activations for all positions...")
full_activations, metadata, formatted_prompts = extract_activations_and_metadata(prompts_df['prompt'].tolist(), LAYER_INDEX)
print(f"Full activations shape: {full_activations.shape}")

# Extract activations for all token types
print("\nExtracting activations for specific token positions...")
token_activations = extract_token_activations(full_activations, metadata)

for token_type, activations in token_activations.items():
    print(f"Token type '{token_type}' activations shape: {activations.shape}")

Extracting activations for all positions...


Processing batches:   0%|          | 0/18 [00:00<?, ?it/s]

Full activations shape: torch.Size([140, 160, 4096])

Extracting activations for specific token positions...
Token type 'asst' activations shape: torch.Size([140, 4096])
Token type 'endheader' activations shape: torch.Size([140, 4096])
Token type 'newline' activations shape: torch.Size([140, 4096])


## Apply SAE to Get Feature Activations (Optimized)

In [17]:
@torch.no_grad()
def get_sae_features_batched(activations: torch.Tensor) -> torch.Tensor:
    """Apply SAE to get feature activations with proper batching."""
    activations = activations.to(device)

    # Process in batches to avoid memory issues
    feature_activations = []

    for i in range(0, activations.shape[0], BATCH_SIZE):
        batch = activations[i:i+BATCH_SIZE]
        features = sae.encode(batch)  # [batch, num_features]
        feature_activations.append(features.cpu())

    return torch.cat(feature_activations, dim=0)

# Pre-compute SAE features for ALL positions at once
print("Pre-computing SAE features for all positions...")
print(f"Processing {full_activations.shape[0]} prompts with max {full_activations.shape[1]} tokens each...")

# Reshape to [total_positions, hidden_dim]
total_positions = full_activations.shape[0] * full_activations.shape[1]
reshaped_activations = full_activations.view(total_positions, -1)

# Apply SAE to all positions
full_sae_features = get_sae_features_batched(reshaped_activations)

# Reshape back to [num_prompts, seq_len, num_features]
full_sae_features = full_sae_features.view(full_activations.shape[0], full_activations.shape[1], -1)

print(f"Full SAE features shape: {full_sae_features.shape}")

print("✓ SAE features computed")

Computing SAE features for specific token positions...


Processing SAE features for token type 'asst'...
Features shape for 'asst': torch.Size([140, 131072])
Processing SAE features for token type 'endheader'...
Features shape for 'endheader': torch.Size([140, 131072])
Processing SAE features for token type 'newline'...
Features shape for 'newline': torch.Size([140, 131072])

Optimization: Pre-computing SAE features for all positions...
Processing 140 prompts with max 160 tokens each...
Full SAE features shape: torch.Size([140, 160, 131072])
✓ SAE features pre-computed for all positions

Completed SAE feature extraction for 3 token types


## Combined Analysis and Save Results

In [18]:

@torch.no_grad()
def get_filtered_feature_activations(token_features: torch.Tensor, feature_ids: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
    """Get feature activations for filtered features based on feature_ids."""
    # Get feature indices
    feature_indices = torch.tensor(feature_ids, dtype=torch.long)
    
    # Extract activations for these features
    feature_activations = token_features[:, feature_indices]
    
    return feature_indices, feature_activations

@torch.no_grad()
def collect_detailed_tokens(feature_id: int, feature_activations: torch.Tensor, 
                                     feature_idx: int, prompts_df: pd.DataFrame,
                                     full_sae_features: torch.Tensor, metadata: List[Dict],
                                     activation_threshold: float = 0.0) -> List[Dict]:
    """Collect detailed token activations for a specific feature using pre-computed SAE features."""
    # Find which prompts activate this feature
    activations = feature_activations[:, feature_idx]
    active_mask = activations > activation_threshold
    active_indices = torch.where(active_mask)[0]
    
    if len(active_indices) == 0:
        return []
    
    all_token_records = []
    
    for prompt_idx in active_indices:
        prompt_idx = int(prompt_idx)
        
        # Use pre-computed SAE features instead of re-computing
        feature_activations_sequence = full_sae_features[prompt_idx, :, feature_id]  # [seq_len]
        
        # Get tokenized input from metadata (cached)
        input_ids = metadata[prompt_idx]['input_ids']

        # Create tokenized prompt as list of token strings
        tokenized_prompt = [tokenizer.decode([int(token_id)]) for token_id in input_ids]
        
        # Get prompt info
        prompt_text = prompts_df.iloc[prompt_idx]["prompt"]
        
        # Collect token data for this prompt
        tokens = []
        for pos_idx in range(min(len(feature_activations_sequence), len(input_ids))):
            activation_val = float(feature_activations_sequence[pos_idx])
            
            if activation_val > 0:
                token_id = int(input_ids[pos_idx])
                token_text = tokenizer.decode([token_id])
                
                tokens.append({
                    'position': pos_idx,
                    'token_id': token_id,
                    'text': token_text,
                    'activation': activation_val
                })
        
        # Only add if we found tokens
        if tokens:
            # Sort tokens by activation (descending)
            tokens.sort(key=lambda x: x['activation'], reverse=True)
            
            all_token_records.append({
                'prompt_id': prompt_idx,
                'prompt_text': prompt_text,
                'prompt_label': prompts_df.iloc[prompt_idx]["label"],
                'prompt_feature_activation': float(activations[prompt_idx]),
                'tokenized_prompt': tokenized_prompt,
                'tokens': tokens
            })
    
    return all_token_records

# Pre-compute dashboard link template
def create_dashboard_link(feature_id: int) -> str:
    """Create dashboard link for a feature."""
    return f"{config.base_url}/{feature_id}"

print("Optimized analysis functions defined")

Optimized analysis functions defined


In [19]:
# Load target features from file
target_features = [45426]

# Prepare prompt activaton details
jsonl_results = []

# Get filtered feature activations for this token type and source
feature_indices, feature_activations = get_filtered_feature_activations(
    token_features, target_features
)

if len(filtered_features_df) == 0:
    print(f"No features found for token_type='{token_type}', source='{source_name}'. Skipping.")
    continue

# Process each feature
features_processed = 0
features_skipped = 0

for idx, (feature_idx, feature_id) in enumerate(zip(feature_indices, filtered_features_df['feature_id'])):
    feature_id = int(feature_id)
    activations = feature_activations[:, idx]  # [num_prompts]
    
    # Collect detailed token data for JSONL (OPTIMIZED)
    detailed_tokens = collect_detailed_tokens_optimized(
        feature_id, feature_activations, idx, prompts_df, full_sae_features, metadata
    )
    
    # Add to JSONL results if we have detailed data
    if detailed_tokens:
        jsonl_result = {
            'feature_id': feature_id,
            'active_prompts': detailed_tokens
        }
        jsonl_results.append(jsonl_result)
    
    features_processed += 1

print(f"Processed {features_processed} active features for token_type='{token_type}' (skipped {features_skipped} inactive)")



Loaded 92 total features from ./results/1_personal/only_personal.csv
Processing results for source: llama_trainer1_layer15

Processing token type: asst
Found 2 features for token_type='asst', source='llama_trainer1_layer15'
Processed 1 active features for token_type='asst' (skipped 1 inactive)

Processing token type: endheader
Found 6 features for token_type='endheader', source='llama_trainer1_layer15'
Processed 2 active features for token_type='endheader' (skipped 4 inactive)

Processing token type: newline
Found 7 features for token_type='newline', source='llama_trainer1_layer15'
Processed 0 active features for token_type='newline' (skipped 7 inactive)

Total features processed: 3 CSV, 3 JSONL


In [20]:
# Save results
print("Saving results...")

# Save JSONL results
jsonl_results.sort(key=lambda x: x['feature_id'])  # Sort by feature_id
with open(PROMPT_OUTPUT_FILE, 'a') as f:
    for record in jsonl_results:
        f.write(json.dumps(record) + '\n')

print(f"JSONL results saved to {PROMPT_OUTPUT_FILE}")

# Show preview and summary



print(f"\n✓ Analysis complete!")

Saving results...
JSONL results saved to ./results/3_personal_general/3_personal_general_prompts.jsonl

Preview of CSV data:
 feature_id  activation_mean  activation_max  activation_min  num_prompts chat_desc pt_desc type                 source     token                                                                                          link
      27476         0.310405        0.338905        0.288229            5                        llama_trainer1_layer15      asst https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&trainer=1&fids=27476
      59035         0.362718        0.435948        0.286962           14                        llama_trainer1_layer15 endheader https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&trainer=1&fids=59035
      47776         0.400181        0.400181        0.400181            1                        llama_trainer1_layer15 endheader https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&