# Feature Diffing Analysis

This notebook analyzes the differences between base and chat models by comparing feature activations.

In [3]:
import torch
import pandas as pd
import numpy as np
import os
from pathlib import Path

# Configuration
MODEL_TYPE = "llama"
SAE_LAYER = 15
SAE_TRAINER = "32x"
TOKEN_OFFSETS = {"asst": -2, "endheader": -1, "newline": 0}

# File paths
BASE_FILE = f"/workspace/results/4_diffing/1_{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_base.pt"
CHAT_FILE = f"/workspace/results/4_diffing/1_{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}_chat.pt"

# Output directory
SOURCE = f"{MODEL_TYPE}_trainer{SAE_TRAINER}_layer{SAE_LAYER}"
OUTPUT_DIR = Path(f"./{SOURCE}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Link
LLAMA_LINK_FORMAT = f"https://www.neuronpedia.org/llama3.1-8b/{SAE_LAYER}-llamascope-res-131k/"

print(f"Loading base model data from: {BASE_FILE}")
print(f"Loading chat model data from: {CHAT_FILE}")
print(f"Output directory: {OUTPUT_DIR}")

# Load the PyTorch files
base_data = torch.load(BASE_FILE)
chat_data = torch.load(CHAT_FILE)

print(f"\nBase data keys: {list(base_data.keys())}")
print(f"Chat data keys: {list(chat_data.keys())}")
print(f"Base metadata: {base_data['metadata']}")
print(f"Chat metadata: {chat_data['metadata']}")

# Verify token types match
base_tokens = [k for k in base_data.keys() if k != 'metadata']
chat_tokens = [k for k in chat_data.keys() if k != 'metadata']
print(f"\nBase token types: {base_tokens}")
print(f"Chat token types: {chat_tokens}")
assert base_tokens == chat_tokens, "Token types don't match between base and chat!"

token_types = base_tokens
print(f"Processing {len(token_types)} token types: {token_types}")

Loading base model data from: /workspace/results/4_diffing/1_llama_trainer32x_layer15_base.pt
Loading chat model data from: /workspace/results/4_diffing/1_llama_trainer32x_layer15_chat.pt
Output directory: llama_trainer32x_layer15

Base data keys: ['asst', 'endheader', 'newline', 'metadata']
Chat data keys: ['asst', 'endheader', 'newline', 'metadata']
Base metadata: {'source': 'llama_trainer32x_layer15_base', 'model_type': 'llama', 'model_ver': 'base', 'sae_layer': 15, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}
Chat metadata: {'source': 'llama_trainer32x_layer15_chat', 'model_type': 'llama', 'model_ver': 'chat', 'sae_layer': 15, 'sae_trainer': '32x', 'num_prompts': 1000, 'num_features': 131072, 'token_types': ['asst', 'endheader', 'newline']}

Base token types: ['asst', 'endheader', 'newline']
Chat token types: ['asst', 'endheader', 'newline']
Processing 3 token types: ['asst', 'endheader', 'newline']


In [6]:
def find_top_increases(base_data, chat_data, token_types, metric_name, top_k=100):
    """
    Find top features with greatest increase in specified metric from base to chat.
    
    Args:
        base_data: Base model data dictionary
        chat_data: Chat model data dictionary  
        token_types: List of token types to process
        metric_name: Name of metric to analyze ('all_mean', 'active_mean', 'sparsity')
        top_k: Number of top features to return
    
    Returns:
        DataFrame with top features and their differences
    """
    all_results = []
    
    for token_type in token_types:
        print(f"Processing {metric_name} for token type: {token_type}")
        
        # Get base and chat tensors for this metric
        base_tensor = base_data[token_type][metric_name]
        chat_tensor = chat_data[token_type][metric_name]
        
        # Calculate difference (chat - base)
        diff_tensor = chat_tensor - base_tensor
        
        # Get top k features with largest increases
        top_values, top_indices = torch.topk(diff_tensor, top_k)
        
        # Convert to lists for DataFrame
        top_values = top_values.tolist()
        top_indices = top_indices.tolist()
        
        # Get corresponding base and chat values
        base_values = base_tensor[top_indices].tolist()
        chat_values = chat_tensor[top_indices].tolist()
        
        # Create records for this token type
        for i, (feat_idx, diff_val, base_val, chat_val) in enumerate(zip(top_indices, top_values, base_values, chat_values)):
            record = {
                'rank': i + 1,
                'feature_id': feat_idx,
                'token': token_type,
                f'{metric_name}_base': base_val,
                f'{metric_name}_chat': chat_val,
                f'{metric_name}_diff': diff_val,
                'link': LLAMA_LINK_FORMAT + str(feat_idx)
            }
            all_results.append(record)
    
    # Convert to DataFrame and sort by difference (descending)
    df = pd.DataFrame(all_results)
    df = df.sort_values(f'{metric_name}_diff', ascending=False).reset_index(drop=True)
    
    print(f"Found {len(df)} total records for {metric_name}")
    return df

# Analyze all three metrics
metrics_to_analyze = ['all_mean', 'active_mean', 'sparsity']
results = {}

for metric in metrics_to_analyze:
    print(f"\n{'='*50}")
    print(f"Analyzing metric: {metric}")
    print(f"{'='*50}")
    
    df = find_top_increases(base_data, chat_data, token_types, metric, top_k=100)
    results[metric] = df
    
    # Show preview
    print(f"\nTop 5 features for {metric}:")
    print(df.head().to_string(index=False))
    
    print(f"\nSummary statistics for {metric}:")
    print(f"Max difference: {df[f'{metric}_diff'].max():.6f}")
    print(f"Min difference: {df[f'{metric}_diff'].min():.6f}")
    print(f"Mean difference: {df[f'{metric}_diff'].mean():.6f}")

print(f"\n{'='*50}")
print("Analysis complete!")
print(f"{'='*50}")


Analyzing metric: all_mean
Processing all_mean for token type: asst
Processing all_mean for token type: endheader
Processing all_mean for token type: newline
Found 300 total records for all_mean

Top 5 features for all_mean:
 rank  feature_id     token  all_mean_base  all_mean_chat  all_mean_diff                                                                  link
    1       11904      asst       0.005578       3.741617       3.736039  https://www.neuronpedia.org/llama3.1-8b/15-llamascope-res-131k/11904
    2       97377      asst       0.006156       3.023328       3.017172  https://www.neuronpedia.org/llama3.1-8b/15-llamascope-res-131k/97377
    1       92801   newline       6.986360       9.858937       2.872578  https://www.neuronpedia.org/llama3.1-8b/15-llamascope-res-131k/92801
    1      112879 endheader       0.000000       2.262328       2.262328 https://www.neuronpedia.org/llama3.1-8b/15-llamascope-res-131k/112879
    2       96419   newline       3.824219       6.014953  

In [7]:
# Save results to CSV files
print("Saving results to CSV files...")

for metric, df in results.items():
    # Create output filename
    output_file = OUTPUT_DIR / f"top_{metric}.csv"
    
    # Save to CSV
    df.to_csv(output_file, index=False)
    print(f"Saved {len(df)} records to: {output_file}")
    
    # Show file info
    file_size = output_file.stat().st_size / 1024  # KB
    print(f"  File size: {file_size:.1f} KB")

print(f"\n{'='*50}")
print("All results saved!")
print(f"{'='*50}")

# Summary of what was saved
print(f"\nSummary of saved files:")
for metric in metrics_to_analyze:
    output_file = OUTPUT_DIR / f"top_{metric}.csv"
    print(f"  {output_file.name}: Top 100 features per token type with greatest {metric} increase")

print(f"\nEach file contains:")
print(f"  - feature_id: SAE feature index")
print(f"  - token: Token position type (asst, endheader, newline)")
print(f"  - {metric}_base: Base model value")
print(f"  - {metric}_chat: Chat model value") 
print(f"  - {metric}_diff: Difference (chat - base)")
print(f"  - {metric}_ratio: Ratio (chat / base)")
print(f"  - rank: Rank within token type (1-100)")
print(f"  - model_type, sae_layer, sae_trainer: Configuration info")

print(f"\nTotal features analyzed: {base_data['metadata']['num_features']:,}")
print(f"Total records per file: {len(results['all_mean'])}")
print(f"Files ready for further analysis!")

Saving results to CSV files...
Saved 300 records to: llama_trainer32x_layer15/top_all_mean.csv
  File size: 39.6 KB
Saved 300 records to: llama_trainer32x_layer15/top_active_mean.csv
  File size: 32.0 KB
Saved 300 records to: llama_trainer32x_layer15/top_sparsity.csv
  File size: 39.8 KB

All results saved!

Summary of saved files:
  top_all_mean.csv: Top 100 features per token type with greatest all_mean increase
  top_active_mean.csv: Top 100 features per token type with greatest active_mean increase
  top_sparsity.csv: Top 100 features per token type with greatest sparsity increase

Each file contains:
  - feature_id: SAE feature index
  - token: Token position type (asst, endheader, newline)
  - sparsity_base: Base model value
  - sparsity_chat: Chat model value
  - sparsity_diff: Difference (chat - base)
  - sparsity_ratio: Ratio (chat / base)
  - rank: Rank within token type (1-100)
  - model_type, sae_layer, sae_trainer: Configuration info

Total features analyzed: 131,072
Total

## Plot Results

In [None]:
# Create interactive Plotly scatterplot
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px

print("Creating interactive Plotly scatterplot...")

# Create subplot with 3 panels
fig = make_subplots(
    rows=1, cols=3,
    subplot_titles=['Mean All', 'Mean Active', 'Sparsity'],
    specs=[[{"secondary_y": False}] * 3],
    horizontal_spacing=0.08
)

# Color mapping for token types
colors = {'asst': '#FF6B6B', 'endheader': '#4ECDC4', 'newline': '#45B7D1'}
symbols = {'asst': 'circle', 'endheader': 'square', 'newline': 'diamond'}

# Calculate unified ranges for mean_all and mean_active
all_mean_data = results['all_mean']
active_mean_data = results['active_mean']

# Find min and max values across both mean datasets
unified_min = min(
    all_mean_data['all_mean_base'].min(),
    all_mean_data['all_mean_chat'].min(),
    active_mean_data['active_mean_base'].min(),
    active_mean_data['active_mean_chat'].min()
)

unified_max = max(
    all_mean_data['all_mean_base'].max(),
    all_mean_data['all_mean_chat'].max(),
    active_mean_data['active_mean_base'].max(),
    active_mean_data['active_mean_chat'].max()
)

# Add small padding
padding = (unified_max - unified_min) * 0.05
unified_min -= padding
unified_max += padding

for i, metric in enumerate(['all_mean', 'active_mean', 'sparsity']):
    df = results[metric]
    
    for token in df['token'].unique():
        token_data = df[df['token'] == token]
        
        # Create hover text with detailed information
        hover_text = []
        for _, row in token_data.iterrows():
            hover_text.append(
                f"<b>Feature {row['feature_id']}</b><br>" +
                f"Token: {row['token']}<br>" +
                f"Rank: {row['rank']}<br>" +
                f"Base: {row[f'{metric}_base']:.6f}<br>" +
                f"Chat: {row[f'{metric}_chat']:.6f}<br>" +
                f"Difference: {row[f'{metric}_diff']:.6f}<br>"
            )
        
        fig.add_trace(
            go.Scatter(
                x=token_data[f'{metric}_base'],
                y=token_data[f'{metric}_chat'],
                mode='markers',
                name=f'{token}' if i == 0 else f'{token}',  # Only show legend for first subplot
                marker=dict(
                    color=colors[token],
                    size=8,
                    symbol=symbols[token],
                    line=dict(width=1, color='white')
                ),
                text=[f"Feature {fid}" for fid in token_data['feature_id']],
                hovertemplate=hover_text,
                showlegend=(i == 0),  # Only show legend for first subplot
                legendgroup=token  # Group legend items
            ),
            row=1, col=i+1
        )
    
    # Add diagonal "no change" line for each subplot
    if i < 2:  # Use unified range for first two subplots (mean_all and mean_active)
        line_min, line_max = unified_min, unified_max
    else:  # Use original range for sparsity
        max_val = max(df[f'{metric}_base'].max(), df[f'{metric}_chat'].max())
        min_val = min(df[f'{metric}_base'].min(), df[f'{metric}_chat'].min())
        line_min, line_max = min_val, max_val
    
    fig.add_trace(
        go.Scatter(
            x=[line_min, line_max],
            y=[line_min, line_max],
            mode='lines',
            line=dict(color='gray', dash='dash', width=2),
            name='No Change' if i == 0 else 'No Change',
            showlegend=(i == 0),
            legendgroup='nochange',
            hovertemplate="No change line<extra></extra>"
        ),
        row=1, col=i+1
    )

# Update layout
fig.update_layout(
    title={
        'text': f'Llama 3.1 8B Base → Chat Feature Activations<br><sub>Top 100 features per activation token position</sub>',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    height=500,
    width=1400,
    showlegend=True,
    legend=dict(
        orientation="v",
        yanchor="top",
        y=1,
        xanchor="left",
        x=1.02
    ),
    margin=dict(r=120)  # Make room for legend
)

# Update axes labels
fig.update_xaxes(title_text="Base Activation", row=1, col=2)

fig.update_yaxes(title_text="Chat Activation", row=1, col=1)

# Set unified axis ranges for mean_all and mean_active subplots
fig.update_xaxes(range=[unified_min, unified_max], row=1, col=1)  # Mean All
fig.update_yaxes(range=[unified_min, unified_max], row=1, col=1)
fig.update_xaxes(range=[unified_min, unified_max], row=1, col=2)  # Mean Active  
fig.update_yaxes(range=[unified_min, unified_max], row=1, col=2)

# Add grid
fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')

# Save the interactive plot
output_html = OUTPUT_DIR / "feature_increases.html"
fig.write_html(output_html)

print(f"Interactive plot saved to: {output_html}")
print(f"File size: {output_html.stat().st_size / 1024:.1f} KB")

# Show the plot
fig.show()

print(f"\nPlot features:")
print(f"- 3 subplots for all_mean, active_mean, and sparsity")
print(f"- Different colors and symbols for each token type")
print(f"- Hover info shows feature details")
print(f"- Gray diagonal line shows 'no change' reference")
print(f"- Points above the line increased from base to chat")
print(f"- Interactive zoom, pan, and selection capabilities")
print(f"- Mean All and Mean Active subplots now use unified axis ranges")