# Feature Diffing Analysis

This notebook analyzes the differences between base and chat models by comparing feature activations.

In [8]:
import torch
import pandas as pd
import numpy as np
import os
from pathlib import Path

# Configuration
MODEL_TYPE = "llama"
SAE_LAYER = 17
SAE_TRAINER = "32x"
TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}

# File paths
BASE_FILE = f"/workspace/results/4_diffing/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/1000_prompts/base.pt"
CHAT_FILE = f"/workspace/results/4_diffing/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/1000_prompts/chat.pt"

# Output directory
SOURCE = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
OUTPUT_DIR = Path(f"./{SOURCE}/1000_prompts")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Link
LLAMA_LINK_FORMAT = f"https://www.neuronpedia.org/llama3.1-8b/{SAE_LAYER}-llamascope-res-131k/"

print(f"Loading base model data from: {BASE_FILE}")
print(f"Loading chat model data from: {CHAT_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

# Load the PyTorch files
base_data = torch.load(BASE_FILE)
chat_data = torch.load(CHAT_FILE)

print(f"\nBase data keys: {list(base_data.keys())}")
print(f"Chat data keys: {list(chat_data.keys())}")
print(f"Base metadata: {base_data['metadata']}")
print(f"Chat metadata: {chat_data['metadata']}")

# Verify token types match
base_tokens = [k for k in base_data.keys() if k != 'metadata']
chat_tokens = [k for k in chat_data.keys() if k != 'metadata']
print(f"\nBase token types: {base_tokens}")
print(f"Chat token types: {chat_tokens}")
assert base_tokens == chat_tokens, "Token types don't match between base and chat!"

token_types = base_tokens
print(f"Processing {len(token_types)} token types: {token_types}")

Loading base model data from: /workspace/results/4_diffing/llama_trainer32x_layer17/1000_prompts/base.pt
Loading chat model data from: /workspace/results/4_diffing/llama_trainer32x_layer17/1000_prompts/chat.pt
Output directory: llama_trainer32x_layer17/1000_prompts

Base data keys: ['asst', 'endheader', 'newline', 'metadata']
Chat data keys: ['asst', 'endheader', 'newline', 'metadata']
Base metadata: {'source': 'llama_trainer32x_layer17_base', 'model_type': 'llama', 'model_ver': 'base', 'sae_layer': 17, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}
Chat metadata: {'source': 'llama_trainer32x_layer17_chat', 'model_type': 'llama', 'model_ver': 'chat', 'sae_layer': 17, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}

Base token types: ['asst', 'endheader', 'newline']
Chat token types: ['asst', 'endheader', 'newline']
Processing 3 token types: ['asst', '

## Plot Results

In [None]:
# Create interactive scatterplot for one metric with all 3 token types
import plotly.graph_objects as go

# Load Claude explanations
explanations_path = f"../../explanations/{SOURCE}.csv"
explanations_df = pd.read_csv(explanations_path)
print(f"Loaded {len(explanations_df)} explanations")

# Create a dictionary for fast lookup of explanations by feature_id
explanations_dict = dict(zip(explanations_df['feature_id'], explanations_df['claude_desc']))

# Choose one metric for detailed analysis
selected_metric = 'all_mean'  # Change this to 'active_mean' or 'sparsity' if desired
metric_subtitle = {
    'all_mean': 'Mean Activation',
    'active_mean': 'Mean Activation (When Active)',
    'sparsity': 'Activation Sparsity'
}

print(f"Creating interactive scatterplot for {selected_metric} metric with all token types...")

# Color mapping for token types
colors = {'asst': '#FF6B6B', 'endheader': '#4ECDC4', 'newline': '#45B7D1'}
symbols = {'asst': 'circle', 'endheader': 'square', 'newline': 'diamond'}

# Create the scatterplot
fig = go.Figure()

total_features = 0

no_explanation = set()
for token_type in token_types:
    # Get all feature values for this token/metric combination
    base_values = base_data[token_type][selected_metric].numpy()
    chat_values = chat_data[token_type][selected_metric].numpy()
    
    # Filter out features that are inactive in both base and chat
    active_mask = (base_values > 0) | (chat_values > 0)
    filtered_base = base_values[active_mask]
    filtered_chat = chat_values[active_mask]
    
    # Get feature IDs for active features
    feature_ids = np.arange(len(base_values))[active_mask]
    
    # Calculate differences for coloring
    differences = filtered_chat - filtered_base
    
    # Create hover text with feature details including Claude explanations
    hover_text = []
    neuronpedia_urls = []
    for i, (fid, base_val, chat_val, diff) in enumerate(zip(feature_ids, filtered_base, filtered_chat, differences)):
        # Get Claude explanation if available
        if fid not in explanations_dict:
            no_explanation.add(fid)
        claude_explanation = explanations_dict.get(fid, "No explanation available")
        
        # Check if explanation is a string (handle NaN/float values)
        if not isinstance(claude_explanation, str) or pd.isna(claude_explanation):
            claude_explanation = "No explanation available"
            no_explanation.add(fid)
        
        # Format the explanation to wrap at reasonable length
        if len(claude_explanation) > 50:
            # Split into chunks of ~80 characters at word boundaries
            words = claude_explanation.split()
            lines = []
            current_line = ""
            for word in words:
                if len(current_line + word) <= 80:
                    current_line += word + " "
                else:
                    if current_line:
                        lines.append(current_line.strip())
                    current_line = word + " "
            if current_line:
                lines.append(current_line.strip())
            formatted_explanation = "<br>".join(lines)
        else:
            formatted_explanation = claude_explanation
        
        # Create Neuronpedia URL for this feature
        neuronpedia_url = f"{LLAMA_LINK_FORMAT}{fid}"
        neuronpedia_urls.append(neuronpedia_url)
        
        hover_text.append(
            f"<b>Feature {fid}</b><br>" +
            f"Base: {base_val:.4f}<br>" +
            f"Instruct: {chat_val:.4f}<br>" +
            f"Difference: {diff:.4f}<br><br>" +
            f"<b>Description:</b><br>" +
            f"{formatted_explanation}<extra></extra>"
        )
    
    # Add scatter points for this token type using Scattergl for better performance
    fig.add_trace(
        go.Scattergl(
            x=filtered_base,
            y=filtered_chat,
            mode='markers',
            name=f'{token_type}',
            marker=dict(
                size=6,
                color=colors[token_type],
                symbol=symbols[token_type],
                line=dict(width=0.3, color='black'),  # Thinner lines
                opacity=0.7
            ),
            text=[f"Feature {fid}" for fid in feature_ids],
            customdata=neuronpedia_urls,
            hovertemplate=hover_text,
            hoverlabel=dict(
                bgcolor=colors[token_type],
                bordercolor="black",
                font_size=12, 
                font_family="Arial",
                font_color="white"
            )
        )
    )
    
    total_features += len(filtered_base)
    print(f"  {token_type}: {len(filtered_base):,} active features")

# Add diagonal "no change" line
all_base_values = []
all_chat_values = []
for token_type in token_types:
    base_values = base_data[token_type][selected_metric].numpy()
    chat_values = chat_data[token_type][selected_metric].numpy()
    active_mask = (base_values > 0) | (chat_values > 0)
    all_base_values.extend(base_values[active_mask])
    all_chat_values.extend(chat_values[active_mask])

max_val = max(max(all_base_values), max(all_chat_values))
min_val = min(min(all_base_values), min(all_chat_values))

fig.add_trace(
    go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode='lines',
        line=dict(color='gray', dash='dash', width=2),
        name='No Change',
        hovertemplate="No change line<extra></extra>",
        hoverlabel=dict(
            bgcolor="gray",
            bordercolor="black",
            font_size=11,
            font_family="Arial",
            font_color="white"
        )
    )
)

# Update layout with performance optimizations
fig.update_layout(
    title={
        'text': f'Base → Instruct SAE Features: {metric_subtitle[selected_metric]}<br><sub>Llama 3.1 8B, Residual Stream Post-Layer {SAE_LAYER}</sub>',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    xaxis_title=f'Base: {metric_subtitle[selected_metric]}',
    yaxis_title=f'Instruct: {metric_subtitle[selected_metric]}',
    height=800,
    width=900,
    showlegend=True,
    hovermode='closest',
    legend=dict(
        title="Activation Position",
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02
    ),
)

# Add grid
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')

# Save the interactive plot with JavaScript for click handling
output_html = OUTPUT_DIR / f"{selected_metric}.html"

# Create custom HTML with click handler and performance optimizations
html_content = fig.to_html(
    include_plotlyjs='cdn',  # Use CDN for smaller file size
    config={
        'displayModeBar': True,
        'showTips': False,
        'scrollZoom': True,
        'doubleClick': 'reset'
    }
)

# Add JavaScript to handle clicks with performance optimizations
click_script = """
<script>
document.addEventListener('DOMContentLoaded', function() {
    var plotDiv = document.getElementsByClassName('plotly-graph-div')[0];
    
    // Debounce click events to prevent rapid clicking
    let clickTimeout;
    plotDiv.on('plotly_click', function(data) {
        clearTimeout(clickTimeout);
        clickTimeout = setTimeout(function() {
            var point = data.points[0];
            if (point.customdata) {
                window.open(point.customdata, '_blank');
            }
        }, 100);
    });
});
</script>
"""

# Insert the script before the closing body tag
html_with_script = html_content.replace('</body>', click_script + '</body>')

with open(output_html, 'w') as f:
    f.write(html_with_script)

print(f"\nInteractive scatterplot saved to: {output_html}")
print(f"File size: {output_html.stat().st_size / 1024:.1f} KB")

# Show the plot
fig.show()

print(f"\nScatterplot features:")
print(f"- {total_features:,} total active features from all token types")
print(f"- Different colors/symbols for each token type")
print(f"- Hover shows feature ID, values, and Claude explanation")
print(f"- Interactive legend to show/hide token types")
print(f"- Gray diagonal line = no change reference")
print(f"- Click any point to open its Neuronpedia page")
print(f"- Optimized for performance with 7k+ data points")
print(f"- Can easily change selected_metric above")

Loaded 914 explanations
Creating interactive scatterplot for all_mean metric with all token types...
  asst: 1,755 active features
  endheader: 2,206 active features
  newline: 3,507 active features

Interactive scatterplot saved to: llama_trainer32x_layer17/1000_prompts/all_mean.html
File size: 3299.6 KB



Scatterplot features:
- 7,468 total active features from all token types
- Different colors/symbols for each token type
- Hover shows feature ID, values, and Claude explanation
- Interactive legend to show/hide token types
- Gray diagonal line = no change reference
- Click any point to open its Neuronpedia page
- Optimized for performance with 7k+ data points
- Can easily change selected_metric above


In [None]:
print(len(no_explanation))

# dump each id in no_explanations to a csv, with feature_id and neuronpedia_url

def dump_no_explanation(no_explanation):
    with open('./llama_trainer32x_layer15/missing_explanations.csv', 'w') as f:
        f.write('feature_id,link\n')
        for fid in no_explanation:
            f.write(f'{fid},https://www.neuronpedia.org/llama3.1-8b/15-llamascope-res-131k/{fid}\n')

dump_no_explanation(no_explanation)


3959
