# Feature Diffing Analysis

This notebook analyzes the differences between base and chat models by comparing feature activations.

In [1]:
import torch
import pandas as pd
import numpy as np
import os
from pathlib import Path

# Configuration
MODEL_TYPE = "llama"
SAE_LAYER = 15
SAE_TRAINER = "32x"
TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}

# File paths
BASE_FILE = f"/workspace/results/4_diffing/1_{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_base.pt"
CHAT_FILE = f"/workspace/results/4_diffing/1_{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_chat.pt"

# Output directory
SOURCE = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
OUTPUT_DIR = Path(f"./{SOURCE}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Link
LLAMA_LINK_FORMAT = f"https://www.neuronpedia.org/llama3.1-8b/{SAE_LAYER}-llamascope-res-131k/"

print(f"Loading base model data from: {BASE_FILE}")
print(f"Loading chat model data from: {CHAT_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

# Load the PyTorch files
base_data = torch.load(BASE_FILE)
chat_data = torch.load(CHAT_FILE)

print(f"\nBase data keys: {list(base_data.keys())}")
print(f"Chat data keys: {list(chat_data.keys())}")
print(f"Base metadata: {base_data['metadata']}")
print(f"Chat metadata: {chat_data['metadata']}")

# Verify token types match
base_tokens = [k for k in base_data.keys() if k != 'metadata']
chat_tokens = [k for k in chat_data.keys() if k != 'metadata']
print(f"\nBase token types: {base_tokens}")
print(f"Chat token types: {chat_tokens}")
assert base_tokens == chat_tokens, "Token types don't match between base and chat!"

token_types = base_tokens
print(f"Processing {len(token_types)} token types: {token_types}")

Loading base model data from: /workspace/results/4_diffing/1_llama_trainer32x_layer15_base.pt
Loading chat model data from: /workspace/results/4_diffing/1_llama_trainer32x_layer15_chat.pt
Output directory: llama_trainer32x_layer15

Base data keys: ['asst', 'endheader', 'newline', 'metadata']
Chat data keys: ['asst', 'endheader', 'newline', 'metadata']
Base metadata: {'source': 'llama_trainer32x_layer15_base', 'model_type': 'llama', 'model_ver': 'base', 'sae_layer': 15, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}
Chat metadata: {'source': 'llama_trainer32x_layer15_chat', 'model_type': 'llama', 'model_ver': 'chat', 'sae_layer': 15, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}

Base token types: ['asst', 'endheader', 'newline']
Chat token types: ['asst', 'endheader', 'newline']
Processing 3 token types: ['asst', 'endheader', 'newline']


## Plot Results

In [None]:
# Create interactive scatterplot for one metric with all 3 token types
import plotly.graph_objects as go
import plotly.express as px

# Load Claude explanations
explanations_path = "/root/git/persona-subspace/sae_feature_analysis/results/4_diffing/llama_trainer32x_layer15/explanations_with_claude.csv"
explanations_df = pd.read_csv(explanations_path)
print(f"Loaded {len(explanations_df)} explanations")

# Create a dictionary for fast lookup of explanations by feature_id
explanations_dict = dict(zip(explanations_df['feature_id'], explanations_df['claude_desc']))

# Choose one metric for detailed analysis
selected_metric = 'sparsity'  # Change this to 'active_mean' or 'sparsity' if desired

print(f"Creating interactive scatterplot for {selected_metric} metric with all token types...")

# Color mapping for token types
colors = {'asst': '#FF6B6B', 'endheader': '#4ECDC4', 'newline': '#45B7D1'}
symbols = {'asst': 'circle', 'endheader': 'square', 'newline': 'diamond'}

# Create the scatterplot
fig = go.Figure()

total_features = 0
for token_type in token_types:
    # Get all feature values for this token/metric combination
    base_values = base_data[token_type][selected_metric].numpy()
    chat_values = chat_data[token_type][selected_metric].numpy()
    
    # Filter out features that are inactive in both base and chat
    active_mask = (base_values > 0) | (chat_values > 0)
    filtered_base = base_values[active_mask]
    filtered_chat = chat_values[active_mask]
    
    # Get feature IDs for active features
    feature_ids = np.arange(len(base_values))[active_mask]
    
    # Calculate differences for coloring
    differences = filtered_chat - filtered_base
    
    # Create hover text with feature details including Claude explanations
    hover_text = []
    neuronpedia_urls = []
    for i, (fid, base_val, chat_val, diff) in enumerate(zip(feature_ids, filtered_base, filtered_chat, differences)):
        # Get Claude explanation if available
        claude_explanation = explanations_dict.get(fid, "No explanation available")
        
        # Check if explanation is a string (handle NaN/float values)
        if not isinstance(claude_explanation, str) or pd.isna(claude_explanation):
            claude_explanation = "No explanation available"
        
        # Format the explanation to wrap at reasonable length
        if len(claude_explanation) > 80:
            # Split into chunks of ~80 characters at word boundaries
            words = claude_explanation.split()
            lines = []
            current_line = ""
            for word in words:
                if len(current_line + word) <= 80:
                    current_line += word + " "
                else:
                    if current_line:
                        lines.append(current_line.strip())
                    current_line = word + " "
            if current_line:
                lines.append(current_line.strip())
            formatted_explanation = "<br>".join(lines)
        else:
            formatted_explanation = claude_explanation
        
        # Create Neuronpedia URL for this feature
        neuronpedia_url = f"{LLAMA_LINK_FORMAT}{fid}"
        neuronpedia_urls.append(neuronpedia_url)
        
        hover_text.append(
            f"<b>Feature {fid}</b><br>" +
            f"Base: {base_val:.6f}<br>" +
            f"Chat: {chat_val:.6f}<br>" +
            f"Difference: {diff:.6f}<br><br>" +
            f"<b>Description:</b><br>" +
            f"{formatted_explanation}<extra></extra>"
        )
    
    # Add scatter points for this token type
    fig.add_trace(
        go.Scatter(
            x=filtered_base,
            y=filtered_chat,
            mode='markers',
            name=f'{token_type}',
            marker=dict(
                size=6,
                color=colors[token_type],
                symbol=symbols[token_type],
                line=dict(width=0.5, color='black'),
                opacity=0.7
            ),
            text=[f"Feature {fid}" for fid in feature_ids],
            customdata=neuronpedia_urls,
            hovertemplate=hover_text
        )
    )
    
    total_features += len(filtered_base)
    print(f"  {token_type}: {len(filtered_base):,} active features")

# Add diagonal "no change" line
all_base_values = []
all_chat_values = []
for token_type in token_types:
    base_values = base_data[token_type][selected_metric].numpy()
    chat_values = chat_data[token_type][selected_metric].numpy()
    active_mask = (base_values > 0) | (chat_values > 0)
    all_base_values.extend(base_values[active_mask])
    all_chat_values.extend(chat_values[active_mask])

max_val = max(max(all_base_values), max(all_chat_values))
min_val = min(min(all_base_values), min(all_chat_values))

fig.add_trace(
    go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode='lines',
        line=dict(color='gray', dash='dash', width=2),
        name='No Change',
        hovertemplate="No change line<extra></extra>"
    )
)

# Update layout
fig.update_layout(
    title={
        'text': f'Base → Instruct SAE Features: Activation Sparsity<br><sub>Llama 3.1 8B, Residual Stream Post-Layer 15</sub>',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    xaxis_title=f'Base: Activation Sparsity',
    yaxis_title=f'Instruct: Activation Sparsity',
    height=800,
    width=900,
    showlegend=True,
    hovermode='closest',
    legend=dict(
        title="Activation Position",
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02
    ),
    # Configure hover appearance
    hoverlabel=dict(
        bgcolor="white",
        bordercolor="black",
        font_size=12,
        font_family="Arial"
    )
)

# Add grid
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')

# Save the interactive plot with JavaScript for click handling
output_html = OUTPUT_DIR / f"{selected_metric}.html"

# Create custom HTML with click handler
html_content = fig.to_html(include_plotlyjs='cdn')

# Add JavaScript to handle clicks
click_script = """
<script>
document.addEventListener('DOMContentLoaded', function() {
    var plotDiv = document.getElementsByClassName('plotly-graph-div')[0];
    plotDiv.on('plotly_click', function(data) {
        var point = data.points[0];
        if (point.customdata) {
            window.open(point.customdata, '_blank');
        }
    });
});
</script>
"""

# Insert the script before the closing body tag
html_with_script = html_content.replace('</body>', click_script + '</body>')

with open(output_html, 'w') as f:
    f.write(html_with_script)

print(f"\nInteractive scatterplot saved to: {output_html}")
print(f"File size: {output_html.stat().st_size / 1024:.1f} KB")

# Show the plot
fig.show()

print(f"\nScatterplot features:")
print(f"- {total_features:,} total active features from all token types")
print(f"- Different colors/symbols for each token type")
print(f"- Hover shows feature ID, values, and Claude explanation")
print(f"- Interactive legend to show/hide token types")
print(f"- Gray diagonal line = no change reference")
print(f"- Click any point to open its Neuronpedia page")
print(f"- Can easily change selected_metric above")