In [1]:
import sys
sys.path.append('..')
import torch
from transformer_lens import HookedTransformer
from trainers.scae import SCAESuite
from interp_utils import prepare_streaming_dataset

model_name = "roneneldan/TinyStories-33M"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
expansion = 2
k = 30

# Initialize model
model = HookedTransformer.from_pretrained(model_name, device=device)
tokenizer = model.tokenizer
# tinystories
dataset_name = "roneneldan/TinyStories"
get_first_n = 10 # get first n features of each key

debug = True
if debug:
    max_length = 64
    batch_size = 16
    num_datapoints = 100
else:
    max_length = 64
    batch_size = 64
    num_datapoints = 5_000

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer


In [2]:
import pickle
c = 10 # 10 30 100 300
mse = False
connections_name = f"../tinystories_connections/top_connections_{c}.pkl"
with open(connections_name, "rb") as f:
    connections = pickle.load(f)
if mse:
    scae_repo_id = f"jacobcd52/TinyStories-33M_scae_{c}_mse"
else:
    scae_repo_id = f"jacobcd52/TinyStories-33M_scae_{c}_ce"
suite = SCAESuite.from_pretrained(
    repo_id=scae_repo_id,
    model=model,
    device=device,
)
model_save_name = scae_repo_id.split("/")[-1]

In [3]:
# connection_vals = connections.copy()
# that makes a shallow copy
connection_vals = {key: {key2: torch.zeros(connections[key][key2].shape).to(device) for key2 in connections[key]} for key in connections}
tensor_shape = connection_vals["mlp_0"]["attn_0"].shape
dummy_connection_vals = torch.rand(tensor_shape).to(device)
for key in connection_vals:
    for key2 in connection_vals[key]:
        nz_connections = connections[key][key2].nonzero()
        #set nz connections to a positive random number
        connection_vals[key][key2][nz_connections[:,0], nz_connections[:,1]] = dummy_connection_vals[nz_connections[:,0], nz_connections[:,1]]
connections.keys(), connections["mlp_0"].keys(), connections["mlp_0"]["attn_0"][3], connection_vals["mlp_0"]["attn_0"][3]

(dict_keys(['attn_0', 'mlp_0', 'attn_1', 'mlp_1', 'attn_2', 'mlp_2', 'attn_3', 'mlp_3']),
 dict_keys(['attn_0']),
 tensor([  98, 1458, 2317,  274,  650, 2883,  336, 2147,  438, 1844],
        device='cuda:0'),
 tensor([0.4143, 0.5092, 0.3140, 0.4781, 0.0855, 0.6693, 0.8219, 0.4123, 0.4081,
         0.4898], device='cuda:0'))

In [4]:
# Create keys in desired order
dataset = prepare_streaming_dataset(tokenizer, dataset_name, max_length, batch_size, num_datapoints=num_datapoints)

n_layers = suite.model.cfg.n_layers

keys = [f"attn_{i}" for i in range(n_layers)] + [f"mlp_{i}" for i in range(n_layers)]

# Just test w/ first 10 features of each
features_to_save = {key: [i for i in range(get_first_n)] for key in keys}
saved_feature_act_list = {
    key: [] for key in keys
}
saved_token_list = []

In [5]:
from tqdm import tqdm
with torch.no_grad():
    for batch in tqdm(dataset, total=num_datapoints // batch_size):
        _, cache = model.run_with_cache(batch)
        _, features = suite.forward_pruned(cache, return_features=True)
        for key in keys:
            saved_feature_act_list[key].append(features[key][..., features_to_save[key]].cpu())
        saved_token_list.append(batch.cpu())
saved_feature_act_list = {key: torch.cat(saved_feature_act_list[key], dim=0) for key in keys}
saved_token_list = torch.cat(saved_token_list, dim=0)

  0%|          | 0/6 [00:00<?, ?it/s]

7it [00:10,  1.47s/it]                       


In [6]:
import numpy as np
from IPython.display import display, HTML
from einops import rearrange

def make_colorbar(min_value, max_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    # Add color bar
    colorbar = ""
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    # Do zero
    colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>'
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    return colorbar

def value_to_color(activation, max_value, min_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    if activation > positive_threshold:
        ratio = activation/max_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'
    elif activation < -negative_threshold:
        ratio = activation/min_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'
    else:
        text_color = "0,0,0"
        background_color = f'rgba({white},{white},{white},1)'
    return text_color, background_color

def convert_token_array_to_list(array):
    if isinstance(array, torch.Tensor):
        if array.dim() == 1:
            array = [array.tolist()]
        elif array.dim()==2:
            array = array.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(array, list):
        # ensure it's a list of lists
        if isinstance(array[0], int):
            array = [array]
    return array

def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None, model_type="causal", text_above_each_act=None):
    # text_spacing = "0.07em"
    text_spacing = "0.00em"
    toks = convert_token_array_to_list(toks)
    activations = convert_token_array_to_list(activations)
    # toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '\\n') for t in tok] for tok in toks]
    highlighted_text = []
    # Make background black
    # highlighted_text.append('<body style="background-color:black; color: white;">')
    highlighted_text.append("""
<body style="background-color: black; color: white;">
""")
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    if(logit_diffs is not None and model_type != "reward_model"):
        logit_max_value = max([max(activ) for activ in logit_diffs])
        logit_min_value = min([min(activ) for activ in logit_diffs])

    # Add color bar
    highlighted_text.append("Token Activations: " + make_colorbar(min_value, max_value))
    if(logit_diffs is not None and model_type != "reward_model"):
        highlighted_text.append('<div style="margin-top: 0.1em;"></div>')
        highlighted_text.append("Logit Diff: " + make_colorbar(logit_min_value, logit_max_value))
    
    highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
    for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
        if(text_above_each_act is not None):
            highlighted_text.append(f'<span>{text_above_each_act[seq_ind]}</span>')
        for act_ind, (a, t) in enumerate(zip(act, tok)):
            if(logit_diffs is not None and model_type != "reward_model"):
                highlighted_text.append('<div style="display: inline-block;">')
            text_color, background_color = value_to_color(a, max_value, min_value)
            highlighted_text.append(f'<span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{t.replace(" ", "&nbsp")}</span>')
            if(logit_diffs is not None and model_type != "reward_model"):
                logit_diffs_act = logit_diffs[seq_ind][act_ind]
                _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)
                highlighted_text.append(f'<div style="display: block; margin-right: {text_spacing}; height: 10px; background-color:{logit_background_color}; text-align: center;"></div></div>')
        if(logit_diffs is not None and model_type=="reward_model"):
            reward_change = logit_diffs[seq_ind].item()
            text_color, background_color = value_to_color(reward_change, 10, -10)
            highlighted_text.append(f'<br><span>Reward: </span><span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{reward_change:.2f}</span>')
        highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
        # highlighted_text.append('<br><br>')
    # highlighted_text.append('</body>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text
def save_token_display(tokens, activations, tokenizer, path, save=True, logit_diffs=None, show=False, model_type="causal"):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs, model_type=model_type)
    # if(save):
    #     imgkit.from_string(html, path)
    # if(show):
    return display(HTML(html))

def get_feature_indices(feature_activations, k=10, setting="max"):
    # Sort the features by activation, get the indices
    batch_size, seq_len = feature_activations.shape
    feature_activations = rearrange(feature_activations, 'b s -> (b s)')
    if setting=="max":
        found_indices = torch.argsort(feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(feature_activations)
        min_value = torch.min(feature_activations)
        max_value = torch.max(feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    d_indices = found_indices // seq_len
    s_indices = found_indices % seq_len
    return d_indices, s_indices

def get_feature_datapoints(d_idx, seq_pos_idx, all_activations, all_tokens, tokenizer):
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for md, s_ind in zip(d_idx, seq_pos_idx):
        md = int(md)
        s_ind = int(s_ind)
        # full_tok = torch.tensor(dataset[md]["input_ids"])
        
        full_tok = all_tokens[md]
        # [tokenizer.decode(t) for t in tokens[0]]

        full_text.append(tokenizer.decode(full_tok))
        tok = full_tok[:s_ind+1]
        # tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md].tolist())
        partial_activations.append(all_activations[md][:s_ind+1].tolist())
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations


In [7]:
from IPython.display import display, HTML
import torch

def create_logit_lens_html(top_ind, top_val, bot_ind, bot_val, tokenizer, k=10):
    """
    Create an HTML display of top and bottom tokens with their values.
    
    Parameters:
    - top_ind: indices of top tokens
    - top_val: values of top tokens
    - bot_ind: indices of bottom tokens
    - bot_val: values of bottom tokens
    - tokenizer: tokenizer to decode indices
    - k: number of tokens to display (default 10)
    """
    
    # Decode tokens
    top_text = [tokenizer.decode(tok).replace(" ", "_").replace("\n", "\\newline") for tok in top_ind[:k]]
    bot_text = [tokenizer.decode(tok).replace(" ", "_").replace("\n", "\\newline") for tok in bot_ind[:k]]
    
    # Create HTML template with direct background color attributes
    html_template = """
    <style>
        .token-table {{
            font-family: Arial, sans-serif;
            border-collapse: collapse;
            width: 100%;
            margin-top: 10px;
        }}
        .token-table td {{
            padding: 8px;
            border-bottom: 1px solid #ddd;
        }}
        .title {{
            font-size: 20px;
            font-weight: bold;
            text-align: center;
            margin-bottom: 10px;
        }}
    </style>
    
    <div class="title">Logit Lens</div>
    
    <table class="token-table">
        <tr>
            <td><b>Top Token</b></td>
            <td><b>Value</b></td>
            <td><b>Bottom Token</b></td>
            <td><b>Value</b></td>
        </tr>
    """
    
    # Add rows with colored backgrounds applied to spans instead of cells
    for i in range(k):
        html_template += f"""
        <tr>
            <td>{top_text[i]}</td>
            <td><span style="background-color: #0000FF; color: white; padding: 2px 4px; display: inline-block;"><b>{top_val[i].item():.3f}</b></span></td>
            <td>{bot_text[i]}</td>
            <td><span style="background-color: #FF0000; color: white; padding: 2px 4px; display: inline-block;"><b>{bot_val[i].item():.3f}</b></span></td>
        </tr>
        """
    
    html_template += "</table>"
    
    # return HTML(html_template)
    return html_template

# Usage example:
# display(create_token_display(top_ind, top_val, bot_ind, bot_val, tokenizer))

In [8]:
import matplotlib.pyplot as plt
import numpy as np
import io
import base64
from IPython.display import display, HTML

def display_matplotlib_figure(fig, width=None, height=None):
    """
    Convert a matplotlib figure to an HTML img tag for display in Jupyter notebooks
    
    Parameters:
    - fig: matplotlib figure to display
    - width: optional width (in pixels)
    - height: optional height (in pixels)
    """
    # Save the figure to a PNG in memory
    buf = io.BytesIO()
    fig.savefig(buf, format='png', bbox_inches='tight')
    buf.seek(0)
    
    # Encode the PNG as base64
    img_str = base64.b64encode(buf.read()).decode('utf-8')
    
    # Set the width and height attributes if provided
    style = ""
    if width is not None:
        style += f"width:{width}px;"
    if height is not None:
        style += f"height:{height}px;"
    
    style_attr = f' style="{style}"' if style else ''
    
    # Generate the HTML
    html = f'<img src="data:image/png;base64,{img_str}"{style_attr}/>'
    
    # return HTML(html)
    return html

# Example usage with a histogram
def create_histogram_html(data, bins=30, title="Histogram", width=600, height=400):
    # Create a histogram
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.hist(data, bins=bins, alpha=0.7, color='skyblue', edgecolor='black')
    ax.set_title(title)
    ax.set_xlabel('Value')
    ax.set_ylabel('Frequency')
    ax.grid(alpha=0.3)
    
    # Convert to HTML and display
    plt.close(fig)  # Close the figure to prevent it from displaying twice
    return display_matplotlib_figure(fig, width=width, height=height)

# Example with random data
# random_data = np.random.normal(0, 1, 1000)
# display(create_and_display_histogram(random_data, title="Normal Distribution"))

In [40]:
import json
import os
import torch

def generate_enhanced_viewer(keys, features_to_save, saved_feature_act_list, saved_token_list, tokenizer, model, suite, connections, connection_vals, output_dir="llm_feature_viewer", model_save_name=""):
    """
    Generate an enhanced HTML viewer with separate data files for each feature.
    
    This creates:
    1. A main index.html viewer file
    2. Separate HTML data files for each feature to avoid loading everything at once
    3. Connection data between features with links that support ctrl+click to open in new tabs
    
    Args:
        keys: List of model keys or identifiers
        features_to_save: Dictionary mapping keys to lists of feature indices
        saved_feature_act_list: Dictionary mapping keys to feature activation tensors
        saved_token_list: List of tokens for each example
        tokenizer: The tokenizer used to convert tokens to text
        model: The model object (for logit lens)
        suite: The suite object containing feature decoders
        connections: Dictionary of module-to-module connections (indices)
        connection_vals: Dictionary of module-to-module connection values
        output_dir: Directory where the viewer files will be saved
        model_save_name: Optional prefix for the output directory
        
    Returns:
        The path to the generated index.html file
    """
    
    # Number of examples to show per feature
    num_feature_datapoints = 10
    
    # Prepare output directory
    if model_save_name:
        output_dir = f"{model_save_name}_{output_dir}"
    
    # Create directory structure
    os.makedirs(output_dir, exist_ok=True)
    data_dir = os.path.join(output_dir, "data")
    os.makedirs(data_dir, exist_ok=True)
    
    # Create manifest data
    manifest = {
        "keys": keys,
        "features": {}
    }
    
    # Prepare string versions of features for manifest
    for key in keys:
        manifest["features"][key] = [str(f) for f in features_to_save[key]]
        # Create directory for each key
        key_dir = os.path.join(data_dir, key)
        os.makedirs(key_dir, exist_ok=True)
    # Function to get connections for a specific feature
    def get_feature_connections(source_key, source_feature_idx):
        feature_connections = []
        
        # Check if this key is in connections
        if source_key not in connections:
            return feature_connections
            
        # Check connections to all other modules
        for target_key in connections[source_key]:
            # Get connection indices tensor for this module pair
            connection_tensor = connections[source_key][target_key]
            
            # Get connection values tensor for this module pair
            value_tensor = connection_vals[source_key][target_key][source_feature_idx]
            
            # Ensure the feature index is valid
            if source_feature_idx >= connection_tensor.shape[0]:
                continue
                
            # Get row for this feature's connections
            connection_row = connection_tensor[source_feature_idx]
            
            # Find non-zero connections (where connected)
            # e.g. [0,1,4,6] for idx over the top-c connections (so 0-c indexed)
            non_zero_indices = (connection_row > 0).nonzero()[:, 0]
            
            # Handle various tensor dimensions
            if non_zero_indices.dim() == 0 and non_zero_indices.nelement() > 0:
                # Single non-zero value case
                non_zero_indices = [non_zero_indices.item()]
            elif non_zero_indices.nelement() > 0:
                non_zero_indices = non_zero_indices.tolist()
            else:
                non_zero_indices = []
                
            # For each connected feature, get the connection value
            for nz_idx in non_zero_indices:
                # Get the connection value from value_tensor
                target_feature_idx = int(connection_row[nz_idx].item())
                connection_value = float(value_tensor[nz_idx].item())
                
                # Only add if target feature is in features_to_save
                # if target_feature_idx in features_to_save.get(target_key, []):
                feature_connections.append({
                    "target_key": target_key,
                    "target_feature": target_feature_idx,
                    "value": connection_value
                })
        
        # Sort connections by strength (absolute value) in descending order
        feature_connections.sort(key=lambda x: abs(x["value"]), reverse=True)
        return feature_connections
    
    # Generate HTML for feature connections with URL links
    def generate_connections_html(connections_list):
        if not connections_list:
            return "<div class='no-connections'>No significant connections found</div>"
            
        html = ["<div class='connections-container'>",
                "<h3>Feature Connections</h3>",
                "<table class='connections-table'>",
                "<tr><th>Connected Feature</th><th>Connection Strength</th></tr>"]
                
        for conn in connections_list:
            target_key = conn["target_key"]
            target_feature = conn["target_feature"]
            conn_value = conn["value"]
            
            # Determine CSS class based on connection value
            if conn_value > 0:
                value_class = "positive-connection"
            else:
                value_class = "negative-connection"
                
            # Create link to target feature with URL parameters
            # This supports ctrl+click to open in new tab
            html.append(f"<tr>")
            html.append(f"<td><a href='index.html?key={target_key}&feature={target_feature}' class='feature-link'>{target_key} - Feature {target_feature}</a></td>")
            html.append(f"<td class='{value_class}'>{conn_value:.4f}</td>")
            html.append(f"</tr>")
            
        html.append("</table></div>")
        return "\n".join(html)
    
    # Generate data files for each key-feature pair
    for key in keys:
        features_for_this_key = features_to_save[key]
        
        for feature in features_for_this_key:
            # Cast feature to int if it's a string
            feature_idx = int(feature) if isinstance(feature, str) else feature
            feature_str = str(feature)
            
            # Initialize content container for this feature
            feature_data = {
                "tokenActivations": "",
                "logitLens": "",
                "histogram": "",
                "connections": []
            }
            
            # Get feature activations for this feature
            feature_activations = saved_feature_act_list[key][..., feature_idx]
            
            # Get indices of examples with highest activations
            d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="max")
            
            # Get data for these examples
            text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(
                d_idx, seq_idx, feature_activations, saved_token_list, tokenizer
            )
            
            # Generate token activations HTML
            token_html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
            feature_data["tokenActivations"] = token_html
            
            # Generate logit lens HTML
            try:
                feature_decoder = suite.aes[key].decoder.weight[:, feature_idx]
                unembd = model.W_U
                final_ln = model.ln_final
                logit_lens = final_ln(feature_decoder) @ unembd
                top_val, top_ind = torch.topk(logit_lens, k=10, dim=-1)
                bot_val, bot_ind = torch.topk(logit_lens, k=10, dim=-1, largest=False)
                
                # Create logit lens HTML
                logit_lens_html = create_logit_lens_html(top_ind, top_val, bot_ind, bot_val, tokenizer)
                feature_data["logitLens"] = logit_lens_html
            except Exception as e:
                feature_data["logitLens"] = f"<div class='error-panel'>Logit lens visualization unavailable: {str(e)}</div>"
            
            # Generate histogram HTML
            try:
                nz_feature_act = feature_activations[feature_activations != 0]
                frequency = nz_feature_act.numel() / feature_activations.numel()
                
                # Create histogram HTML
                hist_html = create_histogram_html(nz_feature_act.numpy(), 
                                               title=f"Activation Frequency {frequency*100:.2f}%")
                feature_data["histogram"] = hist_html
            except Exception as e:
                feature_data["histogram"] = f"<div class='error-panel'>Histogram visualization unavailable: {str(e)}</div>"
            
            # Get feature connections
            feature_connections = get_feature_connections(key, feature_idx)
            feature_data["connections"] = feature_connections
            
            # Generate connections HTML
            connections_html = generate_connections_html(feature_connections)
            
            # Create an HTML data file for this feature
            feature_html = f'''<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <title>Feature Data</title>
    <style>
        /* Connection styles */
        .connections-container {{
            margin-top: 20px;
            padding: 15px;
            background-color: #f8f9fa;
            border-radius: 4px;
            border: 1px solid #e9ecef;
        }}
        
        .connections-container h3 {{
            margin-top: 0;
            margin-bottom: 10px;
            font-size: 18px;
            color: #495057;
        }}
        
        .connections-table {{
            width: 100%;
            border-collapse: collapse;
            font-family: Arial, sans-serif;
        }}
        
        .connections-table th {{
            background-color: #e9ecef;
            padding: 8px;
            text-align: left;
            border-bottom: 2px solid #dee2e6;
        }}
        
        .connections-table td {{
            padding: 8px;
            border-bottom: 1px solid #dee2e6;
        }}
        
        .feature-link {{
            color: #007bff;
            text-decoration: none;
        }}
        
        .feature-link:hover {{
            text-decoration: underline;
        }}
        
        .positive-connection {{
            color: #28a745;
            font-weight: bold;
        }}
        
        .negative-connection {{
            color: #dc3545;
            font-weight: bold;
        }}
        
        .no-connections {{
            font-style: italic;
            color: #6c757d;
            padding: 10px;
            text-align: center;
        }}
    </style>
</head>
<body>
    <script>
        // Feature data as a JavaScript object
        const featureData = {json.dumps({
            "tokenActivations": feature_data["tokenActivations"],
            "logitLens": feature_data["logitLens"],
            "histogram": feature_data["histogram"],
            "connections": connections_html
        })};
        
        // Send data to parent window via postMessage
        window.addEventListener('message', function(event) {{
            if (event.data === 'requestData') {{
                window.parent.postMessage({{
                    type: 'featureData',
                    data: featureData
                }}, '*');
            }}
        }});
    </script>
</body>
</html>'''
            
            # Save feature HTML file
            feature_path = os.path.join(data_dir, key, f"{feature_str}.html")
            with open(feature_path, 'w') as f:
                f.write(feature_html)
    
    # Save manifest as a JavaScript file to avoid CORS issues
    manifest_js = f'''// Feature manifest data
const manifestData = {json.dumps(manifest)};
'''
    
    manifest_path = os.path.join(output_dir, "manifest.js")
    with open(manifest_path, 'w') as f:
        f.write(manifest_js)
    
    # Create main viewer HTML file
    viewer_html = '''<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Enhanced Token Activation Viewer</title>
    <style>
        /* Reset and base styles */
        * {
            box-sizing: border-box;
            margin: 0;
            padding: 0;
        }
        
        body {
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            line-height: 1.6;
            color: #333;
            background-color: #f5f5f5;
            padding: 20px;
        }
        
        .container {
            max-width: 1200px;
            margin: 0 auto;
            background-color: #fff;
            border-radius: 8px;
            box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
            overflow: hidden;
        }
        
        /* Header styles */
        .header {
            background-color: #2c3e50;
            color: white;
            padding: 20px;
            text-align: center;
        }
        
        .header h1 {
            margin: 0;
            font-size: 24px;
        }
        
        /* Controls section */
        .controls {
            background-color: #f8f9fa;
            padding: 20px;
            border-bottom: 1px solid #e9ecef;
            display: flex;
            flex-wrap: wrap;
            gap: 20px;
        }
        
        .control-group {
            flex: 1;
            min-width: 200px;
        }
        
        label {
            display: block;
            margin-bottom: 8px;
            font-weight: 600;
            color: #495057;
        }
        
        select {
            width: 100%;
            padding: 10px;
            border: 1px solid #ced4da;
            border-radius: 4px;
            background-color: #fff;
            font-size: 16px;
        }
        
        /* Navigation section */
        .navigation {
            display: flex;
            justify-content: space-between;
            padding: 15px 20px;
            background-color: #f8f9fa;
            border-bottom: 1px solid #e9ecef;
        }
        
        .button {
            padding: 8px 16px;
            background-color: #007bff;
            color: white;
            border: none;
            border-radius: 4px;
            cursor: pointer;
            font-weight: 500;
        }
        
        .button:hover {
            background-color: #0069d9;
        }
        
        .button:disabled {
            background-color: #cccccc;
            cursor: not-allowed;
        }
        
        /* Multi-panel layout */
        .upper-panels-table {
            width: 100%;
            border-collapse: collapse;
            margin: 20px 0;
        }
        
        .panel-cell {
            width: 50%;
            padding: 0 10px;
            vertical-align: top;
        }
        
        .panel {
            border: 1px solid #ddd;
            border-radius: 4px;
            height: 100%;
        }
        
        .panel-header {
            background-color: #f1f1f1;
            padding: 10px;
            font-weight: bold;
            border-bottom: 1px solid #ddd;
        }
        
        .panel-content {
            padding: 15px;
            overflow: auto;
            min-height: 300px;
        }
        
        /* Content section */
        .token-activations-container {
            background-color: black;
            color: white;
            min-height: 400px;
            padding: 20px;
            overflow: auto;
            margin: 0 20px 20px 20px;
            border-radius: 4px;
        }
        
        /* Connections section */
        .connections-section {
            margin: 0 20px 20px 20px;
        }
        
        /* Loading indicator */
        .loading {
            display: flex;
            justify-content: center;
            align-items: center;
            height: 300px;
            font-size: 18px;
            color: #6c757d;
        }
        
        /* Footer section */
        .footer {
            background-color: #f8f9fa;
            text-align: center;
            padding: 15px;
            color: #6c757d;
            border-top: 1px solid #e9ecef;
        }
        
        /* Content styles for token activations */
        .content {
            font-family: monospace;
            line-height: 1.4;
            white-space: pre-wrap;
        }

        /* Error panel */
        .error-panel {
            background-color: #ffe6e6;
            border: 1px solid #ffcccc;
            color: #990000;
            padding: 15px;
            border-radius: 4px;
            text-align: center;
        }
        
        /* Show/hide elements */
        .hidden {
            display: none !important;
        }
        
        /* Token table (for logit lens) */
        .token-table {
            font-family: Arial, sans-serif;
            border-collapse: collapse;
            width: 100%;
            margin-top: 10px;
        }
        
        .token-table td {
            padding: 8px;
            border-bottom: 1px solid #ddd;
        }
        
        .title {
            font-size: 20px;
            font-weight: bold;
            text-align: center;
            margin-bottom: 10px;
        }
        
        /* Fix for histogram images */
        .histogram-container {
            display: flex;
            justify-content: center;
            align-items: center;
        }
        
        .histogram-container img {
            max-width: 100%;
            max-height: 280px;
            width: auto;
            height: auto;
            object-fit: contain;
        }
        
        /* Data iframe for loading feature data */
        #dataFrame {
            display: none;
            width: 0;
            height: 0;
            border: 0;
        }
    </style>
    <!-- Load manifest data -->
    <script src="manifest.js"></script>
</head>
<body>
    <div class="container">
        <div class="header">
            <h1>Enhanced Token Activation Viewer</h1>
        </div>
        
        <div class="controls">
            <div class="control-group">
                <label for="keySelect">Select Model/Key:</label>
                <select id="keySelect">
                    <option value="" disabled selected>Choose a model/key</option>
                    <!-- Options will be populated by JavaScript -->
                </select>
            </div>
            <div class="control-group">
                <label for="featureSelect">Select Feature:</label>
                <select id="featureSelect" disabled>
                    <option value="" disabled selected>Select a model/key first</option>
                    <!-- Options will be populated by JavaScript -->
                </select>
            </div>
        </div>
        
        <div class="navigation">
            <button id="prevFeature" class="button" disabled>Previous Feature</button>
            <div>
                <span id="featureInfo">No feature selected</span>
            </div>
            <button id="nextFeature" class="button" disabled>Next Feature</button>
        </div>
        
        <!-- Upper panels for logit lens and histogram -->
        <table class="upper-panels-table">
            <tr>
                <!-- Logit lens panel -->
                <td class="panel-cell">
                    <div class="panel">
                        <div class="panel-header">Logit Lens</div>
                        <div id="logitLensPanel" class="panel-content">
                            <div class="loading">Select a feature to view logit lens</div>
                        </div>
                    </div>
                </td>
                
                <!-- Histogram panel -->
                <td class="panel-cell">
                    <div class="panel">
                        <div class="panel-header">Activation Histogram</div>
                        <div id="histogramPanel" class="panel-content">
                            <div class="loading">Select a feature to view histogram</div>
                        </div>
                    </div>
                </td>
            </tr>
        </table>
        
        <!-- Token activations section -->
        <div class="token-activations-container">
            <div id="loadingIndicator" class="loading">
                <p>Select a model/key and feature to view token activations</p>
            </div>
            <div id="contentDisplay" class="content hidden"></div>
        </div>
        
        <!-- Connections section -->
        <div class="connections-section">
            <div id="connectionsPanel"></div>
        </div>
        
        <div class="footer">
            <p>Enhanced Token Activation Viewer | Features Explorer</p>
        </div>
    </div>
    
    <!-- Hidden iframe for loading feature data -->
    <iframe id="dataFrame" title="Data Frame"></iframe>
    
    <script>
        // ===== DOM ELEMENTS =====
        // Get references to DOM elements
        const keySelect = document.getElementById('keySelect');
        const featureSelect = document.getElementById('featureSelect');
        const prevFeatureBtn = document.getElementById('prevFeature');
        const nextFeatureBtn = document.getElementById('nextFeature');
        const featureInfo = document.getElementById('featureInfo');
        const loadingIndicator = document.getElementById('loadingIndicator');
        const contentDisplay = document.getElementById('contentDisplay');
        const logitLensPanel = document.getElementById('logitLensPanel');
        const histogramPanel = document.getElementById('histogramPanel');
        const connectionsPanel = document.getElementById('connectionsPanel');
        const dataFrame = document.getElementById('dataFrame');
        
        // Current feature data
        let currentFeatureData = null;
        
        // ===== URL PARAMETER HANDLING =====
        // Parse URL parameters
        function getUrlParams() {
            const params = {};
            const searchParams = new URLSearchParams(window.location.search);
            
            for (const [key, value] of searchParams) {
                params[key] = value;
            }
            
            return params;
        }
        
        // ===== MESSAGE HANDLING =====
        // Handle messages from the data iframe
        window.addEventListener('message', function(event) {
            if (event.data && event.data.type === 'featureData') {
                // Store received feature data
                currentFeatureData = event.data.data;
                
                // Display the data
                displayFeatureData();
            }
        });
        
        // ===== DATA LOADING =====
        // Load feature data using iframe (avoids CORS issues)
        function loadFeatureData(key, feature) {
            return new Promise((resolve, reject) => {
                // Set up a timeout for loading
                const timeout = setTimeout(() => {
                    reject(new Error('Timeout loading feature data'));
                }, 10000); // 10 second timeout
                
                // Handle message from iframe
                const messageHandler = function(event) {
                    if (event.data && event.data.type === 'featureData') {
                        // Clear timeout and remove event listener
                        clearTimeout(timeout);
                        window.removeEventListener('message', messageHandler);
                        
                        // Resolve with the data
                        resolve(event.data.data);
                    }
                };
                
                // Listen for message from iframe
                window.addEventListener('message', messageHandler);
                
                // Set iframe src to load the feature data file
                dataFrame.src = `data/${key}/${feature}.html`;
                
                // Request data after a short delay to ensure iframe is loaded
                setTimeout(() => {
                    dataFrame.contentWindow.postMessage('requestData', '*');
                }, 100);
            });
        }
        
        // ===== INITIALIZATION =====
        // Initialize the viewer
        function initViewer() {
            // Check if manifest data is available
            if (!manifestData || !manifestData.keys || !manifestData.features) {
                showError('Failed to load manifest data. Please refresh the page.');
                return;
            }
            
            // Populate the keys dropdown
            manifestData.keys.forEach(key => {
                const option = document.createElement('option');
                option.value = key;
                option.textContent = key;
                keySelect.appendChild(option);
            });
            
            // Set up event listeners
            keySelect.addEventListener('change', handleKeyChange);
            featureSelect.addEventListener('change', handleFeatureChange);
            prevFeatureBtn.addEventListener('click', showPreviousFeature);
            nextFeatureBtn.addEventListener('click', showNextFeature);
            
            // Check URL parameters for direct feature loading
            const params = getUrlParams();
            const urlKey = params.key;
            const urlFeature = params.feature;
            
            if (urlKey && urlFeature && 
                manifestData.keys.includes(urlKey) && 
                manifestData.features[urlKey] && 
                manifestData.features[urlKey].includes(urlFeature)) {
                
                // Select the key from URL
                keySelect.value = urlKey;
                handleKeyChange();
                
                // Allow a small delay for the feature dropdown to update
                setTimeout(() => {
                    // Select the feature from URL
                    featureSelect.value = urlFeature;
                    handleFeatureChange();
                }, 100);
            }
            // If no URL params, select first key and feature if available
            else if (manifestData.keys.length > 0) {
                // Select first key
                keySelect.value = manifestData.keys[0];
                handleKeyChange();
                
                // Select first feature if available
                if (manifestData.features[manifestData.keys[0]] && manifestData.features[manifestData.keys[0]].length > 0) {
                    setTimeout(() => {
                        featureSelect.value = manifestData.features[manifestData.keys[0]][0];
                        handleFeatureChange();
                    }, 100); // Small delay to ensure feature dropdown is populated
                }
            }
        }
        
        // Show error message
        function showError(message) {
            loadingIndicator.classList.remove('hidden');
            loadingIndicator.innerHTML = `<p class="error-panel">${message}</p>`;
            contentDisplay.classList.add('hidden');
        }
        
        // ===== EVENT HANDLERS =====
        // Handle key selection change
        function handleKeyChange() {
            const selectedKey = keySelect.value;
            
            // Clear and reset feature dropdown
            featureSelect.innerHTML = '';
            featureSelect.disabled = !selectedKey;
            
            if (selectedKey && manifestData.features[selectedKey]) {
                // Add default option
                const defaultOption = document.createElement('option');
                defaultOption.value = '';
                defaultOption.textContent = `Select a feature`;
                defaultOption.disabled = true;
                defaultOption.selected = true;
                featureSelect.appendChild(defaultOption);
                
                // Populate feature dropdown
                manifestData.features[selectedKey].forEach(feature => {
                    const option = document.createElement('option');
                    option.value = feature;
                    option.textContent = `Feature ${feature}`;
                    featureSelect.appendChild(option);
                });
                
                // Clear content display
                contentDisplay.classList.add('hidden');
                loadingIndicator.classList.remove('hidden');
                loadingIndicator.innerHTML = '<p>Select a feature to view token activations</p>';
                
                // Reset feature info
                featureInfo.textContent = 'No feature selected';
                
                // Clear panels
                logitLensPanel.innerHTML = '<div class="loading">Select a feature to view logit lens</div>';
                histogramPanel.innerHTML = '<div class="loading">Select a feature to view histogram</div>';
                connectionsPanel.innerHTML = '';
                
                // Disable navigation buttons
                prevFeatureBtn.disabled = true;
                nextFeatureBtn.disabled = true;
                
                // Reset current feature data
                currentFeatureData = null;
            }
        }
        
        // Handle feature selection change
        async function handleFeatureChange() {
            const selectedKey = keySelect.value;
            const selectedFeature = featureSelect.value;
            
            if (selectedKey && selectedFeature) {
                // Show loading indicators
                loadingIndicator.classList.remove('hidden');
                loadingIndicator.innerHTML = '<p>Loading content...</p>';
                contentDisplay.classList.add('hidden');
                
                logitLensPanel.innerHTML = '<div class="loading">Loading logit lens...</div>';
                histogramPanel.innerHTML = '<div class="loading">Loading histogram...</div>';
                connectionsPanel.innerHTML = '<div class="loading">Loading connections...</div>';
                
                // Update feature info
                featureInfo.textContent = `${selectedKey} - Feature ${selectedFeature}`;
                
                // Update browser URL without reloading the page
                const newUrl = new URL(window.location.href);
                newUrl.searchParams.set('key', selectedKey);
                newUrl.searchParams.set('feature', selectedFeature);
                window.history.pushState({ key: selectedKey, feature: selectedFeature }, '', newUrl.href);
                
                try {
                    // Load feature data
                    currentFeatureData = await loadFeatureData(selectedKey, selectedFeature);
                    
                    // Display the data
                    displayFeatureData();
                    
                    // Update navigation buttons
                    updateNavigationButtons();
                } catch (error) {
                    console.error('Error loading feature data:', error);
                    showError(`Failed to load data for ${selectedKey} - Feature ${selectedFeature}`);
                }
            }
        }
        
        // Display the loaded feature data
        function displayFeatureData() {
            if (currentFeatureData) {
                // Display token activations
                contentDisplay.innerHTML = currentFeatureData.tokenActivations;
                contentDisplay.classList.remove('hidden');
                loadingIndicator.classList.add('hidden');
                
                // Display logit lens
                logitLensPanel.innerHTML = currentFeatureData.logitLens;
                
                // Display histogram
                histogramPanel.innerHTML = currentFeatureData.histogram;
                
                // Display connections
                connectionsPanel.innerHTML = currentFeatureData.connections;
            }
        }
        
        // ===== NAVIGATION CONTROLS =====
        // Show previous feature
        function showPreviousFeature() {
            const selectedKey = keySelect.value;
            const features = manifestData.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            if (currentIndex > 0) {
                featureSelect.value = features[currentIndex - 1];
                handleFeatureChange();
            }
        }
        
        // Show next feature
        function showNextFeature() {
            const selectedKey = keySelect.value;
            const features = manifestData.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            if (currentIndex < features.length - 1) {
                featureSelect.value = features[currentIndex + 1];
                handleFeatureChange();
            }
        }
        
        // Update navigation button states
        function updateNavigationButtons() {
            const selectedKey = keySelect.value;
            const features = manifestData.features[selectedKey];
            const currentFeature = featureSelect.value;
            
            const currentIndex = features.indexOf(currentFeature);
            
            prevFeatureBtn.disabled = currentIndex <= 0;
            nextFeatureBtn.disabled = currentIndex >= features.length - 1;
        }
        
        // Handle browser back/forward navigation
        window.addEventListener('popstate', function(event) {
            const params = getUrlParams();
            
            if (params.key && params.feature) {
                // Only update if values actually changed
                if (keySelect.value !== params.key) {
                    keySelect.value = params.key;
                    handleKeyChange();
                    
                    setTimeout(() => {
                        featureSelect.value = params.feature;
                        handleFeatureChange();
                    }, 100);
                } else if (featureSelect.value !== params.feature) {
                    featureSelect.value = params.feature;
                    handleFeatureChange();
                }
            }
        });
        
        // Initialize the viewer when the page loads
        window.addEventListener('DOMContentLoaded', initViewer);
    </script>
</body>
</html>'''
    
    # Save main viewer HTML file
    index_path = os.path.join(output_dir, "index.html")
    with open(index_path, 'w') as f:
        f.write(viewer_html)
    
    print(f"Enhanced token activation viewer created at: {output_dir}")
    print(f"Open {index_path} in your browser to use the viewer")
    
    # If running in notebook, provide a clickable link
    try:
        from IPython.display import HTML, display
        display(HTML(f'<a href="{index_path}" target="_blank">Open Enhanced Token Activation Viewer</a>'))
    except:
        pass
    
    return index_path

In [None]:
generate_enhanced_viewer(
    model=model,
    suite=suite,
    keys=keys,
    features_to_save=features_to_save,
    saved_feature_act_list=saved_feature_act_list,
    saved_token_list=saved_token_list,
    tokenizer=tokenizer,
    model_save_name=model_save_name,
    connections=connections,
    connection_vals=connection_vals,
)

Enhanced token activation viewer created at: TinyStories-33M_scae_10_ce_llm_feature_viewer
Open TinyStories-33M_scae_10_ce_llm_feature_viewer/index.html in your browser to use the viewer


'TinyStories-33M_scae_10_ce_llm_feature_viewer/index.html'

In [10]:
from IPython.display import display, HTML
num_feature_datapoints = 10 # how many examples/expert
for key in keys:
    key = "mlp_0"
    print(f"Key: {key}")
    features_for_this_key = features_to_save[key]
    for feature_idx, feature in enumerate(features_for_this_key):
        feature_activations = saved_feature_act_list[key][..., int(feature)]
        d_idx, seq_idx = get_feature_indices(feature_activations, k=num_feature_datapoints, setting="max")
        # uniform_indices = get_feature_indices(feature, feature_activations, k=num_feature_datapoints, setting="max")
        text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(d_idx, seq_idx, feature_activations, saved_token_list, tokenizer)
        html = tokens_and_activations_to_html(token_list, partial_activations, tokenizer)
        display(HTML(html))
        W_U = model.W_U
        final_ln = model.ln_final
        feature_decoder = suite.aes[key].decoder.weight[:, int(feature)]
        logit_lens = final_ln(feature_decoder) @ W_U
        top_val, top_ind = torch.topk(logit_lens, k=10, dim=-1)
        bot_val, bot_ind = torch.topk(logit_lens, k=10, dim=-1, largest=False)
        logit_lens_html = create_logit_lens_html(top_ind, top_val, bot_ind, bot_val, tokenizer)
        display(HTML(logit_lens_html))
        if(feature_idx > 10):
            break # to avoid too many examples
    break

Key: mlp_0


0,1,2,3
Top Token,Value,Bottom Token,Value
_suddenly,15.125,nd,-11.130
_we,12.100,Four,-10.588
_you,12.096,_Mostly,-10.065
_Humans,12.069,_kale,-9.926
_come,11.732,_amounts,-9.913
_finally,11.604,_then,-9.489
_exhaustion,11.275,_rate,-9.448
_Enter,11.127,mos,-9.355
_introduce,11.056,_gunshots,-9.026


0,1,2,3
Top Token,Value,Bottom Token,Value
_cat,13.659,ling,-12.546
fly,10.926,ink,-12.477
_Zig,10.849,ering,-12.035
_fly,10.792,_fr,-11.838
jo,10.710,out,-11.718
_Ana,10.554,aming,-11.160
_audition,10.497,et,-11.157
_musician,10.393,er,-11.062
_Rico,10.272,hole,-10.951


0,1,2,3
Top Token,Value,Bottom Token,Value
urry,14.061,_gratification,-13.370
_oily,13.920,_motivation,-11.781
_pink,13.646,ms,-11.226
_fluffy,12.550,_ampl,-10.972
bing,11.556,_dominated,-10.623
_cute,11.520,beh,-9.798
_but,11.216,Things,-9.776
_graceful,11.093,azard,-9.638
_mush,11.011,_securing,-9.638


0,1,2,3
Top Token,Value,Bottom Token,Value
_director,12.904,_his,-16.349
_happily,12.376,_her,-13.949
_Tee,11.914,ged,-11.924
_Oct,11.100,His,-11.461
"_""",10.769,pt,-11.393
_Ma,10.410,Her,-11.353
_Age,10.319,_him,-11.288
Cle,10.177,bite,-11.202
_Whe,10.151,ared,-11.145


0,1,2,3
Top Token,Value,Bottom Token,Value
_upon,18.459,',-11.491
_before,13.482,_worse,-11.345
_early,12.589,sm,-10.563
_atop,12.255,_fluids,-10.428
_maj,12.089,_disliked,-10.245
_yourselves,11.886,_burg,-9.976
_itself,11.586,amp,-9.922
_downward,11.422,ize,-9.906
_unlikely,11.184,rd,-9.889


0,1,2,3
Top Token,Value,Bottom Token,Value
_ground,13.237,_YOU,-10.780
_smoking,12.833,_swear,-10.598
_until,12.609,cong,-10.370
_down,12.304,_titles,-10.058
_out,12.181,_jerk,-10.052
_bumper,12.034,played,-9.826
_ax,11.772,_referred,-9.729
_throughout,11.231,_rejoice,-9.727
_upright,10.877,_comple,-9.598


0,1,2,3
Top Token,Value,Bottom Token,Value
_when,22.717,hetically,-12.465
!,22.313,_improvement,-11.629
.,18.902,_bitterness,-11.622
!.,16.015,_believe,-10.665
"!""",16.003,_coincidence,-10.655
_it,14.196,_sweetness,-10.566
"!"".",13.689,_misunderstanding,-10.459
_whenever,13.659,_Happ,-10.320
?,12.855,_origin,-9.949


0,1,2,3
Top Token,Value,Bottom Token,Value
y,15.621,_blaze,-10.512
pped,14.033,_proposal,-10.221
c,13.492,_setback,-10.003
_announcing,12.819,_fend,-9.921
is,11.922,_Among,-9.441
lessly,11.778,_miracle,-9.271
ler,11.700,_differences,-9.174
fl,11.658,_particular,-9.093
v,11.359,Sony,-9.071


0,1,2,3
Top Token,Value,Bottom Token,Value
_that,13.333,_disgust,-11.818
_little,12.382,abb,-11.142
_place,12.328,Av,-10.045
_well,11.687,keys,-9.939
_luck,11.452,Design,-9.813
_the,11.230,_murm,-9.658
_game,11.106,_const,-9.170
_games,10.935,_unravel,-9.109
_her,10.830,_fails,-9.030


0,1,2,3
Top Token,Value,Bottom Token,Value
_upstream,11.564,upp,-13.161
_downstream,10.777,ep,-11.499
_emotional,10.689,_Play,-11.450
_determined,10.339,ci,-11.397
_fertilizer,10.181,room,-11.285
_American,9.923,_fists,-11.177
tree,9.862,_play,-10.445
_past,9.818,AY,-10.324
"!'""",9.759,_doors,-10.248


0,1,2,3
Top Token,Value,Bottom Token,Value
rm,11.075,_grips,-13.360
_orphan,10.754,Dis,-11.128
aze,10.278,ner,-10.726
_Circus,9.826,sed,-10.308
_disapp,9.709,_deserve,-10.269
_Broken,9.681,_cores,-10.212
_result,9.516,_thirteen,-10.089
_from,9.487,aying,-10.069
_combo,9.432,d,-10.025


0,1,2,3
Top Token,Value,Bottom Token,Value
_work,13.825,_sadd,-13.262
_soak,13.709,imp,-10.887
_dry,13.502,die,-10.859
_action,12.719,_rushes,-10.473
_places,11.298,_receiving,-10.078
_inspire,11.263,_anew,-9.744
_haul,10.870,tw,-9.737
_drown,10.667,ents,-9.576
_lure,10.626,_Homer,-9.501


In [11]:
key, feature

('mlp_0', '11')