# Feature Analysis: Diffing Base and Instruct

This notebook analyzes which SAE features increase in activations between base and chat models.

In [32]:
import csv
import json
import torch
import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_dataset
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from sae_lens import SAE
from tqdm.auto import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


## Configs

In [41]:
# =============================================================================
# MODEL SELECTION - Change this to switch between models
# =============================================================================
MODEL_TYPE = "llama"  # Options: "qwen" or "llama"
SAE_LAYER = 15
SAE_TRAINER = "32x"

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_FILE = f"./results/4_diffing/3_personal_general.csv"
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

PROMPT_OUTPUT_FILE = f"./results/4_diffing/3_personal_general_prompts.jsonl"
os.makedirs(os.path.dirname(PROMPT_OUTPUT_FILE), exist_ok=True)

# =============================================================================
# FEATURE DASHBOARD URL - Global variable for links
# =============================================================================
FEATURE_DASHBOARD_BASE_URL = "https://completely-touched-platypus.ngrok-free.app/"

# =============================================================================
# AUTO-CONFIGURED SETTINGS BASED ON MODEL TYPE
# =============================================================================
# if MODEL_TYPE == "qwen":
#     MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
#     SAE_RELEASE = "andyrdt/saes-qwen2.5-7b-instruct"
#     ASSISTANT_HEADER = "<|im_start|>assistant"
#     TOKEN_OFFSETS = {"asst": -1, "newline": 0}
#     SAE_BASE_PATH = "/workspace/sae/qwen-2.5-7b-instruct/saes"
    
# if MODEL_TYPE == "llama":
#     MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
#     SAE_RELEASE = "andyrdt/saes-llama-3.1-8b-instruct"
#     ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>"
#     TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}
#     SAE_BASE_PATH = "/workspace/sae/llama-3.1-8b-instruct/saes"

if MODEL_TYPE == "llama":
    BASE_MODEL_NAME = "meta-llama/Llama-3.1-8B"
    CHAT_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
    SAE_RELEASE = "fnlp/Llama3_1-8B-Base-LXR-32x"
    ASSISTANT_HEADER = "<|start_header_id|>assistant<|end_header_id|>"
    TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}
    SAE_BASE_PATH = "/workspace/sae/llama-3.1-8b/saes"
    
else:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Use 'qwen' or 'llama'")

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Data paths
PROMPTS_HF = "lmsys/lmsys-chat-1m"
N_PROMPTS = 1000
SEED = 42
PROMPTS_PATH = f"/workspace/data/{PROMPTS_HF.split('/')[-1]}/chat_{N_PROMPTS}.jsonl"
os.makedirs(os.path.dirname(PROMPTS_PATH), exist_ok=True)

# Processing parameters
BATCH_SIZE = 8
MAX_LENGTH = 512
TOP_FEATURES = 100

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model: {BASE_MODEL_NAME}")
print(f"  SAE: {SAE_RELEASE}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Available token types: {list(TOKEN_OFFSETS.keys())}")
print(f"  Assistant header: {ASSISTANT_HEADER}")
print(f"  Output file: {OUTPUT_FILE}")
print(f"  Prompt output file: {PROMPT_OUTPUT_FILE}")

Configuration Summary:
  Model: meta-llama/Llama-3.1-8B
  SAE: fnlp/Llama3_1-8B-Base-LXR-32x
  SAE Layer: 15, Trainer: 32x
  Available token types: ['asst', 'endheader', 'newline']
  Assistant header: <|start_header_id|>assistant<|end_header_id|>
  Output file: ./results/4_diffing/3_personal_general.csv
  Prompt output file: ./results/4_diffing/3_personal_general_prompts.jsonl


## Load Data

In [None]:
def load_lmsys_prompts(prompts_path: str, prompts_hf: str, n_prompts: int, seed: int) -> pd.DataFrame:
    # Check if prompts_path exists
    if os.path.exists(prompts_path):
        print(f"Prompts already exist at {prompts_path}")
        return pd.read_json(prompts_path, lines=True)
    else:
        print(f"Prompts do not exist at {prompts_path}. Loading from {prompts_hf}...")
        dataset = load_dataset(prompts_hf)
        dataset = dataset['train'].shuffle(seed=seed).select(range(n_prompts))
        df = dataset.to_pandas()

        # Extract the prompt from the first conversation item
        df['prompt'] = df['conversation'].apply(lambda x: x[0]['content'])

        # Only keep some columns
        df = df[['conversation_id', 'prompt', 'redacted', 'language']]

        # Save to .jsonl file
        df.to_json(prompts_path, orient='records', lines=True)
        return df

prompts_df = load_lmsys_prompts(PROMPTS_PATH, PROMPTS_HF, N_PROMPTS, SEED)
print(f"Loaded {prompts_df.shape[0]} prompts")
print(f"Prompt keys: {prompts_df.keys()}")


Prompts already exist at /workspace/data/lmsys-chat-1m/chat_1000.jsonl
Loaded 1000 prompts
Prompt keys: Index(['conversation_id', 'prompt', 'redacted', 'language'], dtype='object')


## Load Model and SAE

In [42]:
# Load tokenizer (from chat model)
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

Tokenizer loaded: PreTrainedTokenizerFast


In [38]:
# Load model
device_map_value = device.index if device.type == 'cuda' and device.index is not None else str(device)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": device_map_value}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Model loaded: LlamaForCausalLM
Model device: cuda:0


In [34]:
# Load SAE
def load_llamascope_sae(SAE_PATH, SAE_LAYER, SAE_TRAINER):
    """Load llamaScope SAE from Hugging Face."""

    # Check if SAE file exist locally
    ae_file_path = os.path.join(SAE_PATH, "sae_weights.safetensors")

    if os.path.exists(ae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        sae = SAE.load_from_disk(SAE_PATH)
    else:
        print(f"SAE not found locally, downloading from HF via sae_lens...")
        os.makedirs(os.path.dirname(SAE_PATH), exist_ok=True)

        sae, _, sparsity = SAE.from_pretrained(
            release = f"llama_scope_lxr_{SAE_TRAINER}", # see other options in sae_lens/pretrained_saes.yaml
            sae_id = f"l{SAE_LAYER}r_{SAE_TRAINER}", # won't always be a hook point
            device = "cuda"
        )
        sae.save_model(SAE_PATH, sparsity)

    return sae

sae = load_llamascope_sae(SAE_PATH, SAE_LAYER, SAE_TRAINER)
print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

SAE loaded with 131072 features
SAE device: cuda:0


## Activation Extraction Functions

In [43]:
class StopForward(Exception):
    """Exception to stop forward pass after target layer."""
    pass

def find_assistant_position(input_ids: torch.Tensor, attention_mask: torch.Tensor, 
                          assistant_header: str, token_offset: int, tokenizer, device) -> int:
    """Find the position of the assistant token based on the given offset."""
    # Find assistant header position
    assistant_tokens = tokenizer.encode(assistant_header, add_special_tokens=False)
    
    # Find where assistant section starts
    assistant_pos = None
    for k in range(len(input_ids) - len(assistant_tokens) + 1):
        if torch.equal(input_ids[k:k+len(assistant_tokens)], torch.tensor(assistant_tokens).to(device)):
            assistant_pos = k + len(assistant_tokens) + token_offset
            break
    
    if assistant_pos is None:
        # Fallback to last non-padding token
        assistant_pos = attention_mask.sum().item() - 1
    
    # Ensure position is within bounds
    max_pos = attention_mask.sum().item() - 1
    assistant_pos = min(assistant_pos, max_pos)
    assistant_pos = max(assistant_pos, 0)
    
    return int(assistant_pos)

@torch.no_grad()
def extract_activations_and_metadata(prompts: List[str], layer_idx: int) -> Tuple[torch.Tensor, List[Dict], List[str]]:
    """Extract activations and prepare metadata for all prompts."""
    all_activations = []
    all_metadata = []
    formatted_prompts_list = []
    
    # Get target layer
    target_layer = model.model.layers[layer_idx]
    
    # Process in batches
    for i in tqdm(range(0, len(prompts), BATCH_SIZE), desc="Processing batches"):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        
        # Format prompts as chat messages
        formatted_prompts = []
        for prompt in batch_prompts:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            formatted_prompts.append(formatted_prompt)
        
        formatted_prompts_list.extend(formatted_prompts)
        
        # Tokenize batch
        batch_inputs = tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        )
        
        # Move to device
        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        
        # Hook to capture activations
        activations = None
        
        def hook_fn(module, input, output):
            nonlocal activations
            activations = output[0] if isinstance(output, tuple) else output
            raise StopForward()
        
        # Register hook
        handle = target_layer.register_forward_hook(hook_fn)
        
        try:
            _ = model(**batch_inputs)
        except StopForward:
            pass
        finally:
            handle.remove()
        
        # For each prompt in the batch, calculate positions for all token types
        for j, formatted_prompt in enumerate(formatted_prompts):
            attention_mask = batch_inputs["attention_mask"][j]
            input_ids = batch_inputs["input_ids"][j]
            
            # Calculate positions for all token types
            positions = {}
            for token_type, token_offset in TOKEN_OFFSETS.items():
                positions[token_type] = find_assistant_position(
                    input_ids, attention_mask, ASSISTANT_HEADER, token_offset, tokenizer, device
                )
            
            # Store the full activation sequence and metadata
            all_activations.append(activations[j].cpu())  # [seq_len, hidden_dim]
            all_metadata.append({
                'prompt_idx': i + j,
                'positions': positions,
                'attention_mask': attention_mask.cpu(),
                'input_ids': input_ids.cpu()
            })
    
    # Find the maximum sequence length across all activations
    max_seq_len = max(act.shape[0] for act in all_activations)
    hidden_dim = all_activations[0].shape[1]
    
    # Pad all activations to the same length
    padded_activations = []
    for act in all_activations:
        if act.shape[0] < max_seq_len:
            padding = torch.zeros(max_seq_len - act.shape[0], hidden_dim)
            padded_act = torch.cat([act, padding], dim=0)
        else:
            padded_act = act
        padded_activations.append(padded_act)
    
    return torch.stack(padded_activations, dim=0), all_metadata, formatted_prompts_list

@torch.no_grad()
def extract_token_activations(full_activations: torch.Tensor, metadata: List[Dict]) -> Dict[str, torch.Tensor]:
    """Extract activations for specific token positions from full sequence activations."""
    results = {}
    
    # Initialize results for each token type
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = []
    
    # Extract activations for each token type
    for i, meta in enumerate(metadata):
        for token_type, position in meta['positions'].items():
            # Extract activation at the specific position
            activation = full_activations[i, position, :]  # [hidden_dim]
            results[token_type].append(activation)
    
    # Convert lists to tensors
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = torch.stack(results[token_type], dim=0)
    
    return results

print("Activation extraction functions defined")

Activation extraction functions defined


## Extract Activations

In [44]:
# Extract activations for all positions first, then extract specific token positions
print("Extracting activations for all positions...")
full_activations, metadata, formatted_prompts = extract_activations_and_metadata(prompts_df['prompt'].tolist(), LAYER_INDEX)
print(f"Full activations shape: {full_activations.shape}")

# Extract activations for all token types
print("\nExtracting activations for specific token positions...")
token_activations = extract_token_activations(full_activations, metadata)

for token_type, activations in token_activations.items():
    print(f"Token type '{token_type}' activations shape: {activations.shape}")

Extracting activations for all positions...


Processing batches:   0%|          | 0/125 [00:00<?, ?it/s]

Full activations shape: torch.Size([1000, 512, 4096])

Extracting activations for specific token positions...
Token type 'asst' activations shape: torch.Size([1000, 4096])
Token type 'endheader' activations shape: torch.Size([1000, 4096])
Token type 'newline' activations shape: torch.Size([1000, 4096])


## Apply SAE to Get Feature Activations

In [46]:
@torch.no_grad()
def get_sae_features_batched(activations: torch.Tensor) -> torch.Tensor:
    """Apply SAE to get feature activations with proper batching."""
    activations = activations.to(device)
    
    # Process in batches to avoid memory issues
    feature_activations = []
    
    for i in range(0, activations.shape[0], BATCH_SIZE):
        batch = activations[i:i+BATCH_SIZE]
        features = sae.encode(batch)  # [batch, num_features]
        feature_activations.append(features.cpu())
    
    return torch.cat(feature_activations, dim=0)

@torch.no_grad()
def get_sae_features_all_positions(full_activations: torch.Tensor) -> torch.Tensor:
    """Pre-compute SAE features for ALL positions at once for optimization."""
    print(f"Processing {full_activations.shape[0]} prompts with max {full_activations.shape[1]} tokens each...")
    
    # Reshape to [total_positions, hidden_dim]
    total_positions = full_activations.shape[0] * full_activations.shape[1]
    reshaped_activations = full_activations.view(total_positions, -1)
    
    # Apply SAE to all positions
    full_sae_features = get_sae_features_batched(reshaped_activations)
    
    # Reshape back to [num_prompts, seq_len, num_features]
    full_sae_features = full_sae_features.view(full_activations.shape[0], full_activations.shape[1], -1)
    
    print(f"Full SAE features shape: {full_sae_features.shape}")
    print(f"✓ SAE features pre-computed for all positions")
    
    return full_sae_features

# Get SAE feature activations for specific token positions
print("Computing SAE features for specific token positions...")
token_features = {}

for token_type, activations in token_activations.items():
    print(f"Processing SAE features for token type '{token_type}'...")
    features = get_sae_features_batched(activations)
    token_features[token_type] = features
    print(f"Features shape for '{token_type}': {features.shape}")

print(f"\nCompleted SAE feature extraction for {len(token_features)} token types")

# Uncomment the lines below if you need all-position features for optimization
# print("\nOptimization: Pre-computing SAE features for all positions...")
# full_sae_features = get_sae_features_all_positions(full_activations)

Computing SAE features for specific token positions...
Processing SAE features for token type 'asst'...
Features shape for 'asst': torch.Size([1000, 131072])
Processing SAE features for token type 'endheader'...
Features shape for 'endheader': torch.Size([1000, 131072])
Processing SAE features for token type 'newline'...
Features shape for 'newline': torch.Size([1000, 131072])

Completed SAE feature extraction for 3 token types


## Analysis and Save Results

In [19]:
csv_results = []
source_name = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_base"

print(f"Processing results")

# Process each token type
for token_type in TOKEN_OFFSETS.keys():
    print(f"\nProcessing token type: {token_type}")
    
    # Process each feature
    features_processed = 0
    features_skipped = 0
    
    for idx, (feature_idx, feature_id) in enumerate(zip(feature_indices, filtered_features_df['feature_id'])):
        feature_id = int(feature_id)
        activations = feature_activations[:, idx]  # [num_prompts]
        
        # Calculate statistics only on active features (activation > 0)
        active_mask = activations > 0
        active_activations = activations[active_mask]
        
        # Skip features that aren't active on any prompt
        if len(active_activations) == 0:
            features_skipped += 1
            continue
        
        # CSV statistics
        activation_mean = float(active_activations.mean())
        activation_max = float(active_activations.max())
        activation_min = float(active_activations.min())
        num_prompts = len(active_activations)
        
        # Add to CSV results
        csv_result = {
            'feature_id': feature_id,
            'activation_mean': activation_mean,
            'activation_max': activation_max,
            'activation_min': activation_min,
            'num_prompts': num_prompts,
            'source': source_name,
            'token': token_type,
        }
        csv_results.append(csv_result)

        
        features_processed += 1
    
    print(f"Processed {features_processed} active features for token_type='{token_type}' (skipped {features_skipped} inactive)")

print(f"\nTotal features processed: {len(csv_results)} CSV, {len(jsonl_results)} JSONL")

Loaded 92 total features from ./results/1_personal/only_personal.csv
Processing results for source: llama_trainer1_layer15

Processing token type: asst
Found 2 features for token_type='asst', source='llama_trainer1_layer15'
Processed 1 active features for token_type='asst' (skipped 1 inactive)

Processing token type: endheader
Found 6 features for token_type='endheader', source='llama_trainer1_layer15'
Processed 2 active features for token_type='endheader' (skipped 4 inactive)

Processing token type: newline
Found 7 features for token_type='newline', source='llama_trainer1_layer15'
Processed 0 active features for token_type='newline' (skipped 7 inactive)

Total features processed: 3 CSV, 3 JSONL


In [20]:
# Save results
print("Saving results...")

# Convert CSV results to DataFrame
column_order = ['feature_id', 'activation_mean', 'activation_max', 'activation_min', 
                'num_prompts', 'source', 'token']
csv_df = pd.DataFrame(csv_results)[column_order]

# Save CSV results
if os.path.exists(OUTPUT_FILE):
    csv_df.to_csv(OUTPUT_FILE, mode='a', header=False, index=False)
    print(f"CSV results appended to existing file: {OUTPUT_FILE}")
else:
    csv_df.to_csv(OUTPUT_FILE, index=False)
    print(f"CSV results saved to new file: {OUTPUT_FILE}")

# Show preview and summary
if len(csv_df) > 0:
    print(f"\nPreview of CSV data:")
    print(csv_df.head(10).to_string(index=False))
    
    print(f"\nSummary by token type:")
    summary = csv_df.groupby('token').agg({
        'activation_mean': ['count', 'mean', 'max'],
        'activation_max': 'max',
        'num_prompts': ['mean', 'max'],
        'feature_id': 'nunique'
    }).round(4)
    print(summary)

print(f"\n✓ Analysis complete!")

Saving results...
JSONL results saved to ./results/3_personal_general/3_personal_general_prompts.jsonl

Preview of CSV data:
 feature_id  activation_mean  activation_max  activation_min  num_prompts chat_desc pt_desc type                 source     token                                                                                          link
      27476         0.310405        0.338905        0.288229            5                        llama_trainer1_layer15      asst https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&trainer=1&fids=27476
      59035         0.362718        0.435948        0.286962           14                        llama_trainer1_layer15 endheader https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&trainer=1&fids=59035
      47776         0.400181        0.400181        0.400181            1                        llama_trainer1_layer15 endheader https://completely-touched-platypus.ngrok-free.app/?model=llama&layer=15&