# SAE Feature Interpretability Multitool

This notebook provides tools for interpreting features extracted by Sparse Autoencoders (SAEs) trained on the Gemma-2B large language model. Specifically, it:

*   Loads weights from [Joseph Bloom's SAEs trained on Gemma-2B at layers 0, 6, 10, and 12](https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs).
*   Constructs a "feature vector" from selected SAE feature weights (encoder or decoder).
*   Uses this feature vector for two interpretation tasks:
    1. **Generating "Definition Trees"**: Constructs trees for "ghost tokens" derived from the feature vector.
    2. **Producing Token Lists**: Generates lists of tokens based on cosine similarity to the feature vector and the token embedding centroid.

**Setup:** For the first cell to run, you'll need a HuggingFace account and access token (free). Get yours at: https://huggingface.co/settings/tokens

Add your access token to Colab as a "secret":
1. Click the key icon in the left sidebar.
2. Click **+ Add new secret** and name it `HF_READ_TOKEN`.
3. Run the first cell to log into HuggingFace.


In [2]:
# @title HuggingFace login
# Log into HuggingFace with your access token stored in Colab secrets
import huggingface_hub
from google.colab import userdata

huggingface_hub.login(userdata.get('HF_READ_TOKEN'))

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [3]:
# @title Preparation: Install dependencies, load model, tokeniser, embeddings, and SAE weights; calculate PCA components

# Install necessary libraries quietly
!pip install -q nnsight accelerate datasets tqdm

# Import necessary libraries
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
import torch
from huggingface_hub import hf_hub_download
import os
import sys
import io
from safetensors import safe_open
import requests
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
import torch.nn.functional as F

# Initialize global variables
global model, tokenizer, token_embeddings, sae_weights

# Suppress Hugging Face token warning in Colab
def filter_hf_token_warning(message, category, filename, lineno, file=None, line=None):
    if category == UserWarning and "The secret `HF_TOKEN` does not exist in your Colab secrets" in str(message):
        return None
    return True

warnings.showwarning = filter_hf_token_warning

# Suppress FutureWarning for torch.load regarding weights_only
warnings.filterwarnings("ignore", category=FutureWarning, module="torch")

# Function to load token embeddings
def load_token_embeddings():
    try:
        print("\nDownloading token embeddings...")
        embeddings_path = hf_hub_download(
            repo_id="mwatkins1970/gemma-2b-embeddings",
            filename="gemma_2b_embeddings.pt"
        )
        embeddings = torch.load(embeddings_path, map_location=model.device)
        print(f"Token embeddings loaded successfully. Shape: {list(embeddings.shape)}")
        return embeddings
    except Exception as e:
        print(f"Error loading token embeddings: {str(e)}")
        print("Falling back to model's token embeddings.")
        return model.get_input_embeddings().weight.data

# Function to download and load SAE weights
def load_sae_weights(sae_name):
    base_url = 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/'
    url = f'{base_url}{sae_options[sae_name]}?download=true'
    local_filename = f'sae_{sae_name.replace(" ", "_").lower()}.safetensors'

    if not os.path.exists(local_filename):
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        with open(local_filename, 'wb') as f, tqdm(
            desc=sae_name,
            total=total_size,
            unit='iB',
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar:
            for data in response.iter_content(chunk_size=1024):
                size = f.write(data)
                progress_bar.update(size)

    with safe_open(local_filename, framework="pt") as f:
        return {
            "encoder": f.get_tensor("W_enc"),
            "decoder": f.get_tensor("W_dec")
        }

# Function to calculate the first PCA component (PCA1) from embeddings
def perform_pca(embeddings):
    pca = PCA(n_components=1)
    embeddings_cpu = embeddings.detach().cpu().numpy()  # Convert embeddings to numpy for PCA
    pca.fit(embeddings_cpu)
    pca_direction = torch.tensor(pca.components_[0], dtype=config.DTYPE, device=config.DEVICE)
    return F.normalize(pca_direction, p=2, dim=0)  # Normalize the PCA direction

# Initialize configuration with data type and device settings
class Config:
    def __init__(self):
        self.SUB_TOKEN_ID = 23070
        self.CUTOFF = 0.00001
        self.TOPK = 5
        self.LOG_BASE = 10
        self.running = True
        self.DTYPE = torch.float32
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()

# Load the model with explicit configuration
print("\n\nLoading Gemma-2B model...")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)
# Set the hidden activation function explicitly
model.config.hidden_activation = "gelu_pytorch_tanh"  # Or whatever function you need

model.eval()
print("Gemma-2B model loaded successfully.")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
print("Tokenizer loaded successfully.")

# Load token embeddings
token_embeddings = load_token_embeddings()

# Define SAE layers and load weights
sae_options = {
    "Gemma-2B layer 0": "gemma_2b_blocks.0.hook_resid_post_16384_anthropic/sae_weights.safetensors",
    "Gemma-2B layer 6": "gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors",
    "Gemma-2B layer 10": "gemma_2b_blocks.10.hook_resid_post_16384/sae_weights.safetensors",
    "Gemma-2B layer 12": "gemma_2b_blocks.12.hook_resid_post_16384/sae_weights.safetensors"
}
print("\nLoading SAE weights...")
sae_weights = {name: load_sae_weights(name) for name in sae_options.keys()}
print("All SAE weights loaded successfully.")

# Precompute PCA1 for token embeddings
token_embeddings = model.get_input_embeddings().weight.data.to(config.DEVICE)
print("\nCalculating 1st PCA component for token embeddings...")
PCA1_direction = perform_pca(token_embeddings)
print("PCA component calculated successfully.")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/3.5 MB[0m [31m10.9 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━[0m [32m2.2/3.5 MB[0m [31m33.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/471.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m471.6/471.6 kB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m81.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━

config.json:   0%|          | 0.00/627 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

Gemma-2B model loaded successfully.


tokenizer_config.json:   0%|          | 0.00/33.6k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Tokenizer loaded successfully.

Downloading token embeddings...


gemma_2b_embeddings.pt:   0%|          | 0.00/2.10G [00:00<?, ?B/s]

Token embeddings loaded successfully. Shape: [256000, 2048]

Loading SAE weights...


Gemma-2B layer 0:   0%|          | 0.00/256M [00:00<?, ?iB/s]

Gemma-2B layer 6:   0%|          | 0.00/256M [00:00<?, ?iB/s]

Gemma-2B layer 10:   0%|          | 0.00/256M [00:00<?, ?iB/s]

Gemma-2B layer 12:   0%|          | 0.00/256M [00:00<?, ?iB/s]

All SAE weights loaded successfully.

Calculating 1st PCA component for token embeddings...
PCA component calculated successfully.


### To use functionality 1:

1.   select an SAE from the dropdown (layer 0, 6, 10 or 12)
2.   choose a feature number (in the range 0..16383)
3.   chose between encoder or decoder weights for that feature
4.   decide whether to use token centroid offset (recommended)
5.   choose a scaling factor (L2 norm of vector used to construct "ghost token", the default 3.8 being the approximate mean token embedding distance-from-centroid; disregarding the centroid, the mean L2 norm of token embeddings is ~7.9)
6.   choose a cumulative probability cutoff (this controls the point at which a sequence of tokens output as a "typical definition of" the ghost token is terminated, hence smaller values produce richer trees, which take longer to generate)
7.   if you want to nudge the feature direction towards the direction of the first PCA component for the token embeddings, choose an appropriate nonzero value for "PCA 1st component weighting"








In [None]:
# @title functionality 1: generate definition trees for SAE features
import torch
import ipywidgets as widgets
import math
import uuid  # For generating unique image filenames
from graphviz import Digraph
from IPython.display import display, clear_output, HTML
from PIL import Image as PILImage
import os
import requests
import time
from safetensors import safe_open
from tqdm.auto import tqdm
import warnings

# Suppress DecompressionBombWarning
warnings.filterwarnings("ignore", category=PILImage.DecompressionBombWarning)

# Clear GPU cache
torch.cuda.empty_cache()

global explorer

# Tree generation functions
def update_token_embedding(model, token_id, new_embedding):
    new_embedding = new_embedding.to(model.get_input_embeddings().weight.device)
    model.get_input_embeddings().weight.data[token_id] = new_embedding

def produce_next_token_ids(input_ids, model, topk, sub_token_id):
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
    last_logits = logits[0, -1, :]
    last_logits[sub_token_id] = float('-inf')  # Avoid circular definitions
    softmax_probs = F.softmax(last_logits, dim=-1)
    top_k_probs, top_k_ids = torch.topk(softmax_probs, k=topk)
    return top_k_ids, top_k_probs

def build_def_tree(input_ids, data, base_prompt, model, tokenizer, config, output_widget,
                   depth=0, max_depth=25, cumulative_prob=1.0, output_lines=None):
    if output_lines is None:
        output_lines = []
    if depth >= max_depth or cumulative_prob < config.CUTOFF or not config.running:
        return

    current_prompt = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    current_definition = current_prompt.split(" would be ")[1] if " would be " in current_prompt else current_prompt
    current_definition = current_definition.replace("\n", " | ")

    line = f"depth {depth:<2}: {current_definition:<75} cumulative prob.: {cumulative_prob:.8f}"

    output_lines.append(line)
    # Keep only the last 50 lines
    if len(output_lines) > 50:
        output_lines = output_lines[-50:]

    with output_widget:
        clear_output(wait=True)
        print('\n'.join(output_lines))

    top_k_ids, top_k_probs = produce_next_token_ids(input_ids, model, config.TOPK, config.SUB_TOKEN_ID)

    for idx, token_id in enumerate(top_k_ids.tolist()):
        if token_id == config.SUB_TOKEN_ID:
            continue  # Skip the substitute token to avoid circular definitions

        new_input_ids = torch.cat([input_ids, torch.tensor([[token_id]], device=input_ids.device)], dim=-1)

        new_cumulative_prob = cumulative_prob * top_k_probs[idx].item()

        if new_cumulative_prob < config.CUTOFF:
            continue

        token_str = tokenizer.decode([token_id], skip_special_tokens=True)

        new_child = {
            "token_id": token_id,
            "token": token_str,
            "cumulative_prob": new_cumulative_prob,
            "children": []
        }
        data['children'].append(new_child)

        build_def_tree(new_input_ids, new_child, base_prompt, model, tokenizer, config, output_widget,
                       depth=depth+1, max_depth=max_depth, cumulative_prob=new_cumulative_prob,
                       output_lines=output_lines)

def generate_definition_tree(base_prompt, embedding, model, tokenizer, config, output_widget):
    results_dict = {"token": "", "cumulative_prob": 1, "children": []}
    output_lines = []

    # Update the token embedding
    token_embedding = torch.unsqueeze(embedding, dim=0)
    update_token_embedding(model, config.SUB_TOKEN_ID, token_embedding)

    input_ids = tokenizer.encode(base_prompt, return_tensors="pt").to(model.device)

    # Generate the tree and print scrolling text
    build_def_tree(input_ids, results_dict, base_prompt, model, tokenizer, config, output_widget, output_lines=output_lines)

    with output_widget:
        print("\ndefinition tree visualisation incoming...\n")

    return results_dict


# Visualisation functions
def find_max_min_cumulative_weight(node, current_max=0, current_min=float('inf')):
    current_max = max(current_max, node.get('cumulative_prob', 0))
    if node.get('cumulative_prob', 1) > 0:
        current_min = min(current_min, node.get('cumulative_prob', 1))
    for child in node.get('children', []):
        current_max, current_min = find_max_min_cumulative_weight(child, current_max, current_min)
    return current_max, current_min

def create_tree_diagram(data, directory, name, config, max_weight, min_weight,
                        trim_cutoff=0, image_height=5000, dpi=120, margin=None,
                        font_size=36, ranksep=5):
    import os

    def scale_edge_width(cumulative_weight, max_weight, min_weight, log_base, max_thickness=33, min_thickness=1):
        cumulative_weight = max(cumulative_weight, min_weight)
        log_weight = math.log(cumulative_weight, log_base) - math.log(min_weight, log_base)
        log_max = math.log(max_weight, log_base) - math.log(min_weight, log_base)
        amplified_weight = (log_weight / log_max) ** 2.5
        scaled_weight = (amplified_weight * (max_thickness - min_thickness)) + min_thickness
        return scaled_weight

    def add_nodes_edges(dot, node, name, max_weight, min_weight, parent=None, is_root=True, depth=0,
                        branch='', excluded_tokens=None):
        if excluded_tokens is None:
            excluded_tokens = {"'.", ".'", "()", "''"}

        node_id = str(id(node))
        token = node.get('token', '').strip()
        cumulative_prob = node.get('cumulative_prob', 1)

        if token in excluded_tokens:
            return

        if cumulative_prob < trim_cutoff and not is_root:
            return

        if is_root or (token and token not in excluded_tokens):
            if parent and not is_root:
                edge_weight = scale_edge_width(cumulative_prob, max_weight, min_weight, config.LOG_BASE)
                dot.edge(parent, node_id, arrowhead='dot', arrowsize='1', color='darkblue', penwidth=str(edge_weight))

            label = "*" if is_root else (token if token else "[EMPTY]")
            dot.node(node_id, label=label, shape='plaintext', fontsize=str(font_size), fontname='Helvetica')

            for child in node.get('children', []):
                add_nodes_edges(dot, child, name, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, branch=branch + token + ' ', excluded_tokens=excluded_tokens)

    dot = Digraph(comment='Definition Tree', format='png')
    dot.attr(rankdir='LR', size=f'5040,{image_height}', margin=str(margin), nodesep='0.06', ranksep=str(ranksep), dpi=str(dpi), bgcolor='white')

    add_nodes_edges(dot, data, name, max_weight, min_weight)

    output_file_path = os.path.join(directory, f'{name}')

    try:
        output_path = dot.render(filename=output_file_path, cleanup=True)
    except Exception as e:
        print(f"Error rendering dot graph: {e}")
        return None

    if not os.path.exists(output_path):
        print(f"Error: The output image {output_path} does not exist.")
        return None

    try:
        with PILImage.open(output_path) as tree_img:
            bbox = tree_img.getbbox()
            if bbox:
                cropped_img = tree_img.crop(bbox)
                aspect_ratio = cropped_img.width / cropped_img.height
                new_width = 5000
                new_height = int(new_width / aspect_ratio)

                if new_height > 5000:
                    new_height = 5000
                    new_width = int(new_height * aspect_ratio)

                resized_img = cropped_img.resize((new_width, new_height), PILImage.LANCZOS)

                bg = PILImage.new("RGB", (5040, 5000), (255, 255, 255))
                paste_x = (5040 - new_width) // 2
                paste_y = (5000 - new_height) // 2
                bg.paste(resized_img, (paste_x, paste_y))

                output_path_with_white_bg = os.path.splitext(output_path)[0] + '_resized.png'
                bg.save(output_path_with_white_bg, 'PNG')
            else:
                print("The image appears to be blank.")
                return None
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

    if os.path.exists(output_path):
        os.remove(output_path)

    return output_path_with_white_bg

# Function to calculate token centroid
def calculate_token_centroid(model):
    token_embeddings = model.get_input_embeddings().weight.data
    return torch.mean(token_embeddings, dim=0)

def create_feature_vector(weights, feature_number, token_centroid, use_token_centroid, scaling_factor, weight_type, pca_weighting):
    if weight_type == 'decoder':
        raw_feature_vector = weights[feature_number]
    else:
        raw_feature_vector = weights[:, feature_number]

    raw_feature_direction = F.normalize(raw_feature_vector, p=2, dim=0)
    token_centroid = token_centroid.to(raw_feature_vector.device)

    pca_direction_on_same_device = PCA1_direction.to(raw_feature_vector.device)
    feature_direction = (1 - pca_weighting) * raw_feature_direction + pca_weighting * pca_direction_on_same_device

    if use_token_centroid:
        feature_vector = token_centroid + scaling_factor * feature_direction
    else:
        feature_vector = scaling_factor * feature_direction

    return feature_vector

def format_probability(value):
    # Convert to string with 8 decimal places
    formatted = f"{value:.8f}"
    # Remove trailing zeros
    formatted = formatted.rstrip('0')
    # Remove trailing decimal point if all decimals were zeros
    formatted = formatted.rstrip('.')
    return formatted


class Config:
    def __init__(self):
        self.SUB_TOKEN_ID = 23070  # Token ID for "OSS"
        self.CUTOFF = 0.00001
        self.TOPK = 5
        self.LOG_BASE = 10
        self.running = True
        self.DTYPE = torch.float32
        self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = Config()


class DefinitionTreeExplorer:
    def __init__(self):
        self.controls_output = widgets.Output()
        self.treegen_output = widgets.Output()
        self.neuronpedia_output = widgets.Output()
        self.visualization_and_controls_output = widgets.Output()

        self.trim_cutoff = widgets.FloatLogSlider(
              value=1e-4,
              base=10,
              min=-8,
              max=-2,
              step=0.2,
              description='trim cutoff:',
              style={'description_width': 'initial'},
              layout=widgets.Layout(width='400px')
          )

        self.regenerate_button = widgets.Button(
            description="regenerate visualisation",
            layout=widgets.Layout(width='200px')
        )

        self.exit_button = widgets.Button(
            description="exit",
            layout=widgets.Layout(width='100px')
        )

        # Set up UI and create the horizontal layout
        controls = self.setup_ui()

        self.feature_number.observe(self.on_feature_number_keydown, names='_keydown')

        # Add padding between controls and Neuronpedia widget
        self.layout = widgets.HBox([controls, widgets.Box(layout=widgets.Layout(width='20px')), self.neuronpedia_output],
                                  layout=widgets.Layout(padding='10px'))

        # Generate initial Neuronpedia embed
        self.generate_neuronpedia_embed()

        # Add a spacer below the main layout
        spacer = widgets.Box(layout=widgets.Layout(height='20px'))

        # Display the main layout only once
        display(widgets.VBox([
            self.layout,
            spacer,
            self.treegen_output,
            self.visualization_and_controls_output
        ]))

    def setup_ui(self):
        self.sae_dropdown = widgets.Dropdown(
            options=list(sae_options.keys()),
            description='SAE layer:',
            layout=widgets.Layout(width='250px')
        )
        self.sae_dropdown.observe(self.on_sae_or_feature_change, names='value')

        self.feature_number = widgets.IntText(
            value=0,
            description='feature number:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='200px'),
            continuous_update=False
        )
        self.feature_number.observe(self.on_sae_or_feature_change, names='value')

        self.weight_type = widgets.RadioButtons(
            options=['encoder', 'decoder'],
            description='weight type:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='200px', padding='0 0 0 20px')
        )
        self.weight_type.observe(self.on_sae_or_feature_change, names='value')

        self.use_token_centroid = widgets.Checkbox(
            value=True,
            description='use token centroid offset',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        self.scaling_factor = widgets.FloatSlider(
            value=3.8,
            min=0.1,
            max=10.0,
            step=0.1,
            description='scaling factor:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        self.prob_cutoff = widgets.FloatLogSlider(
            value=0.00005,
            base=10,
            min=-8,
            max=-2,
            step=0.2,
            description='cumulative probability cutoff:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px'),
            format="0.5g"
        )

        self.pca_weight_slider = widgets.FloatSlider(
            value=0.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description='PCA 1st component weighting',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        self.generate_button = widgets.Button(
            description="generate definition tree",
            layout=widgets.Layout(width='200px')
        )
        self.generate_button.on_click(self.on_generate_clicked)

        # Add padding above the generate button
        generate_button_with_padding = widgets.VBox([
            widgets.Box(layout=widgets.Layout(height='20px')),
            self.generate_button
        ])

        controls = widgets.VBox([
            self.sae_dropdown,
            self.feature_number,
            self.weight_type,
            self.use_token_centroid,
            self.scaling_factor,
            self.prob_cutoff,
            self.pca_weight_slider,
            generate_button_with_padding
        ])

        return controls

    def on_feature_number_keydown(self, event):
        if event['type'] == 'keydown' and event['key'] == 'Enter':
            # Manually update the value and trigger the change
            self.feature_number.value = self.feature_number.value
            self.on_sae_or_feature_change(None)

    def update_output_container(self):
        # Arrange the controls and Neuronpedia widget in a horizontal layout
        self.control_and_neuronpedia = widgets.HBox([self.output_container, self.neuronpedia_output])

        # Clear the current output and display the refreshed layout
        with self.controls_output:
            clear_output(wait=True)
            display(self.control_and_neuronpedia)
            print("Control panel and Neuronpedia widget displayed.")

    def on_select_clicked(self, b):
        global weights

        self.controls_output.clear_output(wait=True)

        selected_sae = self.sae_dropdown.value
        selected_weight_type = self.weight_type.value

        weights = sae_weights[selected_sae][selected_weight_type]

        # Initialize the additional controls before trying to use generate_button
        self.init_additional_controls()

        self.display_success_message_and_controls()

        # Generate Neuronpedia iframe now that feature number is initialized
        self.generate_neuronpedia_embed()

    def hide_generate_button(self):
        self.generate_button.layout.display = 'none'

    def display_success_message_and_controls(self):
        layer = self.sae_dropdown.value
        layer_num = layer.split("layer ")[1]
        message_text = f"Layer {layer_num} SAE {self.weight_type.value} weights selected.\n\nTo proceed:\n1. select a feature number;\n2. choose whether to use the token centroid offset to build the feature vector (displacing the origin to the mean token embedding);\n3. choose a scaling factor for the (normalised) feature vector;\n4. set the cumulative probability cutoff for definition rollouts;\n5. set the PCA 1st component weighting w, to replace the normalised feature_vector with (1-w) * feature_vector + w * PCA1, before scaling (default w = 0 has no effect).\n\nFinally, click 'generate definition tree' to build and visualize the tree."

        self.loaded_message = widgets.HTML(
            value=f"<p style='white-space: pre-wrap;'>{message_text}</p>",
            layout=widgets.Layout(width='auto', padding='10px', border='1px solid black')
        )

        # Ensure the generate button is properly initialized and not None
        self.generate_button.disabled = False

        combined_controls = widgets.VBox([
            self.loaded_message,
            self.additional_controls
        ])

        with self.controls_output:
            clear_output(wait=True)
            display(combined_controls)

    def init_initial_controls(self):
        # Minor UI change: set up a tighter layout with a label, dropdown, and radio buttons
        self.sae_dropdown = widgets.Dropdown(
            options=['layer 0', 'layer 1', 'layer 2'],  # Example layers, replace with real options
            layout=widgets.Layout(width='150px')
        )

        self.weight_type = widgets.RadioButtons(
            options=['encoder', 'decoder'],
            description='weight type:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='200px', padding='0 0 0 20px')
        )

        sae_label = widgets.Label("SAE layer:")
        self.initial_controls = widgets.HBox([sae_label, self.sae_dropdown, self.weight_type])

        # Print to confirm controls are initialized
        print("Initial controls set up.")


    def init_additional_controls(self):
        self.feature_number = widgets.IntText(
            value=0,
            description='feature number:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='200px')
        )

        # Attach keydown event handler to detect Enter key and update Neuronpedia
        self.feature_number.on_submit(self.on_feature_number_keydown)

        self.use_token_centroid = widgets.Checkbox(
            value=True,
            description='use token centroid offset',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        self.scaling_factor = widgets.FloatSlider(
            value=3.8,
            min=0.1,
            max=10.0,
            step=0.1,
            description='scaling factor:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        self.prob_cutoff = widgets.FloatLogSlider(
            value=0.00005,
            base=10,
            min=-8,
            max=-2,
            step=0.2,
            description='cumulative probability cutoff:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px'),
            format="0.5g"  # This will display the value in a cleaner format
        )

        self.pca_weight_slider = widgets.FloatSlider(
            value=0.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description='PCA 1st component weighting',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='400px')
        )

        # Move the generate button to the bottom
        self.generate_button = widgets.Button(
            description="generate definition tree",
            layout=widgets.Layout(width='200px')
        )
        self.generate_button.on_click(self.on_generate_clicked)
        self.generate_button.disabled = True  # Disabled until weights are loaded

        # Stack the controls in the desired order
        self.additional_controls = widgets.VBox([
            widgets.HTML('<br/>'),
            self.feature_number,
            self.use_token_centroid,
            self.scaling_factor,
            self.prob_cutoff,
            self.pca_weight_slider,
            widgets.HTML('<br/>'),
            self.generate_button,
            widgets.HTML('<br/>')
        ])

        with self.controls_output:
            display(self.additional_controls)  # Now display the additional controls



    def generate_neuronpedia_embed(self):
        sae_layer = self.sae_dropdown.value.split()[-1]
        feature_number = self.feature_number.value

        iframe_url = f"https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"
        neuronpedia_link = f'<p>feature interpretation from <a href="https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}" target="_blank">https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}</a>:</p>'

        with self.neuronpedia_output:
            self.neuronpedia_output.clear_output(wait=True)
            display(HTML(neuronpedia_link))
            display(HTML(f'<iframe src="{iframe_url}" title="Neuronpedia" style="height: 300px; width: 540px;"></iframe>'))


    def on_sae_or_feature_change(self, change):
        self.generate_neuronpedia_embed()

    def on_feature_number_submit(self, event):
        # When Enter is pressed after typing the feature number, trigger the Neuronpedia update
        self.generate_neuronpedia_embed()

    def create_feature_vector(self):
        global sae_weights, PCA1_direction

        selected_sae = self.sae_dropdown.value
        selected_weight_type = self.weight_type.value
        weights = sae_weights[selected_sae][selected_weight_type]

        token_centroid = calculate_token_centroid(model)

        feature_vector = create_feature_vector(
            weights,
            self.feature_number.value,
            token_centroid,
            self.use_token_centroid.value,
            self.scaling_factor.value,
            selected_weight_type,
            self.pca_weight_slider.value
        )

        return feature_vector

    def on_generate_clicked(self, b):
        # Hide the generate button once clicked
        self.generate_button.layout.display = 'none'

        # Clear previous outputs
        self.treegen_output.clear_output(wait=True)
        self.visualization_and_controls_output.clear_output(wait=True)

        # Set up configuration
        config.running = True
        config.CUTOFF = self.prob_cutoff.value

        # Generate feature vector
        feature_vector = self.create_feature_vector()

        # Generate definition tree
        with self.treegen_output:
            print(f"Generating definition tree for feature number {self.feature_number.value}...\n")
            base_prompt = f'A typical definition of "{tokenizer.decode([config.SUB_TOKEN_ID])}" would be "'

        self.results_dict = generate_definition_tree(base_prompt, feature_vector, model, tokenizer, config, self.treegen_output)

        # Display tree and controls
        self.display_tree_and_controls(self.results_dict, initial=True)


    def display_tree_and_controls(self, results_dict, initial=False):
        #global results_dict

        with self.treegen_output:
            clear_output(wait=True)

        with self.visualization_and_controls_output:
            clear_output(wait=True)

            max_weight, min_weight = find_max_min_cumulative_weight(results_dict)

            current_trim_cutoff = 0 if initial else self.trim_cutoff.value

            unique_id = uuid.uuid4().hex
            image_name = f'definition_tree_{unique_id}'

            output_path = create_tree_diagram(
                results_dict,
                '.',
                image_name,
                config,
                max_weight,
                min_weight,
                trim_cutoff=current_trim_cutoff,
                image_height=5000,
                font_size=48,
                ranksep=1
            )

            if output_path is None or not os.path.exists(output_path):
                with self.visualization_and_controls_output:
                    print("Failed to generate tree diagram.")
                return

            with open(output_path, 'rb') as f:
                image_data = f.read()

            image_widget = widgets.Image(
                value=image_data,
                format='png',
                width='800'
            )

            vbox_contents = [
                image_widget,
                widgets.HTML(value=""),
                widgets.HTML(value="<h3>adjust visualisation:</h3>"),
                self.trim_cutoff,
                widgets.HTML(value=""),
                widgets.HBox([self.regenerate_button, self.exit_button])
            ]

            if initial:
                vbox_contents.insert(1, widgets.HTML(value=f"<p>tree visualisation complete (no trimming applied, cumulative probability cutoff = {format_probability(self.prob_cutoff.value)})</p>"))
            else:
                vbox_contents.insert(1, widgets.HTML(value=f"<p>tree visualisation complete (trim cutoff: {format_probability(current_trim_cutoff)})</p>"))

            vbox = widgets.VBox(vbox_contents)

            display(vbox)
            os.remove(output_path)

            self.regenerate_button.on_click(self.on_regenerate_clicked)
            self.exit_button.on_click(self.on_exit_clicked)

    def on_regenerate_clicked(self, b):
        if config.running:
            with self.treegen_output:
                clear_output(wait=True)
                print("regenerating tree with new cutoff, please wait...")

            with self.visualization_and_controls_output:
                clear_output(wait=True)
                display(HTML("<p></p>"))

            config.CUTOFF = self.trim_cutoff.value
            self.display_tree_and_controls(self.results_dict, initial=False)

    def on_exit_clicked(self, b):
        config.running = False
        with self.treegen_output:
            print("Tree regeneration process terminated.")


# Initialize and run the explorer
explorer = DefinitionTreeExplorer()


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='SAE layer:', layout=Layout(width='250px'), …

In [None]:
# @title functionality 2: generate token lists for SAE features
import torch
import torch.nn.functional as F
import ipywidgets as widgets
from IPython.display import display, HTML, clear_output, Math
import html
import time

@torch.no_grad()
def find_closest_tokens(_emb, token_embeddings, tokenizer, top_k=500, num_exp=1.4):
    # Ensure _emb and token_embeddings are on the same device and dtype
    _emb = _emb.to(token_embeddings.device, dtype=token_embeddings.dtype)

    # Normalize the feature vector and token embeddings
    _emb = F.normalize(_emb.view(1, -1), p=2, dim=1)
    token_embeddings_norm = F.normalize(token_embeddings, p=2, dim=1)

    # Calculate distances from the feature vector and token centroid
    similarities = torch.mm(_emb, token_embeddings_norm.t()).squeeze()
    token_centroid = torch.mean(token_embeddings, dim=0, keepdim=True)
    centroid_norm = F.normalize(token_centroid, p=2, dim=1)
    centroid_similarities = torch.mm(centroid_norm, token_embeddings_norm.t()).squeeze()

    # Apply the numerator exponent
    distances = torch.pow(1 - similarities, num_exp)
    centroid_distances = torch.pow(1 - centroid_similarities, 1)

    # Calculate ratio and find the top k closest tokens
    ratios = distances / centroid_distances
    top_ratios, top_indices = torch.topk(ratios, k=top_k, largest=False)

    # Decode tokens from indices
    closest_tokens = [tokenizer.decode([idx.item()]) for idx in top_indices]

    return list(zip(closest_tokens, top_ratios.tolist()))

# Function to calculate token centroid
def calculate_token_centroid(model):
    token_embeddings = model.get_input_embeddings().weight.data
    return torch.mean(token_embeddings, dim=0)

def create_feature_vector(weights, feature_number, token_centroid, use_token_centroid, scaling_factor, weight_type, pca_weighting):
    # Select feature vector based on encoder/decoder setting
    if weight_type == 'decoder':
        raw_feature_vector = weights[feature_number]
    else:  # encoder
        raw_feature_vector = weights[:, feature_number]

    raw_feature_direction = F.normalize(raw_feature_vector, p=2, dim=0)

    # Move raw_feature_vector and token_centroid to the same device (if needed)
    token_centroid = token_centroid.to(raw_feature_vector.device)

    # Ensure PCA1_direction is on the same device as the feature_vector
    pca_direction_on_same_device = PCA1_direction.to(raw_feature_vector.device)
    feature_direction = (1 - pca_weighting) * raw_feature_direction + pca_weighting * pca_direction_on_same_device

    # Apply the token centroid offset if selected
    if use_token_centroid:
        feature_vector = token_centroid + scaling_factor * feature_direction
    else:
        # Scale the normalized feature vector
        feature_vector = scaling_factor * feature_direction
    return feature_vector

class FeatureExplorer:
    def __init__(self):
        self.closest_tokens_with_values = None

        # Add SAE layer selection and weight type (encoder/decoder) options
        self.sae_dropdown = widgets.Dropdown(
            options=list(sae_options.keys()),  # Populate with SAE options from Cell 3
            description='SAE layer:',
            layout=widgets.Layout(width='250px')
        )
        self.sae_dropdown.observe(self.on_sae_or_feature_change, names='value')

        self.weight_type = widgets.RadioButtons(
            options=['encoder', 'decoder'],
            description='weight type:',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='200px', padding='0 0 0 20px')
        )
        self.weight_type.observe(self.on_sae_or_feature_change, names='value')

        # Initialize other controls (feature number, scaling factor, etc.)
        self.setup_ui()

        # Automatically trigger token generation on initialization
        self.preload_token_list()

    def setup_ui(self):
        self.feature_number = widgets.IntText(value=0, min=0, max=16383, description='feature number:', style={'description_width': 'initial'}, layout=widgets.Layout(width='200px'))
        self.feature_number.observe(self.on_sae_or_feature_change, names='value')

        self.use_token_centroid = widgets.Checkbox(value=True, description='use token centroid offset', style={'description_width': 'initial'})
        self.use_token_centroid.observe(self.on_parameter_change, names='value')

        self.scaling_factor = widgets.FloatSlider(min=0.1, max=10.0, step=0.025, value=3.8, description='scaling factor:', style={'description_width': 'initial'}, layout=widgets.Layout(width='500px'))
        self.scaling_factor.observe(self.on_parameter_change, names='value')

        self.num_exp = widgets.FloatSlider(min=0.1, max=5.0, step=0.025, value=1.4, description='numerator exponent (m):', style={'description_width': 'initial'}, layout=widgets.Layout(width='500px'))
        self.num_exp.observe(self.on_parameter_change, names='value')

        # Adding PCA weighting slider from Cell 4
        self.pca_weight_slider = widgets.FloatSlider(
            value=0.0,
            min=0.0,
            max=1.0,
            step=0.01,
            description='PCA 1st component weighting',
            style={'description_width': 'initial'},
            layout=widgets.Layout(width='500px')
        )
        self.pca_weight_slider.observe(self.on_parameter_change, names='value')

        self.generate_button = widgets.Button(
            description="generate top 100 token list",
            layout=widgets.Layout(width='400px')
        )
        self.generate_button.on_click(self.on_generate_clicked)

        self.neuronpedia_output = widgets.Output()

        # Separate output widget for token list
        self.token_list_output = widgets.Output()

        self.main_output = widgets.Output()

        # Arrange the controls in a box layout
        self.control_box = widgets.VBox([
            self.sae_dropdown,
            self.weight_type,
            self.feature_number,
            self.use_token_centroid,
            self.scaling_factor,
            self.num_exp,
            self.pca_weight_slider,  # Adding PCA slider here
            self.generate_button
        ])

        # Horizontal layout to organize Neuronpedia iframe next to the controls
        self.layout_container = widgets.HBox([self.control_box, self.neuronpedia_output], layout=widgets.Layout(width='100%'))

        display(self.layout_container)
        display(self.token_list_output)  # Display the token list output separately
        display(self.main_output)

        # Display the iframe right away based on initial selections
        self.generate_neuronpedia_embed()

    def preload_token_list(self):
        """Preload the token list to avoid the first click issue."""
        # Trigger a "silent" generation of the token list during initialization
        self.generate_token_lists()

    def reset_token_list_display(self):
        """Completely reset the token list display by recreating the widget."""
        # Close any existing widgets
        self.token_list_output.close()
        # Recreate the widget from scratch to ensure a fresh output container
        self.token_list_output = widgets.Output()
        display(self.token_list_output)

    def generate_token_lists(self):
        """Generate token lists based on the selected feature vector and embeddings."""
        global sae_weights, token_embeddings

        # Select the correct weights based on the dropdown and radio buttons
        weights = sae_weights[self.sae_dropdown.value][self.weight_type.value]

        # Move tensors to the same device (if needed)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        weights = weights.to(device)
        token_embeddings = token_embeddings.to(device)

        # Create feature vector using the selected parameters, including PCA weighting
        token_centroid = torch.mean(token_embeddings, dim=0)
        feature_vector = create_feature_vector(
            weights,
            self.feature_number.value,
            token_centroid,
            self.use_token_centroid.value,
            self.scaling_factor.value,
            self.weight_type.value,
            pca_weighting=self.pca_weight_slider.value  # Pass PCA weighting from the slider
        )

        # Find the closest tokens
        self.closest_tokens_with_values = find_closest_tokens(
            feature_vector,
            token_embeddings,
            tokenizer,
            top_k=500,
            num_exp=self.num_exp.value,
        )

        # Escape HTML characters to prevent formatting issues
        escaped_token_list = [html.escape(token) for token, _ in self.closest_tokens_with_values[:100]]

        # Clear the token list output before displaying new tokens
        self.token_list_output.clear_output()

        # Display top 100 tokens
        with self.token_list_output:
            display(HTML("<br>100 tokens whose embeddings produce the smallest ratio:<br><br>"))
            display(Math(r"\frac{(\textrm{cosine distance from feature vector})^m}{\textrm{cosine distance from mean token embedding}}"))
            display(HTML("<br>"))

            # Safely escape and display tokens as plain text, ensuring correct formatting
            display(HTML("<p>[{}]</p>".format(", ".join(f"'{token}'" for token in escaped_token_list))))
            display(HTML("<br>"))

    def generate_neuronpedia_embed(self):
        """Generate the Neuronpedia embed based on selected SAE layer and feature number."""
        # Extract the selected SAE layer (e.g., "Gemma-2B layer 6") and feature number
        sae_layer = self.sae_dropdown.value.split()[-1]  # Get the layer number
        feature_number = self.feature_number.value

        # Construct the Neuronpedia iframe URL
        iframe_url = f"https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}?embed=true&embedexplanation=true&embedplots=false&embedtest=false&height=300"

        # Add a line of text with an active link to the Neuronpedia page
        neuronpedia_link = f'<p>feature interpretation from <a href="https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}" target="_blank">https://neuronpedia.org/gemma-2b/{sae_layer}-res-jb/{feature_number}</a>:</p>'

        # Display the iframe and the link within the widget
        with self.neuronpedia_output:
            self.neuronpedia_output.clear_output(wait=True)
            display(HTML(neuronpedia_link))  # Add the link
            display(HTML(f'<iframe src="{iframe_url}" title="Neuronpedia" style="height: 300px; width: 540px;"></iframe>'))

    def on_generate_clicked(self, b):
        """Handle button click to generate token lists."""
        # Only generate tokens; do not clear output here.

        # Introduce a short delay to ensure everything updates smoothly
        time.sleep(0.2)  # 200ms delay

        # Generate the token lists and display them
        with self.main_output:
            self.main_output.clear_output(wait=True)
            self.generate_token_lists()

    def on_sae_or_feature_change(self, change):
        """Handle changes to SAE layer or feature number."""
        # Reset the token list display
        self.reset_token_list_display()

        # Clear outputs only when switching feature or SAE layer
        with self.token_list_output:
            self.token_list_output.clear_output(wait=True)

        # Re-initialize Neuronpedia iframe
        self.generate_neuronpedia_embed()

    def on_parameter_change(self, change):
        """Handle other parameter changes (scaling factor, centroid toggle, etc.)."""
        self.closest_tokens_with_values = None
        # No Neuronpedia reload here as the change does not affect the iframe.

# Instantiate and display the FeatureExplorer
explorer = FeatureExplorer()


HBox(children=(VBox(children=(Dropdown(description='SAE layer:', layout=Layout(width='250px'), options=('Gemma…

Output()

Output()

Output()

Output()

Output()

Output()