## Token-Level Activation Analysis

# SAE Feature Analysis Visualization

This notebook creates an interactive heatmap to visualize the mean activation values for different features across prompt labels.

In [19]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import json
from collections import defaultdict

## Load and Process Data

In [20]:
# Load the CSV file with feature descriptions
desc_df = pd.read_csv('personal_general_desc.csv')
print(f"Loaded {len(desc_df)} feature descriptions")
print("\nColumns:", desc_df.columns.tolist())
print("\nFirst few rows:")
print(desc_df.head())

Loaded 19 feature descriptions

Columns: ['feature_id', 'general_activation_mean', 'general_activation_max', 'general_activation_min', 'personal_mean', 'personal_cohens_d', 'chat_desc', 'pt_desc', 'type', 'source', 'token', 'num_prompts', 'link', 'claude_completion', 'claude_desc', 'claude_type']

First few rows:
   feature_id  general_activation_mean  general_activation_max  \
0       18703                 1.614838                2.291193   
1       30068                 1.336165                1.336165   
2       81528                 5.142055               10.429818   
3      130794                 1.645730                1.965041   
4       49925                 1.565177                1.917678   

   general_activation_min  personal_mean  personal_cohens_d  \
0                1.302976       0.832151           1.319064   
1                1.336165       0.478012           0.946825   
2                1.297420       0.234832           0.602961   
3                1.246574       0.12

In [21]:
# Load the JSONL file with prompt data
prompt_data = []
with open('personal_general_prompts.jsonl', 'r') as f:
    for line in f:
        prompt_data.append(json.loads(line.strip()))

print(f"Loaded {len(prompt_data)} feature records")
print("\nFirst record keys:", list(prompt_data[0].keys()))
print("\nFirst active prompt keys:", list(prompt_data[0]['active_prompts'][0].keys()))

Loaded 19 feature records

First record keys: ['feature_id', 'token', 'source', 'active_prompts']

First active prompt keys: ['prompt_id', 'prompt_text', 'prompt_label', 'prompt_feature_activation', 'tokenized_prompt', 'tokens']


## Data Processing for Heatmap

In [22]:
# Process data to create a structure for the heatmap
# We need: feature_id, prompt_label, mean_activation

# First, let's explore the prompt labels
all_labels = set()
for record in prompt_data:
    for prompt in record['active_prompts']:
        all_labels.add(prompt['prompt_label'])

print("Available prompt labels:", sorted(all_labels))

# Count prompts per label
label_counts = defaultdict(int)
for record in prompt_data:
    for prompt in record['active_prompts']:
        label_counts[prompt['prompt_label']] += 1

print("\nPrompt counts by label:")
for label, count in sorted(label_counts.items()):
    print(f"  {label}: {count}")

Available prompt labels: ['analysis', 'code', 'creative', 'math', 'medical', 'therapy', 'trivia']

Prompt counts by label:
  analysis: 69
  code: 41
  creative: 74
  math: 41
  medical: 31
  therapy: 44
  trivia: 31


In [23]:
# Create the heatmap data structure
heatmap_data = []

for record in prompt_data:
    feature_id = record['feature_id']
    
    # Group activations by prompt_label for this feature
    label_activations = defaultdict(list)
    
    for prompt in record['active_prompts']:
        prompt_label = prompt['prompt_label']
        activation = prompt['prompt_feature_activation']
        label_activations[prompt_label].append(activation)
    
    # Calculate mean activation for each label
    for label, activations in label_activations.items():
        mean_activation = np.mean(activations)
        heatmap_data.append({
            'feature_id': feature_id,
            'prompt_label': label,
            'mean_activation': mean_activation,
            'num_prompts': len(activations)
        })

# Convert to DataFrame
heatmap_df = pd.DataFrame(heatmap_data)
print(f"Created heatmap data with {len(heatmap_df)} rows")
print("\nSample data:")
print(heatmap_df.head(10))

Created heatmap data with 42 rows

Sample data:
   feature_id prompt_label  mean_activation  num_prompts
0        6704       trivia         1.513002            3
1       16030         code         1.417526            1
2       18703     creative         1.680974            9
3       18703       trivia         1.317224            2
4       30068     creative         1.336165            1
5       49925     analysis         1.537387            6
6       49925         code         1.488580            2
7       49925     creative         1.616846           10
8       49925         math         1.620788            2
9       49925      medical         1.411191            2


In [24]:
# Create a pivot table for the heatmap
pivot_df = heatmap_df.pivot(index='feature_id', columns='prompt_label', values='mean_activation')

print(f"Pivot table shape: {pivot_df.shape}")
print("\nFeature IDs:", sorted(pivot_df.index.tolist()))
print("\nPrompt labels:", sorted(pivot_df.columns.tolist()))
print("\nData summary:")
print(pivot_df.describe())

Pivot table shape: (19, 7)

Feature IDs: [6704, 9953, 16030, 18703, 27476, 29717, 30068, 47776, 48045, 49123, 49925, 59035, 68574, 81528, 88910, 90235, 91607, 126716, 130794]

Prompt labels: ['analysis', 'code', 'creative', 'math', 'medical', 'therapy', 'trivia']

Data summary:
prompt_label  analysis      code   creative      math   medical   therapy  \
count         8.000000  5.000000  12.000000  3.000000  4.000000  3.000000   
mean          3.227787  2.340470   2.215122  3.858388  1.636622  2.372225   
std           2.916427  1.234842   1.490797  2.372260  1.053856  3.502863   
min           0.353672  1.417526   0.416995  1.620788  0.242600  0.310405   
25%           1.253086  1.488580   1.379261  2.614796  1.119044  0.349975   
50%           2.519060  1.609134   1.648910  3.608804  1.871319  0.389546   
75%           4.328356  2.942980   2.727870  4.977188  2.388897  3.403135   
max           8.808014  4.244131   6.233634  6.345573  2.561250  6.416724   

prompt_label    trivia  
co

## Enhanced Heatmap with Feature Descriptions

In [None]:
# Create a fixed-size heatmap with square cells and text annotations
# Swap axes: features on x-axis, prompt categories on y-axis
feature_labels = [str(fid) for fid in pivot_df.index]
prompt_labels = list(pivot_df.columns)

# Transpose the data for swapped axes
transposed_data = pivot_df.T.values
transposed_text = np.round(transposed_data, 3)

# Create text array with NaN values as empty strings
text_display = []
for i in range(transposed_text.shape[0]):
    row_text = []
    for j in range(transposed_text.shape[1]):
        if np.isnan(transposed_text[i, j]):
            row_text.append("")
        else:
            row_text.append(str(transposed_text[i, j]))
    text_display.append(row_text)

# Create the heatmap with text annotations
fig_enhanced = go.Figure(data=go.Heatmap(
    z=transposed_data,
    x=feature_labels,
    y=prompt_labels,
    colorscale='Viridis',
    hoverongaps=False,
    text=text_display,
    texttemplate="%{text}",
    textfont={"size": 10, "color": "white"},
    showscale=True
))

# Calculate dimensions for square cells
num_features = len(feature_labels)
num_labels = len(prompt_labels)
cell_size = 60  # pixels per cell
width = num_features * cell_size + 200  # extra space for labels and colorbar
height = num_labels * cell_size + 150  # extra space for title and labels

# Update layout for square cells and remove grid
fig_enhanced.update_layout(
    title={
        'text': 'Personal Feature Activation Heatmap by Prompt Category',
        'x': 0.5,
        'xanchor': 'center',
        'font': {'size': 16}
    },
    xaxis={
        'title': 'Feature ID',
        'tickangle': 45,
        'tickfont': {'size': 10},
        'side': 'bottom',
        'showgrid': False,
        'zeroline': False
    },
    yaxis={
        'title': 'Prompt Category',
        'tickfont': {'size': 12},
        'autorange': 'reversed',  # Put first category at top
        'showgrid': False,
        'zeroline': False
    },
    width=width,
    height=height,
    coloraxis_colorbar={
        'title': 'Mean Activation',
        'titleside': 'right'
    },
    plot_bgcolor='white'
)

# Add enhanced hover information (adjusted for transposed data)
hover_text = []
for i, prompt_label in enumerate(prompt_labels):
    row_text = []
    for j, feature_id in enumerate(pivot_df.index):
        activation = pivot_df.loc[feature_id, prompt_label]
        
        # Get feature description if available
        if feature_id in desc_dict:
            desc = desc_dict[feature_id]['claude_desc']
            feature_type = desc_dict[feature_id]['claude_type']
        else:
            desc = 'No description available'
            feature_type = 'Unknown'
        
        if pd.notna(activation):
            text = f"Feature ID: {feature_id}<br>" + \
                   f"Prompt Category: {prompt_label}<br>" + \
                   f"Mean Activation: {activation:.4f}<br>" + \
                   f"Feature Type: {feature_type}<br>" + \
                   f"Description: {desc[:150]}{'...' if len(desc) > 150 else ''}"
        else:
            text = f"Feature ID: {feature_id}<br>" + \
                   f"Prompt Category: {prompt_label}<br>" + \
                   f"No activation data<br>" + \
                   f"Feature Type: {feature_type}<br>" + \
                   f"Description: {desc[:150]}{'...' if len(desc) > 150 else ''}"
        
        row_text.append(text)
    hover_text.append(row_text)

# Update hover template
fig_enhanced.update_traces(
    hovertemplate='%{hovertext}<extra></extra>',
    hovertext=hover_text
)

fig_enhanced.show()

# save the figure as html file
fig_enhanced.write_html('prompt_activation_heatmap.html')
fig_enhanced.write_image('prompt_activation_heatmap.png')