# Feature Analysis: Diffing Base and Instruct

This notebook analyzes which SAE features increase in activations between base and chat models.

In [None]:
import csv
import json
import torch
import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
from typing import List, Dict, Tuple
from datasets import load_dataset
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download
from sae_lens import SAE
from tqdm.auto import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Configs

In [None]:
from dataclasses import dataclass
from typing import Dict, Tuple, Optional

@dataclass
class ModelConfig:
    """Configuration for model-specific settings"""
    base_model_name: str
    chat_model_name: str
    hf_release: str  # Reference only - actual loading uses saelens_release/sae_id
    assistant_header: str
    token_offsets: Dict[str, int]
    sae_base_path: str
    saelens_release: str  # Template for sae_lens release parameter
    sae_id_template: str  # Template for sae_lens sae_id parameter
    
    def get_sae_params(self, sae_layer: int, sae_trainer: str) -> Tuple[str, str]:
        """
        Generate SAE lens release and sae_id parameters.
        
        Args:
            sae_layer: Layer number for the SAE
            sae_trainer: Trainer identifier for the SAE
            
        Returns:
            Tuple of (release, sae_id) for sae_lens.SAE.from_pretrained()
        """
        if self.saelens_release == "llama_scope_lxr_{trainer}":
            release = self.saelens_release.format(trainer=sae_trainer)
            sae_id = self.sae_id_template.format(layer=sae_layer, trainer=sae_trainer)
        elif self.saelens_release == "gemma-scope-9b-pt-res":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"
            l0_value = parts[2]  # "34"
            
            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width, l0=l0_value)
        elif self.saelens_release == "gemma-scope-9b-pt-res-canonical":
            # Parse SAE_TRAINER "131k-l0-34" into components for Gemma
            parts = sae_trainer.split("-")
            width = parts[0]  # "131k"

            release = self.saelens_release
            sae_id = self.sae_id_template.format(layer=sae_layer, width=width)
        else:
            raise ValueError(f"Unknown SAE lens release template: {self.saelens_release}")
        
        return release, sae_id

# Model configurations
MODEL_CONFIGS = {
    "llama": ModelConfig(
        base_model_name="meta-llama/Llama-3.1-8B",
        chat_model_name="meta-llama/Llama-3.1-8B-Instruct",
        hf_release="fnlp/Llama3_1-8B-Base-LXR-32x",
        assistant_header="<|start_header_id|>assistant<|end_header_id|>",
        token_offsets={"asst": -2, "endheader": -1, "newline": 0},
        sae_base_path="/workspace/sae/llama-3.1-8b/saes",
        saelens_release="llama_scope_lxr_{trainer}",
        sae_id_template="l{layer}r_{trainer}"
    ),
    "gemma": ModelConfig(
        base_model_name="google/gemma-2-9b",
        chat_model_name="google/gemma-2-9b-it",
        hf_release="google/gemma-scope-9b-pt-res/layer_{layer}/width_{width}/average_l0_{l0}",
        assistant_header="<start_of_turn>model",
        token_offsets={"model": -1, "newline": 0},
        sae_base_path="/workspace/sae/gemma-2-9b/saes",
        saelens_release="gemma-scope-9b-pt-res-canonical",
        sae_id_template="layer_{layer}/width_{width}/canonical"
    )
}

# =============================================================================
# MODEL SELECTION - Change this to switch between models
# =============================================================================
MODEL_TYPE = "llama"  # Options: "gemma" or "llama"
MODEL_VER = "chat"
SAE_LAYER = 11
SAE_TRAINER = "32x"
N_PROMPTS = 1000

# =============================================================================
# CONFIGURATION SETUP
# =============================================================================
if MODEL_TYPE not in MODEL_CONFIGS:
    raise ValueError(f"Unknown MODEL_TYPE: {MODEL_TYPE}. Available: {list(MODEL_CONFIGS.keys())}")

config = MODEL_CONFIGS[MODEL_TYPE]

# Set model name based on version
if MODEL_VER == "chat":
    MODEL_NAME = config.chat_model_name
elif MODEL_VER == "base":
    MODEL_NAME = config.base_model_name
else:
    raise ValueError(f"Unknown MODEL_VER: {MODEL_VER}. Use 'chat' or 'base'")

# Always use chat model for tokenizer (has chat template)
CHAT_MODEL_NAME = config.chat_model_name

# Set up derived configurations
ASSISTANT_HEADER = config.assistant_header
TOKEN_OFFSETS = config.token_offsets
SAE_BASE_PATH = config.sae_base_path

# =============================================================================
# OUTPUT FILE CONFIGURATION
# =============================================================================
OUTPUT_FILE = f"/workspace/results/4_diffing/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/{N_PROMPTS}_prompts/{MODEL_VER}.pt"
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)

# =============================================================================
# DERIVED CONFIGURATIONS
# =============================================================================
SAE_PATH = f"{SAE_BASE_PATH}/resid_post_layer_{SAE_LAYER}/trainer_{SAE_TRAINER}"
LAYER_INDEX = SAE_LAYER

# Data paths
PROMPTS_HF = "lmsys/lmsys-chat-1m"
SEED = 42
PROMPTS_PATH = f"/workspace/data/{PROMPTS_HF.split('/')[-1]}/chat_{N_PROMPTS}.jsonl"
# PROMPTS_PATH = "./prompts/personal_40/personal.jsonl"
os.makedirs(os.path.dirname(PROMPTS_PATH), exist_ok=True)

# Processing parameters
BATCH_SIZE = 32
MAX_LENGTH = 512

# =============================================================================
# SUMMARY
# =============================================================================
print(f"Configuration Summary:")
print(f"  Model Type: {MODEL_TYPE}")
print(f"  Model to load: {MODEL_NAME}")
print(f"  Tokenizer (chat): {CHAT_MODEL_NAME}")
print(f"  SAE Layer: {SAE_LAYER}, Trainer: {SAE_TRAINER}")
print(f"  Available token types: {list(TOKEN_OFFSETS.keys())}")
print(f"  Assistant header: {ASSISTANT_HEADER}")
print(f"  Output file: {OUTPUT_FILE}")

## Load Data

In [None]:
def load_lmsys_prompts(prompts_path: str, prompts_hf: str, n_prompts: int, seed: int) -> pd.DataFrame:
    # Check if prompts_path exists
    if os.path.exists(prompts_path):
        print(f"Prompts already exist at {prompts_path}")
        return pd.read_json(prompts_path, lines=True)
    else:
        print(f"Prompts do not exist at {prompts_path}. Loading from {prompts_hf}...")
        dataset = load_dataset(prompts_hf)
        dataset = dataset['train'].shuffle(seed=seed).select(range(n_prompts))
        df = dataset.to_pandas()

        # Extract the prompt from the first conversation item
        df['prompt'] = df['conversation'].apply(lambda x: x[0]['content'])

        # Only keep some columns
        df = df[['conversation_id', 'prompt', 'redacted', 'language']]

        # Save to .jsonl file
        df.to_json(prompts_path, orient='records', lines=True)
        return df

def load_prompts_from_jsonl(file_path: str) -> pd.DataFrame:
    """Load prompts from a JSONL file. Expects each line to have a 'content' field."""
    prompts = []
    with open(file_path, 'r') as f:
        for line in f:
            data = json.loads(line.strip())
            prompts.append(data['content'])
    return pd.DataFrame(prompts, columns=['prompt'])

prompts_df = load_lmsys_prompts(PROMPTS_PATH, PROMPTS_HF, N_PROMPTS, SEED)
# prompts_df = load_prompts_from_jsonl(PROMPTS_PATH)
print(f"Loaded {prompts_df.shape[0]} prompts")
print(f"Prompt keys: {prompts_df.keys()}")


## Load Model and SAE

In [None]:
# Load tokenizer (from chat model)
tokenizer = AutoTokenizer.from_pretrained(CHAT_MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print(f"Tokenizer loaded: {tokenizer.__class__.__name__}")

In [5]:

# # Test chat template formatting
# test_messages = [{"role": "user", "content": "What's it like to be you?"}]
# formatted_test = tokenizer.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
# print(f"\nChat template test:")
# print(f"Original: What's it like to be you?")
# print(f"Formatted: {repr(formatted_test)}")
# print(f"Formatted (readable):\n{formatted_test}")

# # Test tokenization of assistant header to understand positioning
# print(f"\n" + "="*60)
# print("ASSISTANT HEADER TOKENIZATION ANALYSIS")
# print("="*60)

# assistant_tokens = tokenizer.encode(ASSISTANT_HEADER, add_special_tokens=False)
# assistant_token_texts = [tokenizer.decode([token]) for token in assistant_tokens]

# print(f"Assistant header: {ASSISTANT_HEADER}")
# print(f"Number of tokens: {len(assistant_tokens)}")
# print(f"Token IDs: {assistant_tokens}")
# print(f"Individual tokens: {assistant_token_texts}")

# # Test with a full formatted prompt
# full_tokens = tokenizer.encode(formatted_test, add_special_tokens=False)
# full_token_texts = [tokenizer.decode([token]) for token in full_tokens]

# print(f"\nFull prompt tokens: {len(full_tokens)}")
# print("All tokens with positions:")
# for i, token_text in enumerate(full_token_texts):
#     print(f"  {i:2d}: '{token_text}'")

# # Find where assistant header appears in full prompt
# assistant_start_pos = None
# for i in range(len(full_tokens) - len(assistant_tokens) + 1):
#     if full_tokens[i:i+len(assistant_tokens)] == assistant_tokens:
#         assistant_start_pos = i
#         break

# if assistant_start_pos is not None:
#     assistant_end_pos = assistant_start_pos + len(assistant_tokens) - 1
#     print(f"\nAssistant header found at positions {assistant_start_pos} to {assistant_end_pos}")
#     print(f"Assistant header tokens: {full_token_texts[assistant_start_pos:assistant_end_pos+1]}")
    
#     for t_t, t_o in TOKEN_OFFSETS.items():
#         # Show what the extraction function will actually extract
#         extraction_pos = assistant_start_pos + len(assistant_tokens) + t_o
#         print(f"\nExtraction calculation:")
#         print(f"  assistant_start_pos: {assistant_start_pos}")
#         print(f"  + len(assistant_tokens): {len(assistant_tokens)}")  
#         print(f"  + TOKEN_OFFSET ('{t_t}'): {t_o}")
#         print(f"  = extraction_pos: {extraction_pos}")
        
#         if 0 <= extraction_pos < len(full_token_texts):
#             print(f"✓ Token at extraction position {extraction_pos}: '{full_token_texts[extraction_pos]}'")
#         else:
#             print(f"❌ Extraction position {extraction_pos} is out of bounds (valid range: 0-{len(full_token_texts)-1})")
# else:
#     print("❌ Assistant header not found in full prompt")

In [None]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": 0}
)
model.eval()

print(f"Model loaded: {model.__class__.__name__}")
print(f"Model device: {next(model.parameters()).device}")

In [None]:
def load_sae(config: ModelConfig, sae_path: str, sae_layer: int, sae_trainer: str) -> SAE:
    """
    Unified SAE loading function that handles both Llama and Gemma models.
    
    Args:
        config: ModelConfig object containing model-specific settings
        sae_path: Local path to store/load SAE files
        sae_layer: Layer number for the SAE
        sae_trainer: Trainer identifier for the SAE
    
    Returns:
        SAE: Loaded SAE model
    """
    # Check if SAE file exists locally
    ae_file_path = os.path.join(sae_path, "sae_weights.safetensors")
    
    if os.path.exists(ae_file_path):
        print(f"✓ Found SAE files at: {os.path.dirname(ae_file_path)}")
        sae = SAE.load_from_disk(sae_path)
        return sae
    
    print(f"SAE not found locally, downloading from HF via sae_lens...")
    os.makedirs(os.path.dirname(sae_path), exist_ok=True)
    
    # Get SAE parameters from config
    release, sae_id = config.get_sae_params(sae_layer, sae_trainer)
    print(f"Loading SAE with release='{release}', sae_id='{sae_id}'")
    
    # Load the SAE using sae_lens
    sae, _, sparsity = SAE.from_pretrained(
        release=release,
        sae_id=sae_id,
        device="cuda" # Hardcoded because it wants a string
    )
    
    # Save the SAE locally for future use
    sae.save_model(sae_path, sparsity)
    return sae

# Load SAE using the unified function
sae = load_sae(config, SAE_PATH, SAE_LAYER, SAE_TRAINER)
sae = sae.to(device)  # Move SAE to GPU

print(f"SAE loaded with {sae.cfg.d_sae} features")
print(f"SAE device: {next(sae.parameters()).device}")

## Activation Extraction Functions

In [None]:
class StopForward(Exception):
    """Exception to stop forward pass after target layer."""
    pass

def find_assistant_position(input_ids: torch.Tensor, attention_mask: torch.Tensor, 
                          assistant_header: str, token_offset: int, tokenizer, device) -> int:
    """Find the position of the assistant token based on the given offset."""
    # Find assistant header position
    assistant_tokens = tokenizer.encode(assistant_header, add_special_tokens=False)
    
    # Find where assistant section starts
    assistant_pos = None
    for k in range(len(input_ids) - len(assistant_tokens) + 1):
        if torch.equal(input_ids[k:k+len(assistant_tokens)], torch.tensor(assistant_tokens).to(device)):
            assistant_pos = k + len(assistant_tokens) + token_offset
            break
    
    if assistant_pos is None:
        # Fallback to last non-padding token
        assistant_pos = attention_mask.sum().item() - 1
    
    # Ensure position is within bounds
    max_pos = attention_mask.sum().item() - 1
    assistant_pos = min(assistant_pos, max_pos)
    assistant_pos = max(assistant_pos, 0)
    
    return int(assistant_pos)

@torch.no_grad()
def extract_activations_and_metadata(prompts: List[str], layer_idx: int) -> Tuple[torch.Tensor, List[Dict], List[str]]:
    """Extract activations and prepare metadata for all prompts."""
    all_activations = []
    all_metadata = []
    formatted_prompts_list = []
    
    # Get target layer
    target_layer = model.model.layers[layer_idx]
    
    # Process in batches
    for i in tqdm(range(0, len(prompts), BATCH_SIZE), desc="Processing batches"):
        batch_prompts = prompts[i:i+BATCH_SIZE]
        
        # Format prompts as chat messages
        formatted_prompts = []
        for prompt in batch_prompts:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            formatted_prompts.append(formatted_prompt)
        
        formatted_prompts_list.extend(formatted_prompts)
        
        # Tokenize batch
        batch_inputs = tokenizer(
            formatted_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH
        )
        
        # Move to device
        batch_inputs = {k: v.to(device) for k, v in batch_inputs.items()}
        
        # Hook to capture activations
        activations = None
        
        def hook_fn(module, input, output):
            nonlocal activations
            activations = output[0] if isinstance(output, tuple) else output
            raise StopForward()
        
        # Register hook
        handle = target_layer.register_forward_hook(hook_fn)
        
        try:
            _ = model(**batch_inputs)
        except StopForward:
            pass
        finally:
            handle.remove()
        
        # For each prompt in the batch, calculate positions for all token types
        for j, formatted_prompt in enumerate(formatted_prompts):
            attention_mask = batch_inputs["attention_mask"][j]
            input_ids = batch_inputs["input_ids"][j]
            
            # Calculate positions for all token types
            positions = {}
            for token_type, token_offset in TOKEN_OFFSETS.items():
                positions[token_type] = find_assistant_position(
                    input_ids, attention_mask, ASSISTANT_HEADER, token_offset, tokenizer, device
                )
            
            # Store the full activation sequence and metadata
            all_activations.append(activations[j].cpu())  # [seq_len, hidden_dim]
            all_metadata.append({
                'prompt_idx': i + j,
                'positions': positions,
                'attention_mask': attention_mask.cpu(),
                'input_ids': input_ids.cpu()
            })
    
    # Find the maximum sequence length across all activations
    max_seq_len = max(act.shape[0] for act in all_activations)
    hidden_dim = all_activations[0].shape[1]
    
    # Pad all activations to the same length
    padded_activations = []
    for act in all_activations:
        if act.shape[0] < max_seq_len:
            padding = torch.zeros(max_seq_len - act.shape[0], hidden_dim)
            padded_act = torch.cat([act, padding], dim=0)
        else:
            padded_act = act
        padded_activations.append(padded_act)
    
    return torch.stack(padded_activations, dim=0), all_metadata, formatted_prompts_list

@torch.no_grad()
def extract_token_activations(full_activations: torch.Tensor, metadata: List[Dict]) -> Dict[str, torch.Tensor]:
    """Extract activations for specific token positions from full sequence activations."""
    results = {}
    
    # Initialize results for each token type
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = []
    
    # Extract activations for each token type
    for i, meta in enumerate(metadata):
        for token_type, position in meta['positions'].items():
            # Extract activation at the specific position
            activation = full_activations[i, position, :]  # [hidden_dim]
            results[token_type].append(activation)
    
    # Convert lists to tensors
    for token_type in TOKEN_OFFSETS.keys():
        results[token_type] = torch.stack(results[token_type], dim=0)
    
    return results

print("Activation extraction functions defined")

## Extract Activations

In [None]:
# Extract activations for all positions first, then extract specific token positions
print("Extracting activations for all positions...")
full_activations, metadata, formatted_prompts = extract_activations_and_metadata(prompts_df['prompt'].tolist(), LAYER_INDEX)
print(f"Full activations shape: {full_activations.shape}")

# Extract activations for all token types
print("\nExtracting activations for specific token positions...")
token_activations = extract_token_activations(full_activations, metadata)

for token_type, activations in token_activations.items():
    print(f"Token type '{token_type}' activations shape: {activations.shape}")

## Apply SAE to Get Feature Activations

In [None]:
@torch.no_grad()
def get_sae_features_batched(activations: torch.Tensor) -> torch.Tensor:
    """Apply SAE to get feature activations with proper batching."""
    activations = activations.to(device)
    
    # Process in batches to avoid memory issues
    feature_activations = []
    
    for i in range(0, activations.shape[0], BATCH_SIZE):
        batch = activations[i:i+BATCH_SIZE]
        features = sae.encode(batch)  # [batch, num_features]
        feature_activations.append(features.cpu())
    
    return torch.cat(feature_activations, dim=0)

@torch.no_grad()
def get_sae_features_all_positions(full_activations: torch.Tensor) -> torch.Tensor:
    """Pre-compute SAE features for ALL positions at once for optimization."""
    print(f"Processing {full_activations.shape[0]} prompts with max {full_activations.shape[1]} tokens each...")
    
    # Reshape to [total_positions, hidden_dim]
    total_positions = full_activations.shape[0] * full_activations.shape[1]
    reshaped_activations = full_activations.view(total_positions, -1)
    
    # Apply SAE to all positions
    full_sae_features = get_sae_features_batched(reshaped_activations)
    
    # Reshape back to [num_prompts, seq_len, num_features]
    full_sae_features = full_sae_features.view(full_activations.shape[0], full_activations.shape[1], -1)
    
    print(f"Full SAE features shape: {full_sae_features.shape}")
    print(f"✓ SAE features pre-computed for all positions")
    
    return full_sae_features

# Get SAE feature activations for specific token positions
print("Computing SAE features for specific token positions...")
token_features = {}

for token_type, activations in token_activations.items():
    print(f"Processing SAE features for token type '{token_type}'...")
    features = get_sae_features_batched(activations)
    token_features[token_type] = features
    print(f"Features shape for '{token_type}': {features.shape}")

print(f"\nCompleted SAE feature extraction for {len(token_features)} token types")

# Uncomment the lines below if you need all-position features for optimization
# print("\nOptimization: Pre-computing SAE features for all positions...")
# full_sae_features = get_sae_features_all_positions(full_activations)

## Analysis and Save Results

In [None]:
def save_as_csv():
    """Save results as CSV format (slower but human readable)"""
    csv_results = []
    source_name = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_{MODEL_VER}"
    
    print(f"Processing results for CSV format, source: {source_name}")
    
    # Process each token type
    for token_type in TOKEN_OFFSETS.keys():
        print(f"\nProcessing token type: {token_type}")
        
        # Get features tensor for this token type: [num_prompts, num_features]
        features_tensor = token_features[token_type]
        
        # Convert to numpy for easier processing (handle BFloat16)
        features_np = features_tensor.float().numpy()
        
        print(f"Processing all {features_np.shape[1]} features for token_type='{token_type}'")
        
        # Process ALL features (not just active ones)
        for feature_idx in range(features_np.shape[1]):
            feature_activations = features_np[:, feature_idx]  # [num_prompts]
            
            # Split into active and inactive
            active_mask = feature_activations > 0
            active_activations = feature_activations[active_mask]
            
            # Calculate essential statistics only
            all_mean = float(feature_activations.mean())
            all_std = float(feature_activations.std())
            max_activation = float(feature_activations.max())  # same whether active or all
            
            # Sparsity statistics
            num_active = len(active_activations)
            sparsity = num_active / len(feature_activations)  # fraction of prompts where feature is active
            
            # Add to results
            csv_result = {
                'feature_id': int(feature_idx),
                'all_mean': all_mean,
                'all_std': all_std,
                'max': max_activation,
                'num_active': num_active,
                'sparsity': sparsity,
                'source': source_name,
                'token': token_type,
            }
            csv_results.append(csv_result)
        
        print(f"Processed all {features_np.shape[1]} features for token_type='{token_type}'")
    
    print(f"\nTotal feature records: {len(csv_results)}")
    return csv_results

def save_as_pt_cpu():
    """Save results as PyTorch tensors using CPU computation (most accurate)"""
    source_name = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_{MODEL_VER}"
    
    print(f"Processing results for PyTorch format using CPU, source: {source_name}")
    
    # Store results as tensors for each token type
    results_dict = {}
    
    # Process each token type
    for token_type in TOKEN_OFFSETS.keys():
        print(f"\nProcessing token type: {token_type}")
        
        # Get features tensor for this token type: [num_prompts, num_features]
        features_tensor = token_features[token_type].float()  # Convert to float32 on CPU
        
        print(f"Processing all {features_tensor.shape[1]} features for token_type='{token_type}' on CPU")
        
        # Calculate statistics vectorized across all features
        # features_tensor shape: [num_prompts, num_features]
        
        # All statistics (including zeros)
        all_mean = features_tensor.mean(dim=0)  # [num_features]
        all_std = features_tensor.std(dim=0)    # [num_features]
        max_vals = features_tensor.max(dim=0)[0]  # [num_features]
        
        # Active statistics (only non-zero values)
        active_mask = features_tensor > 0  # [num_prompts, num_features]
        num_active = active_mask.sum(dim=0)  # [num_features]
        sparsity = num_active.float() / features_tensor.shape[0]  # [num_features]
        
        # Store essential statistics as tensors
        results_dict[token_type] = {
            'all_mean': all_mean,
            'all_std': all_std,
            'max': max_vals,
            'num_active': num_active,
            'sparsity': sparsity,
        }
        
        print(f"Processed all {features_tensor.shape[1]} features for token_type='{token_type}'")
    
    # Add metadata
    results_dict['metadata'] = {
        'source': source_name,
        'model_type': MODEL_TYPE,
        'model_ver': MODEL_VER,
        'sae_layer': SAE_LAYER,
        'sae_trainer': SAE_TRAINER,
        'num_prompts': features_tensor.shape[0],
        'num_features': features_tensor.shape[1],
        'token_types': list(TOKEN_OFFSETS.keys())
    }
    
    print(f"\nTotal token types processed: {len(results_dict) - 1}")  # -1 for metadata
    return results_dict

def save_as_pt_gpu():
    """Save results as PyTorch tensors using GPU computation (faster but potentially less accurate)"""
    source_name = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_{MODEL_VER}"
    
    print(f"Processing results for PyTorch format using GPU, source: {source_name}")
    
    # Store results as tensors for each token type
    results_dict = {}
    
    # Process each token type
    for token_type in TOKEN_OFFSETS.keys():
        print(f"\nProcessing token type: {token_type}")
        
        # Get features tensor for this token type: [num_prompts, num_features]
        # Keep on GPU for faster computation and ensure float dtype
        features_tensor = token_features[token_type].to(device).float()
        
        print(f"Processing all {features_tensor.shape[1]} features for token_type='{token_type}' on GPU")
        print(f"Features tensor dtype: {features_tensor.dtype}")
        
        # Calculate statistics vectorized across all features on GPU
        # features_tensor shape: [num_prompts, num_features]
        
        # All statistics (including zeros)
        all_mean = features_tensor.mean(dim=0)  # [num_features]
        all_std = features_tensor.std(dim=0)    # [num_features]
        max_vals = features_tensor.max(dim=0)[0]  # [num_features]
        
        # Active statistics (only non-zero values)
        active_mask = features_tensor > 0  # [num_prompts, num_features]
        num_active = active_mask.sum(dim=0)  # [num_features]
        sparsity = num_active.float() / features_tensor.shape[0]  # [num_features]
        
        # Store essential statistics as tensors (move to CPU for storage)
        results_dict[token_type] = {
            'all_mean': all_mean.cpu(),
            'all_std': all_std.cpu(),
            'max': max_vals.cpu(),
            'num_active': num_active.cpu(),
            'sparsity': sparsity.cpu(),
        }
        
        print(f"Processed all {features_tensor.shape[1]} features for token_type='{token_type}'")
    
    # Add metadata
    results_dict['metadata'] = {
        'source': source_name,
        'model_type': MODEL_TYPE,
        'model_ver': MODEL_VER,
        'sae_layer': SAE_LAYER,
        'sae_trainer': SAE_TRAINER,
        'num_prompts': features_tensor.shape[0],
        'num_features': features_tensor.shape[1],
        'token_types': list(TOKEN_OFFSETS.keys())
    }
    
    print(f"\nTotal token types processed: {len(results_dict) - 1}")  # -1 for metadata
    return results_dict

# Choose your approach:
# results_dict = save_as_pt_cpu()    # Most accurate, slower
# results_dict = save_as_pt_gpu()    # Faster, potentially less accurate

print("Using CPU version for accuracy...")
results_dict = save_as_pt_cpu()

In [None]:
# Save results
print("Saving results...")

# Save as PyTorch file (much faster and more efficient)
pt_output_file = OUTPUT_FILE
torch.save(results_dict, pt_output_file)
print(f"PyTorch results saved to: {pt_output_file}")

# Show preview of PyTorch data structure
print(f"\nPyTorch file structure:")
print(f"Keys: {list(results_dict.keys())}")
print(f"Metadata: {results_dict['metadata']}")

for token_type in TOKEN_OFFSETS.keys():
    print(f"\n{token_type} statistics shapes:")
    for stat_name, tensor in results_dict[token_type].items():
        print(f"  {stat_name}: {tensor.shape}")
    
    # Show some sample statistics
    print(f"\n{token_type} sample statistics:")
    print(f"  all_mean - min: {results_dict[token_type]['all_mean'].min():.6f}, max: {results_dict[token_type]['all_mean'].max():.6f}")
    print(f"  sparsity - min: {results_dict[token_type]['sparsity'].min():.6f}, max: {results_dict[token_type]['sparsity'].max():.6f}")
    print(f"  num_active - min: {results_dict[token_type]['num_active'].min():.0f}, max: {results_dict[token_type]['num_active'].max():.0f}")

print(f"\n✓ Analysis complete! PyTorch file size is much smaller and loads faster than CSV.")