In [1]:
import sys
sys.path.append('..')
import torch
from transformer_lens import HookedTransformer
from trainers.scae import SCAESuite
from interp_utils import prepare_streaming_dataset

model_name = "roneneldan/TinyStories-33M"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
expansion = 2
k = 30

# Initialize model
model = HookedTransformer.from_pretrained(model_name, device=device)
tokenizer = model.tokenizer
# tinystories
dataset_name = "roneneldan/TinyStories"
n_features = expansion * model.cfg.d_model
suite = SCAESuite(
    model=model,
    k=k,
    n_features=n_features,
    connections=None,
    device=device,
)
# def prepare_streaming_dataset(tokenizer, dataset_name, max_length, batch_size, num_datapoints=None, num_cpu_cores=6):
debug = True
if debug:
    max_length = 64
    batch_size = 16
    num_datapoints = 100
else:
    max_length = 128
    batch_size = 16
    num_datapoints = 1000

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


In [2]:
# Create keys in desired order
dataset = prepare_streaming_dataset(tokenizer, dataset_name, max_length, batch_size, num_datapoints=num_datapoints)

keys = [f"attn_{i}" for i in range(4)] + [f"mlp_{i}" for i in range(4)]

# Just test w/ first 10 features of each
features_to_save = {key: [i for i in range(10)] for key in keys}
saved_feature_act_list = {
    key: [] for key in keys
}
saved_token_list = []

In [3]:
with torch.no_grad():
    for batch in dataset:
        _, cache = model.run_with_cache(batch)
        _, features = suite.forward_pruned(cache, return_features=True)
        for key in keys:
            saved_feature_act_list[key].append(features[key][..., features_to_save[key]].cpu())
        saved_token_list.append(batch.cpu())
saved_feature_act_list = {key: torch.cat(saved_feature_act_list[key], dim=0) for key in keys}
saved_token_list = torch.cat(saved_token_list, dim=0)

In [4]:
saved_feature_act_list["attn_0"].shape, saved_token_list.shape

(torch.Size([100, 64, 10]), torch.Size([100, 64]))

In [5]:
import numpy as np
from IPython.display import display, HTML
from einops import rearrange

def make_colorbar(min_value, max_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    # Add color bar
    colorbar = ""
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    # Do zero
    colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>'
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    return colorbar

def value_to_color(activation, max_value, min_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    if activation > positive_threshold:
        ratio = activation/max_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'
    elif activation < -negative_threshold:
        ratio = activation/min_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'
    else:
        text_color = "0,0,0"
        background_color = f'rgba({white},{white},{white},1)'
    return text_color, background_color

def convert_token_array_to_list(array):
    if isinstance(array, torch.Tensor):
        if array.dim() == 1:
            array = [array.tolist()]
        elif array.dim()==2:
            array = array.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(array, list):
        # ensure it's a list of lists
        if isinstance(array[0], int):
            array = [array]
    return array

def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None, model_type="causal", text_above_each_act=None):
    # text_spacing = "0.07em"
    text_spacing = "0.00em"
    toks = convert_token_array_to_list(toks)
    activations = convert_token_array_to_list(activations)
    # toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '\\n') for t in tok] for tok in toks]
    highlighted_text = []
    # Make background black
    # highlighted_text.append('<body style="background-color:black; color: white;">')
    highlighted_text.append("""
<body style="background-color: black; color: white;">
""")
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    if(logit_diffs is not None and model_type != "reward_model"):
        logit_max_value = max([max(activ) for activ in logit_diffs])
        logit_min_value = min([min(activ) for activ in logit_diffs])

    # Add color bar
    highlighted_text.append("Token Activations: " + make_colorbar(min_value, max_value))
    if(logit_diffs is not None and model_type != "reward_model"):
        highlighted_text.append('<div style="margin-top: 0.1em;"></div>')
        highlighted_text.append("Logit Diff: " + make_colorbar(logit_min_value, logit_max_value))
    
    highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
    for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
        if(text_above_each_act is not None):
            highlighted_text.append(f'<span>{text_above_each_act[seq_ind]}</span>')
        for act_ind, (a, t) in enumerate(zip(act, tok)):
            if(logit_diffs is not None and model_type != "reward_model"):
                highlighted_text.append('<div style="display: inline-block;">')
            text_color, background_color = value_to_color(a, max_value, min_value)
            highlighted_text.append(f'<span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{t.replace(" ", "&nbsp")}</span>')
            if(logit_diffs is not None and model_type != "reward_model"):
                logit_diffs_act = logit_diffs[seq_ind][act_ind]
                _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)
                highlighted_text.append(f'<div style="display: block; margin-right: {text_spacing}; height: 10px; background-color:{logit_background_color}; text-align: center;"></div></div>')
        if(logit_diffs is not None and model_type=="reward_model"):
            reward_change = logit_diffs[seq_ind].item()
            text_color, background_color = value_to_color(reward_change, 10, -10)
            highlighted_text.append(f'<br><span>Reward: </span><span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{reward_change:.2f}</span>')
        highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
        # highlighted_text.append('<br><br>')
    # highlighted_text.append('</body>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text
def save_token_display(tokens, activations, tokenizer, path, save=True, logit_diffs=None, show=False, model_type="causal"):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs, model_type=model_type)
    # if(save):
    #     imgkit.from_string(html, path)
    # if(show):
    return display(HTML(html))

def get_feature_indices(feature_activations, k=10, setting="max"):
    # Sort the features by activation, get the indices
    batch_size, seq_len = feature_activations.shape
    feature_activations = rearrange(feature_activations, 'b s -> (b s)')
    if setting=="max":
        found_indices = torch.argsort(feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(feature_activations)
        min_value = torch.min(feature_activations)
        max_value = torch.max(feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    d_indices = found_indices // seq_len
    s_indices = found_indices % seq_len
    return d_indices, s_indices

def get_feature_datapoints(d_idx, seq_pos_idx, all_activations, all_tokens, tokenizer):
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for md, s_ind in zip(d_idx, seq_pos_idx):
        md = int(md)
        s_ind = int(s_ind)
        # full_tok = torch.tensor(dataset[md]["input_ids"])
        
        full_tok = all_tokens[md]
        # [tokenizer.decode(t) for t in tokens[0]]

        full_text.append(tokenizer.decode(full_tok))
        tok = full_tok[:s_ind+1]
        # tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md].tolist())
        partial_activations.append(all_activations[md][:s_ind+1].tolist())
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations


In [6]:
from IPython.display import display, HTML
import torch

def create_token_display_html(top_ind, top_val, bot_ind, bot_val, tokenizer, k=10):
    """
    Create an HTML display of top and bottom tokens with their values.
    
    Parameters:
    - top_ind: indices of top tokens
    - top_val: values of top tokens
    - bot_ind: indices of bottom tokens
    - bot_val: values of bottom tokens
    - tokenizer: tokenizer to decode indices
    - k: number of tokens to display (default 10)
    """
    
    # Decode tokens
    top_text = [tokenizer.decode(tok).replace(" ", "_").replace("\n", "\\newline") for tok in top_ind[:k]]
    bot_text = [tokenizer.decode(tok).replace(" ", "_").replace("\n", "\\newline") for tok in bot_ind[:k]]
    
    # Create HTML template with direct background color attributes
    html_template = """
    <style>
        .token-table {{
            font-family: Arial, sans-serif;
            border-collapse: collapse;
            width: 100%;
            margin-top: 10px;
        }}
        .token-table td {{
            padding: 8px;
            border-bottom: 1px solid #ddd;
        }}
        .title {{
            font-size: 20px;
            font-weight: bold;
            text-align: center;
            margin-bottom: 10px;
        }}
    </style>
    
    <div class="title">Logit Lens</div>
    
    <table class="token-table">
        <tr>
            <td><b>Top Token</b></td>
            <td><b>Value</b></td>
            <td><b>Bottom Token</b></td>
            <td><b>Value</b></td>
        </tr>
    """
    
    # Add rows with colored backgrounds applied to spans instead of cells
    for i in range(k):
        html_template += f"""
        <tr>
            <td>{top_text[i]}</td>
            <td><span style="background-color: #0000FF; color: white; padding: 2px 4px; display: inline-block;"><b>{top_val[i].item():.3f}</b></span></td>
            <td>{bot_text[i]}</td>
            <td><span style="background-color: #FF0000; color: white; padding: 2px 4px; display: inline-block;"><b>{bot_val[i].item():.3f}</b></span></td>
        </tr>
        """
    
    html_template += "</table>"
    
    # return HTML(html_template)
    return html_template

# Usage example:
# display(create_token_display(top_ind, top_val, bot_ind, bot_val, tokenizer))

In [7]:
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
from IPython.display import display, HTML

def display_matplotlib_figure(fig, width=None, height=None):
    """
    Convert a matplotlib figure to an HTML img tag for display in Jupyter notebooks
    
    Parameters:
    - fig: matplotlib figure to display
    - width: optional width (in pixels)
    - height: optional height (in pixels)
    """
    # Save the figure to a PNG in memory
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    
    # Encode the PNG as base64
    img_str = base64.b64encode(buf.read()).decode('utf-8')
    
    # Set the width and height attributes if provided
    style = ""
    if width is not None:
        style += f"width:{width}px;"
    if height is not None:
        style += f"height:{height}px;"
    
    style_attr = f' style="{style}"' if style else ''
    
    # Generate the HTML
    html = f'<img src="data:image/png;base64,{img_str}"{style_attr}/>'
    
    # return HTML(html)
    return html

# Example usage with a histogram
def create_histogram_html(data, bins=30, title="Histogram", width=600, height=400):
    # Create a histogram
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.hist(data, bins=bins, alpha=0.7, color='skyblue', edgecolor='black')
    ax.set_title(title)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.grid(alpha=0.3)
    
    # Convert to HTML and display
    plt.close(fig)  # Close the figure to prevent it from displaying twice
    return display_matplotlib_figure(fig, width=width, height=height)

# Example with random data
# random_data = np.random.normal(0, 1, 1000)
# display(create_and_display_histogram(random_data, title="Normal Distribution"))

In [8]:
import json
import os

def generate_enhanced_viewer(keys, features_to_save, saved_feature_act_list, saved_token_list, tokenizer, model, suite, output_path="enhanced_token_viewer.html"):
    """
    Generate an enhanced HTML viewer for token activations with multiple visualizations.
    
    This creates a standalone HTML file with embedded JavaScript that shows:
    1. Logit lens and histogram visualizations side-by-side in the upper section
    2. Token activations in the bottom section
    
    Args:
        keys: List of model keys or identifiers
        features_to_save: Dictionary mapping keys to lists of feature indices
        saved_feature_act_list: Dictionary mapping keys to feature activation tensors
        saved_token_list: List of tokens for each example
        tokenizer: The tokenizer used to convert tokens to text
        model: The model object (for logit lens)
        suite: The suite object containing feature decoders
        output_path: Path where the HTML file will be saved
        
    Returns:
        The path to the generated HTML file
    """
    
    # Number of examples to show per feature
    num_feature_datapoints = 10
    
    # Create the data structure for the JavaScript
    data = {
        "keys": keys,
        "features": features_to_save,
        "content": {}
    }
    
    # Generate HTML content for each key-feature pair
    for key in keys:
        data["content"][key] = {}
        features_for_this_key = features_to_save[key]
        
        for feature in features_for_this_key:
            # Cast feature to int if it's a string (fix for the TypeError)
            feature_idx = int(feature) if isinstance(feature, str) else feature
            
            # Initialize content container for this feature
            data["content"][key][str(feature)] = {
                "tokenActivations": "",
                "logitLens": "",
                "histogram": ""
            }
            
            # Get feature activations for this feature (use feature_idx for tensor indexing)
            feature_activations = saved_feature_act_list[key][..., feature_idx]
            
            # Get indices of examples with highest activations
            d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="max")
            
            # Get data for these examples
            text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(
                d_idx, seq_idx, feature_activations, saved_token_list, tokenizer
            )
            
            # Generate token activations HTML
            token_html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
            data["content"][key][str(feature)]["tokenActivations"] = token_html
            
            # Generate logit lens HTML
            try:
                feature_decoder = suite.aes["attn_0"].decoder.weight[:, feature_idx]
                unembd = model.W_U
                logit_lens = feature_decoder @ unembd
                top_val, top_ind = torch.topk(logit_lens, k=10, dim=-1)
                bot_val, bot_ind = torch.topk(logit_lens, k=10, dim=-1, largest=False)
                
                # Create logit lens HTML
                logit_lens_html = create_token_display_html(top_ind, top_val, bot_ind, bot_val, tokenizer)
                data["content"][key][str(feature)]["logitLens"] = logit_lens_html
            except Exception as e:
                # In case of any errors, use a placeholder
                data["content"][key][str(feature)]["logitLens"] = f"<div class='error-panel'>Logit lens visualization unavailable: {str(e)}</div>"
            
            # Generate histogram HTML
            try:
                nz_feature_act = feature_activations[feature_activations != 0]
                frequency = nz_feature_act.numel() / feature_activations.numel()
                
                # Create histogram HTML
                hist_html = create_histogram_html(nz_feature_act.numpy(), 
                                                title=f"Activation Frequency {frequency*100:.2f}%")
                data["content"][key][str(feature)]["histogram"] = hist_html
            except Exception as e:
                # In case of any errors, use a placeholder
                data["content"][key][str(feature)]["histogram"] = f"<div class='error-panel'>Histogram visualization unavailable: {str(e)}</div>"
    
    # Convert features to strings for JSON serialization
    for key in data["features"]:
        data["features"][key] = [str(f) for f in data["features"][key]]
    
    # Template for the HTML file with fixed CSS for side-by-side panels
    html_template = """<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Enhanced Token Activation Viewer</title>
    <style>
        /* Reset and base styles */
        * {{
            box-sizing: border-box;
            margin: 0;
            padding: 0;
        }}
        
        body {{
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            line-height: 1.6;
            color: #333;
            background-color: #f5f5f5;
            padding: 20px;
        }}
        
        .container {{
            max-width: 1200px;
            margin: 0 auto;
            background-color: #fff;
            border-radius: 8px;
            box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
            overflow: hidden;
        }}
        
        /* Header styles */
        .header {{
            background-color: #2c3e50;
            color: white;
            padding: 20px;
            text-align: center;
        }}
        
        .header h1 {{
            margin: 0;
            font-size: 24px;
        }}
        
        /* Controls section - UPDATED for side-by-side layout */
        .controls {{
            background-color: #f8f9fa;
            padding: 20px;
            border-bottom: 1px solid #e9ecef;
            display: flex;
            flex-wrap: wrap;
            gap: 20px;
        }}
        
        .control-group {{
            flex: 1;
            min-width: 200px;
        }}
        
        label {{
            display: block;
            margin-bottom: 8px;
            font-weight: 600;
            color: #495057;
        }}
        
        select {{
            width: 100%;
            padding: 10px;
            border: 1px solid #ced4da;
            border-radius: 4px;
            background-color: #fff;
            font-size: 16px;
        }}
        
        /* Navigation section */
        .navigation {{
            display: flex;
            justify-content: space-between;
            padding: 15px 20px;
            background-color: #f8f9fa;
            border-bottom: 1px solid #e9ecef;
        }}
        
        .button {{
            padding: 8px 16px;
            background-color: #007bff;
            color: white;
            border: none;
            border-radius: 4px;
            cursor: pointer;
            font-weight: 500;
        }}
        
        .button:hover {{
            background-color: #0069d9;
        }}
        
        .button:disabled {{
            background-color: #cccccc;
            cursor: not-allowed;
        }}
        
        /* Multi-panel layout - FIXED for side-by-side display */
        .upper-panels {{
            display: flex;
            width: 100%;
            padding: 20px;
            margin: 0;
        }}
        
        .panel-row {{
            display: flex;
            width: 100%;
            gap: 20px;
        }}
        
        .panel {{
            flex: 1;
            border: 1px solid #ddd;
            border-radius: 4px;
        }}
        
        .panel-header {{
            background-color: #f1f1f1;
            padding: 10px;
            font-weight: bold;
            border-bottom: 1px solid #ddd;
        }}
        
        .panel-content {{
            padding: 15px;
            overflow: auto;
            min-height: 300px;
        }}
        
        /* Content section */
        .token-activations-container {{
            background-color: black;
            color: white;
            min-height: 400px;
            padding: 20px;
            overflow: auto;
            margin: 0 20px 20px 20px;
            border-radius: 4px;
        }}
        
        /* Loading indicator */
        .loading {{
            display: flex;
            justify-content: center;
            align-items: center;
            height: 300px;
            font-size: 18px;
            color: #6c757d;
        }}
        
        /* Footer section */
        .footer {{
            background-color: #f8f9fa;
            text-align: center;
            padding: 15px;
            color: #6c757d;
            border-top: 1px solid #e9ecef;
        }}
        
        /* Content styles for your token activations */
        .content {{
            font-family: monospace;
            line-height: 1.4;
            white-space: pre-wrap;
        }}

        /* Error panel */
        .error-panel {{
            background-color: #ffe6e6;
            border: 1px solid #ffcccc;
            color: #990000;
            padding: 15px;
            border-radius: 4px;
            text-align: center;
        }}
        
        /* Show/hide elements */
        .hidden {{
            display: none !important;
        }}
        
        /* Token table (for logit lens) */
        .token-table {{
            font-family: Arial, sans-serif;
            border-collapse: collapse;
            width: 100%;
            margin-top: 10px;
        }}
        
        .token-table td {{
            padding: 8px;
            border-bottom: 1px solid #ddd;
        }}
        
        .title {{
            font-size: 20px;
            font-weight: bold;
            text-align: center;
            margin-bottom: 10px;
        }}
        
        /* Fix for histogram images */
        .histogram-container {{
            display: flex;
            justify-content: center;
            align-items: center;
        }}
        
        .histogram-container img {{
            max-width: 100%;
            max-height: 280px;
            width: auto;
            height: auto;
            object-fit: contain;
        }}
    </style>
</head>
<body>
    <div class="container">
        <div class="header">
            <h1>Enhanced Token Activation Viewer</h1>
        </div>
        
        <div class="controls">
            <div class="control-group">
                <label for="keySelect">Select Model/Key:</label>
                <select id="keySelect">
                    <option value="" disabled selected>Choose a model/key</option>
                    <!-- Options will be populated by JavaScript -->
                </select>
            </div>
            <div class="control-group">
                <label for="featureSelect">Select Feature:</label>
                <select id="featureSelect" disabled>
                    <option value="" disabled selected>Select a model/key first</option>
                    <!-- Options will be populated by JavaScript -->
                </select>
            </div>
        </div>
        
        <div class="navigation">
            <button id="prevFeature" class="button" disabled>Previous Feature</button>
            <div>
                <span id="featureInfo">No feature selected</span>
            </div>
            <button id="nextFeature" class="button" disabled>Next Feature</button>
        </div>
        
        <!-- Upper panels for logit lens and histogram - SIMPLE TABLE LAYOUT -->
        <table class="upper-panels-table">
            <tr>
                <!-- Logit lens panel -->
                <td class="panel-cell">
                    <div class="panel">
                        <div class="panel-header">Logit Lens</div>
                        <div id="logitLensPanel" class="panel-content">
                            <div class="loading">Select a feature to view logit lens</div>
                        </div>
                    </div>
                </td>
                
                <!-- Histogram panel -->
                <td class="panel-cell">
                    <div class="panel">
                        <div class="panel-header">Activation Histogram</div>
                        <div id="histogramPanel" class="panel-content">
                            <div class="loading">Select a feature to view histogram</div>
                        </div>
                    </div>
                </td>
            </tr>
        </table>
        
        <!-- Token activations section -->
        <div class="token-activations-container">
            <div id="loadingIndicator" class="loading">
                <p>Select a model/key and feature to view token activations</p>
            </div>
            <div id="contentDisplay" class="content hidden"></div>
        </div>
        
        <div class="footer">
            <p>Enhanced Token Activation Viewer | Features Explorer</p>
        </div>
    </div>
    
    <script>
        // DOM elements
        const keySelect = document.getElementById('keySelect');
        const featureSelect = document.getElementById('featureSelect');
        const prevFeatureBtn = document.getElementById('prevFeature');
        const nextFeatureBtn = document.getElementById('nextFeature');
        const featureInfo = document.getElementById('featureInfo');
        const loadingIndicator = document.getElementById('loadingIndicator');
        const contentDisplay = document.getElementById('contentDisplay');
        const logitLensPanel = document.getElementById('logitLensPanel');
        const histogramPanel = document.getElementById('histogramPanel');
        
        // Data structure with all token activation content
        const data = {DATA_PLACEHOLDER};
        
        // Initialize the viewer
        function initViewer() {
            // Populate the keys dropdown
            data.keys.forEach(key => {
                const option = document.createElement('option');
                option.value = key;
                option.textContent = key;
                keySelect.appendChild(option);
            });
            
            // Event listeners
            keySelect.addEventListener('change', handleKeyChange);
            featureSelect.addEventListener('change', handleFeatureChange);
            prevFeatureBtn.addEventListener('click', showPreviousFeature);
            nextFeatureBtn.addEventListener('click', showNextFeature);
            
            // Auto-select first key and feature
            if (data.keys.length > 0) {
                // Select first key
                keySelect.value = data.keys[0];
                handleKeyChange();
                
                // Select first feature if available
                if (data.features[data.keys[0]] && data.features[data.keys[0]].length > 0) {
                    setTimeout(() => {
                        featureSelect.value = data.features[data.keys[0]][0];
                        handleFeatureChange();
                    }, 100); // Small delay to ensure feature dropdown is populated
                }
            }
        }
        
        // Handle key selection change
        function handleKeyChange() {
            const selectedKey = keySelect.value;
            
            // Clear and reset feature dropdown
            featureSelect.innerHTML = '';
            featureSelect.disabled = !selectedKey;
            
            if (selectedKey && data.features[selectedKey]) {
                // Add default option
                const defaultOption = document.createElement('option');
                defaultOption.value = '';
                defaultOption.textContent = `Select a feature`;
                defaultOption.disabled = true;
                defaultOption.selected = true;
                featureSelect.appendChild(defaultOption);
                
                // Populate feature dropdown
                data.features[selectedKey].forEach(feature => {
                    const option = document.createElement('option');
                    option.value = feature;
                    option.textContent = `Feature ${feature}`;
                    featureSelect.appendChild(option);
                });
                
                // Clear content display
                contentDisplay.classList.add('hidden');
                loadingIndicator.classList.remove('hidden');
                loadingIndicator.innerHTML = '<p>Select a feature to view token activations</p>';
                
                // Reset feature info
                featureInfo.textContent = 'No feature selected';
                
                // Clear panels
                logitLensPanel.innerHTML = '<div class="loading">Select a feature to view logit lens</div>';
                histogramPanel.innerHTML = '<div class="loading">Select a feature to view histogram</div>';
                
                // Disable navigation buttons
                prevFeatureBtn.disabled = true;
                nextFeatureBtn.disabled = true;
            }
        }
        
        // Handle feature selection change
        function handleFeatureChange() {
            const selectedKey = keySelect.value;
            const selectedFeature = featureSelect.value;
            
            if (selectedKey && selectedFeature && data.content[selectedKey] && data.content[selectedKey][selectedFeature]) {
                // Show loading indicators
                loadingIndicator.classList.remove('hidden');
                loadingIndicator.innerHTML = '<p>Loading content...</p>';
                contentDisplay.classList.add('hidden');
                
                logitLensPanel.innerHTML = '<div class="loading">Loading logit lens...</div>';
                histogramPanel.innerHTML = '<div class="loading">Loading histogram...</div>';
                
                // Simulate loading delay (remove in production)
                setTimeout(() => {
                    // Update feature info
                    featureInfo.textContent = `${selectedKey} - Feature ${selectedFeature}`;
                    
                    // Get content for this feature
                    const featureContent = data.content[selectedKey][selectedFeature];
                    
                    // Display token activations
                    contentDisplay.innerHTML = featureContent.tokenActivations;
                    contentDisplay.classList.remove('hidden');
                    loadingIndicator.classList.add('hidden');
                    
                    // Display logit lens
                    logitLensPanel.innerHTML = featureContent.logitLens;
                    
                    // Display histogram
                    histogramPanel.innerHTML = featureContent.histogram;
                    
                    // Update navigation buttons
                    updateNavigationButtons();
                    
                    // Balance panel heights
                    equalizeHeights();
                }, 100);
            }
        }
        
        // Show previous feature
        function showPreviousFeature() {
            const selectedKey = keySelect.value;
            const features = data.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            if (currentIndex > 0) {
                featureSelect.value = features[currentIndex - 1];
                handleFeatureChange();
            }
        }
        
        // Show next feature
        function showNextFeature() {
            const selectedKey = keySelect.value;
            const features = data.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            if (currentIndex < features.length - 1) {
                featureSelect.value = features[currentIndex + 1];
                handleFeatureChange();
            }
        }
        
        // Update navigation button states
        function updateNavigationButtons() {
            const selectedKey = keySelect.value;
            const features = data.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            
            prevFeatureBtn.disabled = currentIndex <= 0;
            nextFeatureBtn.disabled = currentIndex >= features.length - 1;
        }
        
        // Function to make histogram and logit lens panels equal height
        function equalizeHeights() {
            const logitPanel = document.getElementById('logitLensPanel');
            const histPanel = document.getElementById('histogramPanel');
            
            if (logitPanel && histPanel) {
                // Reset heights first
                logitPanel.style.height = 'auto';
                histPanel.style.height = 'auto';
                
                // Get the logit lens height
                const logitHeight = logitPanel.scrollHeight;
                
                // Apply sizing to histogram container
                const histImage = histPanel.querySelector('.histogram-container img');
                if (histImage) {
                    // Set max-height for the image to be proportional to logit panel
                    // Using 80% of logit height to leave room for padding
                    histImage.style.maxHeight = (logitHeight * 0.8) + 'px';
                }
                
                // Force same height on both panels after a small delay to ensure rendering
                setTimeout(() => {
                    // Apply the taller height to both panels
                    const maxHeight = Math.max(logitPanel.scrollHeight, histPanel.scrollHeight);
                    logitPanel.style.height = maxHeight + 'px';
                    histPanel.style.height = maxHeight + 'px';
                }, 50);
            }
        }
        
        // Add window resize handler to maintain equal heights
        window.addEventListener('resize', equalizeHeights);
        
        // Initialize the viewer when the page loads
        window.addEventListener('DOMContentLoaded', initViewer);
    </script>
</body>
</html>
    """
    
    # Insert data into the template
    json_data = json.dumps(data)
    html_content = html_template.replace('{DATA_PLACEHOLDER}', json_data)
    
    # Create the output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
    
    # Write the HTML file
    with open(output_path, 'w') as f:
        f.write(html_content)
    
    print(f"Enhanced token activation viewer created at: {output_path}")
    
    # If running in notebook, provide a clickable link
    try:
        from IPython.display import HTML, display
        display(HTML(f'<a href="{output_path}" target="_blank">Open Enhanced Token Activation Viewer</a>'))
    except:
        pass
    
    return output_path

In [9]:
generate_enhanced_viewer(
    model=model,
    suite=suite,
    keys=keys,
    features_to_save=features_to_save,
    saved_feature_act_list=saved_feature_act_list,
    saved_token_list=saved_token_list,
    tokenizer=tokenizer
)

Enhanced token activation viewer created at: enhanced_token_viewer.html


'enhanced_token_viewer.html'

In [10]:
from IPython.display import display, HTML
# min_idx = 0
# max_idx = min_idx + 10
# features = [i for i in range(min_idx, max_idx)]
# features = top_i.tolist()
num_feature_datapoints = 10 # how many examples/expert
for key in keys:
    print(f"Key: {key}")
    features_for_this_key = features_to_save[key]
    for feature in features_for_this_key:
        feature_activations = saved_feature_act_list[key][..., feature]
        d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="max")
        # uniform_indices = get_feature_indices(feature, feature_activations, k=num_feature_datapoints, setting="max")
        text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(d_idx, seq_idx, feature_activations, saved_token_list, tokenizer)
        html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
        display(HTML(html))
    break

Key: attn_0


TypeError: new(): invalid data type 'str'