# Feature Diffing Plotting

This notebook creates an interactive plot of the results from model diffing.

In [None]:
import torch
import pandas as pd
import numpy as np
import os
from pathlib import Path

# Configuration
MODEL_TYPE = "gemma"
MODEL_NAME_READABLE = "Gemma 2 9B"
SAE_LAYER = 20
SAE_TRAINER = "131k-l0-114"
TOKEN_OFFSETS = {"model": -1, "newline": 0}
N_PROMPTS = 40
PERCENT_ACTIVE = 1

# Metrics for analysis
METRIC_SUBTITLE = {
    'target_all_mean': 'Mean Activation',
    'target_sparsity': 'Activation Sparsity'
}

# File paths
SOURCE = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}/personal_40"
BASE_FILE = f"/workspace/results/5_diffing_comp/{SOURCE}/base.pt"
CHAT_FILE = f"/workspace/results/5_diffing_comp/{SOURCE}/chat.pt"
EXPLANATIONS_PATH = f"../../explanations/{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}.csv"
OUTPUT_DIR = Path(f"./{SOURCE}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Neuronpedia link
NEURONPEDIA_BASE = f"https://www.neuronpedia.org/gemma-2-9b/{SAE_LAYER}-gemmascope-res-131k/"

# Load data
print(f"Loading data from {SOURCE}")
base_data = torch.load(BASE_FILE)
chat_data = torch.load(CHAT_FILE)

# Verify data consistency
token_types = [k for k in base_data.keys() if k != 'metadata']
assert base_data['metadata']['token_types'] == chat_data['metadata']['token_types']
print(f"Processing {len(token_types)} token types: {token_types}")

# Load explanations
explanations_df = pd.read_csv(EXPLANATIONS_PATH)
explanations_dict = dict(zip(explanations_df['feature_id'], explanations_df['claude_desc']))
print(f"Loaded {len(explanations_df)} explanations")

## Plot Results

In [None]:
# Pre-compute all feature data for performance optimization
import plotly.graph_objects as go
# Create optimized interactive scatterplot
SELECTED_METRIC = 'target_sparsity'

def precompute_feature_data(base_data, chat_data, token_types, explanations_dict):
    """Pre-compute all feature classifications and data for fast plotting"""
    
    precomputed = {}
    num_target_prompts = base_data['metadata']['num_target_prompts']
    num_control_prompts = base_data['metadata']['num_control_prompts']
    
    for token_type in token_types:
        print(f"Pre-computing data for {token_type}...")
        
        # Get all activation data as numpy arrays
        base_target = base_data[token_type]['target_num_active'].numpy()
        base_control = base_data[token_type]['control_num_active'].numpy()
        chat_target = chat_data[token_type]['target_num_active'].numpy()
        chat_control = chat_data[token_type]['control_num_active'].numpy()
        
        # Vectorized boolean operations
        base_target_active = base_target > 0
        base_control_active = base_control > 0
        chat_target_active = chat_target > 0
        chat_control_active = chat_control > 0
        
        # Exclusivity masks
        base_exclusive = base_target_active & ~base_control_active
        chat_exclusive = chat_target_active & ~chat_control_active
        both_exclusive = base_exclusive & chat_exclusive
        
        # Filter to only target-exclusive features
        target_exclusive_mask = base_exclusive | chat_exclusive
        
        # Get metric values
        base_values = base_data[token_type][SELECTED_METRIC].numpy()
        chat_values = chat_data[token_type][SELECTED_METRIC].numpy()
        
        # Apply filter
        filtered_indices = np.where(target_exclusive_mask)[0]
        
        # Pre-compute exclusivity categories
        exclusivity_categories = np.full(len(filtered_indices), 'neither', dtype=object)
        for i, idx in enumerate(filtered_indices):
            if both_exclusive[idx]:
                exclusivity_categories[i] = 'both'
            elif base_exclusive[idx]:
                exclusivity_categories[i] = 'base'
            elif chat_exclusive[idx]:
                exclusivity_categories[i] = 'chat'
        
        # Pre-compute cross-model inconsistencies
        inconsistencies = np.zeros(len(filtered_indices), dtype=bool)
        for i, idx in enumerate(filtered_indices):
            base_target_only = base_exclusive[idx]
            chat_target_only = chat_exclusive[idx]
            inconsistencies[i] = (base_target_only and chat_control_active[idx]) or \
                                (chat_target_only and base_control_active[idx])
        
        # Pre-format explanations
        explanations = []
        for idx in filtered_indices:
            explanation = explanations_dict.get(idx, "No explanation available")
            if not isinstance(explanation, str) or pd.isna(explanation):
                explanation = "No explanation available"
            
            # Simple text wrapping
            if len(explanation) > 50:
                words = explanation.split()
                lines = []
                current_line = ""
                for word in words:
                    if len(current_line + word) <= 80:
                        current_line += word + " "
                    else:
                        if current_line:
                            lines.append(current_line.strip())
                        current_line = word + " "
                if current_line:
                    lines.append(current_line.strip())
                explanation = "<br>".join(lines)
            
            explanations.append(explanation)
        
        # Pre-compute hover texts
        hover_texts = []
        for i, idx in enumerate(filtered_indices):
            base_val = base_values[idx]
            chat_val = chat_values[idx]
            diff = chat_val - base_val
            
            # Calculate percentages
            base_target_pct = (base_target[idx] / num_target_prompts) * 100
            base_control_pct = (base_control[idx] / num_control_prompts) * 100
            chat_target_pct = (chat_target[idx] / num_target_prompts) * 100
            chat_control_pct = (chat_control[idx] / num_control_prompts) * 100

            inconsistency_note = ""
            if inconsistencies[i]:
                # Determine which model has the inconsistency
                base_target_only = base_exclusive[idx]
                chat_target_only = chat_exclusive[idx]
                
                if base_target_only and chat_control_active[idx]:
                    inconsistency_note = "<br>This feature is also active on the Instruct model's control prompts"
                elif chat_target_only and base_control_active[idx]:
                    inconsistency_note = "<br>This feature is also active on the Base model's control prompts"
                
            exclusivity_text = {
                "both": "Introspective prompt-exclusive feature for both models",
                "base": "Introspective prompt-exclusive feature for Base",
                "chat": "Introspective prompt-exclusive feature for Instruct",
            }[exclusivity_categories[i]]

            
            hover_text = (
                f"<b>Feature {idx}</b><br>" +
                f"Base {METRIC_SUBTITLE[SELECTED_METRIC]}: {base_val:.4f}<br>Instruct {METRIC_SUBTITLE[SELECTED_METRIC]}: {chat_val:.4f}<br>" +
                f"Difference: {diff:.4f}<br><br>" +
                f"<b>{exclusivity_text}{inconsistency_note}</b><br>"
                f"Base Introspective: {base_target_pct:.1f}% ({base_target[idx]}/{num_target_prompts})<br>" +
                f"Base Control: {base_control_pct:.1f}% ({base_control[idx]}/{num_control_prompts})<br>" +
                f"Instruct Introspective: {chat_target_pct:.1f}% ({chat_target[idx]}/{num_target_prompts})<br>" +
                f"Instruct Control: {chat_control_pct:.1f}% ({chat_control[idx]}/{num_control_prompts})<br><br>"
            )
            
            
            
            hover_text += f"<br><b>Description:</b><br>{explanations[i]}<extra></extra>"
            hover_texts.append(hover_text)
        
        precomputed[token_type] = {
            'feature_ids': filtered_indices,
            'base_values': base_values[filtered_indices],
            'chat_values': chat_values[filtered_indices],
            'exclusivity_categories': exclusivity_categories,
            'inconsistencies': inconsistencies,
            'hover_texts': hover_texts,
            'neuronpedia_urls': [f"{NEURONPEDIA_BASE}{idx}" for idx in filtered_indices]
        }
        
        print(f"  {len(filtered_indices)} target-exclusive features")
    
    return precomputed

# Pre-compute all data
precomputed_data = precompute_feature_data(base_data, chat_data, token_types, explanations_dict)

In [None]:


# Styling configuration
colors = {'model': '#FF6B6B', 'newline': '#4ECDC4'}
exclusivity_symbols = {'both': 'star', 'base': 'circle', 'chat': 'square'}
exclusivity_names = {'both': 'Both Introspective', 'base': 'Base Introspective', 'chat': 'Instruct Introspective'}

print(f"Creating interactive scatterplot for {SELECTED_METRIC}...")

# Create figure
fig = go.Figure()

# Get all values for axis scaling
all_base_values = []
all_chat_values = []
for token_type in token_types:
    data = precomputed_data[token_type]
    all_base_values.extend(data['base_values'])
    all_chat_values.extend(data['chat_values'])

# Add traces grouped by token type and exclusivity
for token_type in token_types:
    data = precomputed_data[token_type]
    
    for exclusivity_type in ['both', 'base', 'chat']:
        # Filter data for this exclusivity type
        mask = data['exclusivity_categories'] == exclusivity_type
        if not np.any(mask):
            continue
        
        # Separate consistent and inconsistent features
        consistent_mask = mask & ~data['inconsistencies']
        inconsistent_mask = mask & data['inconsistencies']
        
        trace_name = f"  {exclusivity_names[exclusivity_type]}"
        
        # Add consistent features
        if np.any(consistent_mask):
            fig.add_trace(
                go.Scattergl(
                    x=data['base_values'][consistent_mask],
                    y=data['chat_values'][consistent_mask],
                    mode='markers',
                    name=trace_name,
                    legendgroup=token_type,
                    legendgrouptitle_text=token_type,
                    marker=dict(
                        size=6,
                        color=colors[token_type],
                        symbol=exclusivity_symbols[exclusivity_type],
                        line=dict(width=0.3, color='black'),
                        opacity=0.7
                    ),
                    text=[f"Feature {fid}" for fid in data['feature_ids'][consistent_mask]],
                    customdata=np.array(data['neuronpedia_urls'])[consistent_mask],
                    hovertemplate=np.array(data['hover_texts'])[consistent_mask],
                    hoverlabel=dict(
                        bgcolor=colors[token_type],
                        bordercolor="black",
                        font_size=12,
                        font_family="Arial",
                        font_color="white"
                    )
                )
            )
        
        # Add inconsistent features with red border
        if np.any(inconsistent_mask):
            other_model_text = "Instruct" if exclusivity_type == "base" else "Base"
            inconsistent_suffix = f" + {other_model_text} Control"

            fig.add_trace(
                go.Scattergl(
                    x=data['base_values'][inconsistent_mask],
                    y=data['chat_values'][inconsistent_mask],
                    mode='markers',
                    name=trace_name + f"{inconsistent_suffix}",
                    legendgroup=token_type,
                    legendgrouptitle_text=token_type,
                    marker=dict(
                        size=6,
                        color=colors[token_type],
                        symbol=exclusivity_symbols[exclusivity_type],
                        line=dict(width=1, color='red'),
                        opacity=0.7
                    ),
                    text=[f"Feature {fid}" for fid in data['feature_ids'][inconsistent_mask]],
                    customdata=np.array(data['neuronpedia_urls'])[inconsistent_mask],
                    hovertemplate=np.array(data['hover_texts'])[inconsistent_mask],
                    hoverlabel=dict(
                        bgcolor=colors[token_type],
                        bordercolor="black",
                        font_size=12,
                        font_family="Arial",
                        font_color="white"
                    )
                )
            )

# Add diagonal "no change" line
max_val = max(max(all_base_values), max(all_chat_values))
min_val = min(min(all_base_values), min(all_chat_values))

fig.add_trace(
    go.Scatter(
        x=[min_val, max_val],
        y=[min_val, max_val],
        mode='lines',
        line=dict(color='gray', dash='dash', width=2),
        name='No Change',
        hovertemplate="No change line<extra></extra>",
        hoverlabel=dict(bgcolor="gray", bordercolor="black", font_size=11, font_family="Arial", font_color="white")
    )
)

buffer = (max_val - min_val) * 0.05  # 5% buffer
axis_min = min_val - buffer
axis_max = max_val + buffer

# Update layout
metric_display = METRIC_SUBTITLE[SELECTED_METRIC]

fig.update_layout(
    title={
        'text': f'Base → Instruct Introspective SAE Features: {metric_display}<br><sub>{MODEL_NAME_READABLE}, Residual Stream Post-Layer {SAE_LAYER}</sub>',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    xaxis_title=f'Base: {metric_display}',
    yaxis_title=f'Instruct: {metric_display}',
    height=800,
    width=1005,
    showlegend=True,
    hovermode='closest',
    legend=dict(
        title="Activation Position",
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02,
        groupclick="togglegroup"
    ),
    xaxis=dict(scaleanchor="y", scaleratio=1, range=[axis_min, axis_max]),
    yaxis=dict(range=[axis_min, axis_max]),
)

fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')

# Save with click handler
out_name = SELECTED_METRIC.removeprefix('target_')
output_html = OUTPUT_DIR / f"{out_name}.html"
html_content = fig.to_html(
    include_plotlyjs='cdn',
    config={'displayModeBar': True, 'showTips': False, 'scrollZoom': True, 'doubleClick': 'reset'}
)

click_script = """
<script>
document.addEventListener('DOMContentLoaded', function() {
    var plotDiv = document.getElementsByClassName('plotly-graph-div')[0];
    let clickTimeout;
    plotDiv.on('plotly_click', function(data) {
        clearTimeout(clickTimeout);
        clickTimeout = setTimeout(function() {
            var point = data.points[0];
            if (point.customdata) {
                window.open(point.customdata, '_blank');
            }
        }, 100);
    });
});
</script>
"""

with open(output_html, 'w') as f:
    f.write(html_content.replace('</body>', click_script + '</body>'))

print(f"Interactive scatterplot saved to: {output_html}")
fig.show()