# Feature Diffing Plotting

This notebook creates an interactive plot of the results from model diffing.

In [None]:
import torch
import pandas as pd
import numpy as np
import os
from pathlib import Path

# Configuration
MODEL_TYPE = "llama"
MODEL_NAME_READABLE = "Llama 3.1 8B"
SAE_LAYER = 13
SAE_TRAINER = "32x"
TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}
N_PROMPTS = 1000
# MODEL_TYPE = "gemma"
# MODEL_NAME_READABLE = "Gemma 2 9B"
# SAE_LAYER = 20
# SAE_TRAINER = "131k-l0-114"
# TOKEN_OFFSETS = {"model": -1, "newline": 0}
# N_PROMPTS = 40
PERCENT_ACTIVE = 1

# Choose one metric for detailed analysis
METRIC_SUBTITLE = {
    'all_mean': 'Mean Activation',
    'sparsity': 'Activation Sparsity'
}


# File paths
BASE_FILE = f"/workspace/results/4_diffing/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/{N_PROMPTS}_prompts/base.pt"
CHAT_FILE = f"/workspace/results/4_diffing/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/{N_PROMPTS}_prompts/chat.pt"
EXPLANATIONS_PATH = f"../../explanations/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}.csv"

# Output directory
SOURCE = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/{N_PROMPTS}_prompts"
OUTPUT_DIR = Path(f"./{SOURCE}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Link
LLAMA_LINK_FORMAT = f"https://www.neuronpedia.org/llama3.1-8b/{SAE_LAYER}-llamascope-res-131k/"
GEMMA_LINK_FORMAT = f"https://www.neuronpedia.org/gemma-2-9b/{SAE_LAYER}-gemmascope-res-131k/"

print(f"Loading base model data from: {BASE_FILE}")
print(f"Loading chat model data from: {CHAT_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

# Load the PyTorch files
base_data = torch.load(BASE_FILE)
chat_data = torch.load(CHAT_FILE)

print(f"\nBase data keys: {list(base_data.keys())}")
print(f"Chat data keys: {list(chat_data.keys())}")
print(f"Base metadata: {base_data['metadata']}")
print(f"Chat metadata: {chat_data['metadata']}")

# Verify token types match
base_tokens = [k for k in base_data.keys() if k != 'metadata']
chat_tokens = [k for k in chat_data.keys() if k != 'metadata']
print(f"\nBase token types: {base_tokens}")
print(f"Chat token types: {chat_tokens}")
assert base_tokens == chat_tokens, "Token types don't match between base and chat!"

token_types = base_tokens
print(f"Processing {len(token_types)} token types: {token_types}")

## Plot Results

In [None]:
# Create interactive scatterplot for one metric with all 3 token types
import plotly.graph_objects as go



# Load Claude explanations
explanations_df = pd.read_csv(EXPLANATIONS_PATH)
print(f"Loaded {len(explanations_df)} explanations")

# Create a dictionary for fast lookup of explanations by feature_id
explanations_dict = dict(zip(explanations_df['feature_id'], explanations_df['claude_desc']))

In [None]:

SELECTED_METRIC = 'all_mean'

print(f"Creating interactive scatterplot for {SELECTED_METRIC} metric with all token types...")

# Generate colors and symbols dynamically based on TOKEN_OFFSETS keys
def generate_plot_styling(token_offset_keys):
    """Generate colors and symbols for plotting based on token offset keys"""
    # Color palette - visually distinct colors
    color_palette = [
        '#FF6B6B',  # Red
        '#4ECDC4',  # Teal
        '#45B7D1',  # Blue
        '#96CEB4',  # Green
        '#FFEAA7',  # Yellow
        '#DDA0DD',  # Plum
        '#FFA07A',  # Light Salmon
        '#98D8C8',  # Mint
        '#F7DC6F',  # Light Yellow
        '#BB8FCE'   # Light Purple
    ]
    
    # Symbol palette - distinct shapes
    symbol_palette = [
        'circle',
        'square', 
        'diamond',
        'triangle-up',
        'triangle-down',
        'star',
        'hexagon',
        'pentagon',
        'cross',
        'x'
    ]
    
    colors = {}
    symbols = {}
    
    for i, token_key in enumerate(token_offset_keys):
        colors[token_key] = color_palette[i % len(color_palette)]
        symbols[token_key] = symbol_palette[i % len(symbol_palette)]
    
    return colors, symbols

# Generate styling based on current TOKEN_OFFSETS
colors, symbols = generate_plot_styling(list(TOKEN_OFFSETS.keys()))
print(f"Generated styling for token types: {list(TOKEN_OFFSETS.keys())}")
print(f"Colors: {colors}")
print(f"Symbols: {symbols}")

# Load target feature IDs if the file exists
active_file = f"./{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/personal_{N_PROMPTS}/explanations.csv"
if os.path.exists(active_file):
    active_features_df = pd.read_csv(active_file)
    active_feature_ids = set(active_features_df['feature_id'].tolist())
else:
    # Pre-calculate active masks once per token type since all metrics have same active features
    print("Pre-calculating active masks...")
    active_masks = {}
    for token_type in token_types:
        base_values = base_data[token_type]['num_active'].numpy()
        chat_values = chat_data[token_type]['num_active'].numpy()
        active_masks[token_type] = (base_values > int(N_PROMPTS * PERCENT_ACTIVE / 100)) | (chat_values > int(N_PROMPTS * PERCENT_ACTIVE / 100))
        print(f"  {token_type}: {active_masks[token_type].sum():,} active features")

# Pre-process all data in single loop to avoid redundant calculations
print("Pre-processing data...")
processed_data = {}
all_base_values = []
all_chat_values = []

def format_explanation_efficient(claude_explanation):
    """Efficient text wrapping function - preserves original formatting logic"""
    if not isinstance(claude_explanation, str) or pd.isna(claude_explanation):
        return "No explanation available"
    
    if len(claude_explanation) <= 50:
        return claude_explanation
    
    # Same text wrapping logic as original, but more efficient
    words = claude_explanation.split()
    lines = []
    current_line = ""
    
    for word in words:
        if len(current_line + word) <= 80:
            current_line += word + " "
        else:
            if current_line:
                lines.append(current_line.strip())
            current_line = word + " "
    
    if current_line:
        lines.append(current_line.strip())
    
    return "<br>".join(lines)

# Process all token types in one pass
no_explanation = set()
for token_type in token_types:
    if os.path.exists(active_file):
        # Use pre-filtered feature IDs from CSV
        all_feature_ids = np.arange(base_data['metadata']['num_features'])
        mask = np.isin(all_feature_ids, list(active_feature_ids))

        base_values = base_data[token_type][SELECTED_METRIC].numpy()[mask]
        chat_values = chat_data[token_type][SELECTED_METRIC].numpy()[mask]
        base_num_active = base_data[token_type]['num_active'].numpy()[mask]
        chat_num_active = chat_data[token_type]['num_active'].numpy()[mask]
        feature_ids = all_feature_ids[mask]
    else:
        # Fallback to mask calculation
        active_mask = active_masks[token_type]
        base_values = base_data[token_type][SELECTED_METRIC].numpy()[active_mask]
        chat_values = chat_data[token_type][SELECTED_METRIC].numpy()[active_mask]
        base_num_active = base_data[token_type]['num_active'].numpy()[active_mask]
        chat_num_active = chat_data[token_type]['num_active'].numpy()[active_mask]
        feature_ids = np.arange(len(active_mask))[active_mask]
    
    # Calculate differences vectorized
    differences = chat_values - base_values
    
    # Pre-process hover text and explanations
    hover_texts = []
    neuronpedia_urls = []
    
    for fid, base_val, chat_val, diff, base_active, chat_active in zip(
        feature_ids, base_values, chat_values, differences, base_num_active, chat_num_active
    ):
        # Get Claude explanation if available
        if fid not in explanations_dict:
            no_explanation.add(fid)
        claude_explanation = explanations_dict.get(fid, "No explanation available")
        
        # Check if explanation is a string (handle NaN/float values)
        if not isinstance(claude_explanation, str) or pd.isna(claude_explanation):
            claude_explanation = "No explanation available"
            no_explanation.add(fid)
        
        # Format explanation efficiently but preserve original wrapping behavior
        formatted_explanation = format_explanation_efficient(claude_explanation)
        
        # Create Neuronpedia URL - choose correct format based on model type
        if MODEL_TYPE == "llama":
            neuronpedia_url = f"{LLAMA_LINK_FORMAT}{fid}"
        else:  # gemma
            neuronpedia_url = f"{GEMMA_LINK_FORMAT}{fid}"
        neuronpedia_urls.append(neuronpedia_url)
        
        # Create hover text 
        hover_text = (
            f"<b>Feature {fid}</b><br>" +
            f"Base: {base_val:.4f} ({base_active}/{N_PROMPTS} prompts)<br>" +
            f"Instruct: {chat_val:.4f} ({chat_active}/{N_PROMPTS} prompts)<br>" +
            f"Difference: {diff:.4f}<br><br>" +
            f"<b>Description:</b><br>" +
            f"{formatted_explanation}<extra></extra>"
        )
        hover_texts.append(hover_text)
    
    # Store processed data
    processed_data[token_type] = {
        'base_values': base_values,
        'chat_values': chat_values,
        'feature_ids': feature_ids,
        'hover_texts': hover_texts,
        'neuronpedia_urls': neuronpedia_urls
    }
    
    # Collect all values for min/max calculation (single pass)
    all_base_values.extend(base_values)
    all_chat_values.extend(chat_values)

# Create the scatterplot
fig = go.Figure()

total_features = 0

# Add scatter traces using pre-processed data
for token_type in token_types:
    data = processed_data[token_type]
    
    # Add scatter points for this token type using Scattergl for better performance
    fig.add_trace(
        go.Scattergl(
            x=data['base_values'],
            y=data['chat_values'],
            mode='markers',
            name=f'{token_type}',
            marker=dict(
                size=6,
                color=colors[token_type],
                symbol=symbols[token_type],
                line=dict(width=0.3, color='black'),  # Thinner lines
                opacity=0.7
            ),
            text=[f"Feature {fid}" for fid in data['feature_ids']],
            customdata=data['neuronpedia_urls'],
            hovertemplate=data['hover_texts'],
            hoverlabel=dict(
                bgcolor=colors[token_type],
                bordercolor="black",
                font_size=12, 
                font_family="Arial",
                font_color="white"
            )
        )
    )
    
    total_features += len(data['base_values'])

# Add diagonal "no change" line using pre-calculated values (eliminates second loop)
max_val = max(max(all_base_values), max(all_chat_values))
min_val = min(min(all_base_values), min(all_chat_values))

fig.add_trace(
    go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode='lines',
        line=dict(color='gray', dash='dash', width=2),
        name='No Change',
        hovertemplate="No change line<extra></extra>",
        hoverlabel=dict(
            bgcolor="gray",
            bordercolor="black",
            font_size=11,
            font_family="Arial",
            font_color="white"
        )
    )
)

# Update layout with performance optimizations
fig.update_layout(
    title={
        'text': f'Base → Instruct SAE Features: {METRIC_SUBTITLE[SELECTED_METRIC]}<br><sub>{MODEL_NAME_READABLE}, Residual Stream Post-Layer {SAE_LAYER}</sub>',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    xaxis_title=f'Base: {METRIC_SUBTITLE[SELECTED_METRIC]}',
    yaxis_title=f'Instruct: {METRIC_SUBTITLE[SELECTED_METRIC]}',
    height=800,
    width=900,
    showlegend=True,
    hovermode='closest',
    legend=dict(
        title="Activation Position",
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02
    ),
)

# Add grid
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')

# Save the interactive plot with JavaScript for click handling
output_html = OUTPUT_DIR / f"{SELECTED_METRIC}.html"

# Create custom HTML with click handler and performance optimizations
html_content = fig.to_html(
    include_plotlyjs='cdn',  # Use CDN for smaller file size
    config={
        'displayModeBar': True,
        'showTips': False,
        'scrollZoom': True,
        'doubleClick': 'reset'
    }
)

# Add JavaScript to handle clicks with performance optimizations
click_script = """
<script>
document.addEventListener('DOMContentLoaded', function() {
    var plotDiv = document.getElementsByClassName('plotly-graph-div')[0];
    
    // Debounce click events to prevent rapid clicking
    let clickTimeout;
    plotDiv.on('plotly_click', function(data) {
        clearTimeout(clickTimeout);
        clickTimeout = setTimeout(function() {
            var point = data.points[0];
            if (point.customdata) {
                window.open(point.customdata, '_blank');
            }
        }, 100);
    });
});
</script>
"""

# Insert the script before the closing body tag
html_with_script = html_content.replace('</body>', click_script + '</body>')

with open(output_html, 'w') as f:
    f.write(html_with_script)

print(f"\nInteractive scatterplot saved to: {output_html}")
print(f"File size: {output_html.stat().st_size / 1024:.1f} KB")

# Show the plot
fig.show()

print(f"\nScatterplot features:")
print(f"- {total_features:,} total active features from all token types")
print(f"- Different colors/symbols for each token type")
print(f"- Hover shows feature ID, values, num_active, and Claude explanation")
print(f"- Interactive legend to show/hide token types")
print(f"- Gray diagonal line = no change reference")
print(f"- Click any point to open its Neuronpedia page")
print(f"- Optimized for performance with 7k+ data points")
print(f"- Can easily change selected_metric above")