# Circuit autointerpretability

This stuff just sets up everything we need.

In [None]:
from autointerpretability import *

# Autoreload
%load_ext autoreload
%autoreload 2

In [None]:
config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)

Note you can specify the features you want to examine, in each layer, and just pass in either the relevant ZSAE or MLP transcoder depending on what component you want to look at. The `get_feature_scores` function will handle the differences. Let's have a look at the max-activating examples on Danny's features he wanted to check out (note you can slice `owt_tokens_torch` to run for shorter).

In [None]:
features = [16513, 7861]
sae = z_saes[8]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

Our feature scores are a tensor of shape `(batch, feature, seq_pos)`, and so I've got a function to help extract the max-activating examples for each feature. You need to specify the feature index, which is why it's helpful to know from above the features in your list.

In [None]:
feature_scores.shape

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15)

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=3)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
# Same thing
feature_idx = 1
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15, display_html=False)
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)
pretty_print_tokens_logits(top_tokens, top_logits)
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

You can also pass in and boost logits for multiple features at a time.

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features, k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)


examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

Then, you can just pass it off to GPT-4 to interpret what's going on. Note that I haven't got access to `GPT-4o` with my credits yet, so this will have to wait a few days.

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)

In [None]:
print(feature_interpretation)

Finally, we can pass in multiple features at once to see the max activating examples for features together.

In [None]:
_ = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

However, instead of passing in individual features for specific components in specific layers, I created an object called `CircuitPrediction` to basically store all this stuff for you. I'll quickly illustrate how to use it in conjunction with the above.

In [None]:
cp = get_circuit_prediction(task='ioi', N=20)

The main thing you'll want to do with this is get features from certain components to look at on a specific task. The features for each component are stored in the circuit hypergraph. For instance:

In [None]:
cp.circuit_hypergraph

If you want to look at MLP 3, all you have to do is access it:

In [None]:
cp.circuit_hypergraph['MLP3']

And just repeat what we did above:

In [None]:
features = list(set(cp.circuit_hypergraph['MLP3']['features']))
transcoder = transcoders[3]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=64)

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch, k=5, show_score=True)

There's a few other methods, but you probably don't need to bother with those.

In [None]:
_ = cp.unique_feature_array(visualize=True)

# Playground

What do we want to actually look for?
* We could take specific components, and look at all their features across the circuit hypergraph, then get some sort of "mass autointerpretation" of what this feature is doing. I think for this you'd need to also feed in information from where it activates on the actual circuit. Might seem a bit soft and qualitative, but if you do it principled enough, it could be useful. Also try weighting the cluster max-act examples + logits by how often the feature shows up.
* Look at what features co-occur together in examples. Should give more signal than just looking at features that activate heaps. (Also look at features that activate strongly across all examples as well though.) 

## Feature cluster interpretation of model components

## Co-occurrence of features

In [None]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

In [None]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)

In [None]:
cp.co_occurrence_dict

In [None]:
cp.get_cooccurrences("MLP0", "L9H9")

In [None]:
cp.get_cooccurrences("MLP0", "L9H6")

In [None]:
cp.visualize_co_occurrences()

In [None]:
cp.get_top_k_feature_tuples(k=10)

In [None]:
from collections import Counter, defaultdict

def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L8H6", k=20)

In [None]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L9H9", k=20)

In [None]:
features = [20101]
sae = z_saes[9]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

In [None]:
# Display top k activating examples
feature_idx = 0 # corresponding to 16109
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=5)

encoder_feature_pairs = [(sae, [20101])]

# top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)

top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
features = [6798]
transcoder = transcoders[0]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=128)

In [None]:
# Display top k activating examples
feature_idx = 0 # corresponding to 16109
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=5)

encoder_feature_pairs = [(transcoder, [6798])]

# top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)

top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L9H9", k=20)

encoder_feature_pairs = [(transcoders[0], [6798]), (z_saes[9], [20101])]
top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_scores = get_feature_scores_across_layers(model, encoder_feature_pairs, owt_tokens_torch, batch_size=128)

In [None]:
examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=25, show_score=True)

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L4H11", k=20)

In [None]:
layer = 8
head = 6
feature = 16513

features = [feature]
sae = z_saes[layer]

feature_scores = get_feature_scores(model, sae, owt_tokens_torch[:1024*4], features, batch_size=128)

In [None]:
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch[:1024*4], k=15)

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, [(sae, [feature])], k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
from autointerp_prompts import get_opening_prompt

# Autoreload
%load_ext autoreload
%autoreload 2

def new_get_response(llm_client, examples_clean_text, top_tokens):
    opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)
    messages = [{"role": "user", "content": opening_prompt}]
    response = llm_client.chat.completions.create(
        model="gpt4_large",
        messages=messages,
    )
    return f"{response.choices[0].message.content}"

In [None]:
feature_interpretation = new_get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

## Autointerp over clusters of features

In [None]:
# Go through co-occurrence dict and print any tuples that have a feature over 24576
for k, v in cp.co_occurrence_dict.items():
    for feature_tuple in v:
        if any([f > 24576 for f in feature_tuple]):
            print(k, feature_tuple)

In [None]:
sae = z_saes[5]
sae.W_dec.shape

In [None]:
def feature_scores_for_component_cluster(component_name: str, layer: int):
    features = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]

    sae = z_saes[layer]
    feature_scores = get_feature_scores(model, sae, owt_tokens_torch[:1024*4], features, batch_size=128)

    top_tokens, top_logits = get_top_k_tokens(model, [(sae, features)], k=10, act_strength=5)

    return feature_scores, top_tokens, top_logits

In [None]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L0_H1", 0)

In [None]:
example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=15)

In [None]:
pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L5_H5", 5)

In [None]:
example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L10_H7", 10)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_scores.shape

In [None]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L8_H6", 8)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_scores.shape

# Evaluation

Workflow for evaluating autointerp of specific method:
* Get the interpretation of the feature/family of features using everything we've set up above.
* Provide that interpretation, along with the reasoning, to another LLM.
* For a bunch of different IOI sequences, and for randomly selected tokens in that sequence, get the LLM to predict the feature activations.
* Look at the correlation coefficient between the predicted feature activations and the actual feature activations. This is the evaluation metric.

There are a few changes based on if we're looking at a specific feature (i.e. when we're just ablating token importance information added in to autointerp prompt) vs. a family of features. When we have a family of features:
* We still get the interpretation as before - everything about this is normal. We look at max-activating examples for all features (not summing though - we need to change this to an argmax type thing for each example). For max-boosted logits, we boost all features at once. 
* When we get the feature activations, I think we should take the max-activation on each token of _any_ feature in the feature family. In this way, the autointerp is just predicting whether this feature family will fire in general - still gives a good idea of performance.

For feature co-occurrence, we do the same correlation score approach as family of features, but:
* When getting the interpretation, we look at max-activating examples for both sets of features (max for both on an example, divided by 2; same setup as feature families but now there's two features). Logits are boosted by both simultaeneously. 


Will have to figure out a way to combine token importances with feature families/family co-occurrences when the time comes.

Finally, we need a baseline to compare against.
* For token importances, this is easy, since we're only looking at one feature at a time. Just compare the correlation score with and without the token importance information.
* For feature families, we can compare to the average correlation score of running the autointerp on a given feature in that family. Whilst not exactly the same thing, it gives an idea about how feature families help us to generalise.
* For feature co-occurrences, we should also....

## Important tokens

In [None]:
%load_ext autoreload
%autoreload 2

import torch

from max_act_analysis import MaxActAnalysis,open_web_text_tokens
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2
from openai_utils import gen_openai_completion, get_response
from autointerpretability import *
from discovery_strategies import (
    create_filter,
    create_simple_greedy_strategy,
    create_top_contributor_strategy,
)

torch.set_grad_enabled(False)


In [None]:
# %%
feature = 27535
# feature = 16401
layer = 5
# feature = 15647
# num_examples = 1000
num_examples = 5000

strategy = create_simple_greedy_strategy(
    passes=1,
    node_contributors=1,
    minimal=True,
)

analyze = MaxActAnalysis("attn", layer, feature, num_sequences=num_examples, batch_size=128, strategy=strategy)
analyze.show_top_active_examples(num_examples=15)

In [None]:
mini_examples = analyze.get_context_referenced_prompts_for_range(0, 25)

p = main_aug_interp_prompt(mini_examples)

print(p)

In [None]:
from aug_interp_prompts import main_aug_interp_prompt, main_aug_interp_prompt_v2

p_base = main_aug_interp_prompt_v2(mini_examples)
print(p_base)

In [None]:
interp = get_response(p)
print(interp)

In [None]:
interp = "This neuron appears to activate when it encounters an example of 'pair linking' in the text, usually manifested through a conjunction like 'and', activating on the repetition of a pair of words from earlier in the text to later in the text. The activating token is the second of the pair."

In [None]:
interp_base = get_response(p_base)
print(interp_base)

In [None]:
from autointerpretability import *

device = 'cpu'
model, z_saes, transcoders = get_model_encoders(device=device)

In [None]:
# Get some IOI examples where it activates (max activation across feature family is the ground-truth)

feature = 27535
layer = 5
component_type = 'attn'

# Get the actual prompts
n_prompts = 100
dataset_prompts = gen_templated_prompts(template_idex=1, N=n_prompts)
dataset_prompts = [x['text'] + x['correct'] for x in dataset_prompts]
dataset_tokens = model.to_tokens(dataset_prompts)

# Run the model over the prompts and get the feature activations at each token in each prompt
_, cache = model.run_with_cache(dataset_tokens)
z = cache["z", layer]
b, s, n, d = z.shape
del cache
z = einops.rearrange(z, "b s n d -> (b s) (n d)")

# Apply relevant SAE or transcoder to the activations
if component_type == 'attn':
    encoder = z_saes[layer]
else:
    encoder = transcoders[layer]

z_hidden = encoder.encode(z)
z_hidden = einops.rearrange(z_hidden, "(b s) h -> b s h", s=s)

# Only keep feature indices (last dimension)
feature_indices = [feature]
z_hidden = z_hidden[:, :, feature_indices]

# For each batch (first dimension) and each token in each batch (second dimension), only keep the max activation (third dimension)
z_hidden = z_hidden.max(dim=2).values

# Print rows which have a non-zero value
non_zero_indices = np.where(z_hidden > 0.0)[0].tolist()[:5]
print(non_zero_indices)

# Keep 3 of these rows plus the row after
indices_to_keep = []
for i, j in enumerate(non_zero_indices):
    indices_to_keep.extend([j])

print(indices_to_keep)

# Keep z_hidden rows
z_hidden = z_hidden[indices_to_keep]

z_hidden

In [None]:
# Set up new LLM interpreter given interpretation
from jinja2 import Template
from typing import List

def follow_up_activation_prediction_prompt(
    interpretation: str, sentence: str
):
    last_word_in_sentence = sentence.split()[-1]
    
    template = Template(
        """
{# You are an AI researcher continuing an important investigation into a certain neuron in a language model. Your task is to predict whether this neuron will activate on the final word of a given sentence based on a previously provided interpretation of the neuron's behavior. Here's how you will complete this task: #}

You are an AI researcher continuing an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to predict whether this neuron will have a zero or non-zero activation on the final word of a given sentence based on the provided interpretation of the neuron's behavior.

INTERPRETATION:
{{interpretation}}

INPUT:
The sentence to analyze is:
=================================================
{{sentence}}
=================================================

The final word to analyze is: {{last_word_in_sentence}}

OUTPUT:
Based on the provided interpretation, analyze the sentence and describe your reasoning in two sentences. Then, predict whether the neuron will have a zero or non-zero activation on the final word of the sentence. Provide your answer in the following format:
[ANALYSIS]: <two sentences of analysis>
[ACTIVATION]: zero or non-zero

Guidelines:
- Carefully consider the interpretation and apply it to the given sentence.
- Your analysis should be concise and relevant to the provided interpretation.
- Your prediction should be either "zero" or "non-zero".

EXAMPLE:
[ANALYSIS]: The final word in the sentence fits the pattern described in the interpretation. The context provided in the sentence suggests a non-zero activation.
[ACTIVATION]: non-zero
"""
    )

    return template.render(
        {"interpretation": interpretation, "sentence": sentence, "last_word_in_sentence": last_word_in_sentence}
    )

In [None]:
# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    # positions = [5, 6, 14, 15, 19, 20]
    threshold = 1.0
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    print(non_zero_indices)
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x > 0 and x < len(sentence_tokens)]

    if len(positions) == 0:
        # Randomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    print(positions)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(interp, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        print(prediction)
        print(sentence_str_example)
        print(ground_truth)
        print()
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > threshold else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

# Means
print(np.mean(accuracies))
print(np.mean(f1s))

In [None]:
# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    # positions = [5, 6, 14, 15, 19, 20]
    threshold = 1.0
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    print(non_zero_indices)
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x > 0 and x < len(sentence_tokens)]

    if len(positions) == 0:
        # Randomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    print(positions)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(interp_base, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        print(prediction)
        print(sentence_str_example)
        print(ground_truth)
        print()
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > threshold else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

# Means
print(np.mean(accuracies))
print(np.mean(f1s))

In [None]:
results_dict = {
    "16513_8_attn": {
        "token_importance":{
            "accuracies": 0.83,
            "f1s": 0.66,
        },
        "base": {
            "accuracies": 0.79,
            "f1s": 0.54,
        }
    },
    "24166_2_attn": {
        "token_importance":{
            "accuracies": 0.73,
            "f1s": 0.2,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.0,
        }
    },
    "27535_5_attn": {
        "token_importance":{
            "accuracies": 0.76,
            "f1s": 0.4,
        },
        "base": {
            "accuracies": 0.36,
            "f1s": 0.2,
        }
    },
}

In [None]:
import plotly.graph_objects as go

results_dict = {
    "16513_8_attn": {
        "token_importance": {
            "accuracies": 0.83,
            "f1s": 0.66,
        },
        "base": {
            "accuracies": 0.79,
            "f1s": 0.54,
        }
    },
    "24166_2_attn": {
        "token_importance": {
            "accuracies": 0.73,
            "f1s": 0.2,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.0,
        }
    },
    "27535_5_attn": {
        "token_importance": {
            "accuracies": 0.76,
            "f1s": 0.4,
        },
        "base": {
            "accuracies": 0.36,
            "f1s": 0.2,
        }
    },
}

def create_label(key):
    feature_id, layer, attn = key.split("_")
    return f"Feat. {feature_id} (L{layer} Attn.)"

labels = [create_label(key) for key in results_dict.keys()]
token_importance_accuracies = [results_dict[key]["token_importance"]["accuracies"] for key in results_dict.keys()]
token_importance_f1s = [results_dict[key]["token_importance"]["f1s"] for key in results_dict.keys()]
base_accuracies = [results_dict[key]["base"]["accuracies"] for key in results_dict.keys()]
base_f1s = [results_dict[key]["base"]["f1s"] for key in results_dict.keys()]

fig = go.Figure(data=[
    go.Bar(name='Token Importance Accuracies', x=labels, y=token_importance_accuracies, marker_color='#ADD8E6'),  # Light blue
    go.Bar(name='Token Importance F1s', x=labels, y=token_importance_f1s, marker_color='#00008B'),  # Dark blue
    go.Bar(name='Base Accuracies', x=labels, y=base_accuracies, marker_color='#90EE90'),  # Light green
    go.Bar(name='Base F1s', x=labels, y=base_f1s, marker_color='#006400')  # Dark green
])

fig.update_layout(
    title='Token Importance vs. Base Performance',
    xaxis_title='Feature',
    yaxis_title='Score',
    barmode='group',
    legend_title_text='Metrics',
    font=dict(size=14),
    template='plotly_white',
    width=900
)

fig.show()

## Feature families

In [None]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)

In [None]:
def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L5H5", k=20)

In [None]:
cp.get_top_k_feature_tuples(k=20)

In [None]:
component_name = 'L5_H5'
layer = 5

features = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]
print(features)

In [None]:
from openai_utils import get_response

def get_interpretation_old(examples_clean_text, top_tokens):
    opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)
    return get_response(opening_prompt) #gen_openai_completion(opening_prompt, visualize_stream=False)

def feature_scores_for_component_cluster(component_name: str, layer: int):
    features = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]

    sae = z_saes[layer]
    feature_scores = get_feature_scores(model, sae, owt_tokens_torch[:1024*4], features, batch_size=128)

    top_tokens, top_logits = get_top_k_tokens(model, [(sae, features)], k=10, act_strength=5)

    return feature_scores, top_tokens, top_logits

In [None]:
# Get the interpretation

feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster(component_name, layer)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], 
                                                                          feature_indices=[x for x in range(feature_scores.shape[1])], k=25, display_html=False)

top_tokens, top_logits = get_top_k_tokens(model, [(z_saes[layer], [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1])], k=10, act_strength=5)

interpretation = get_interpretation_old(examples_clean_text, top_tokens)

print(interpretation)

In [None]:
# Get some IOI examples where it activates (max activation across feature family is the ground-truth)

# Get the actual prompts
n_prompts = 100
dataset_prompts = gen_templated_prompts(template_idex=1, N=n_prompts)
dataset_prompts = [x['text'] + x['correct'] for x in dataset_prompts]
dataset_tokens = model.to_tokens(dataset_prompts)

# Run the model over the prompts and get the feature activations at each token in each prompt
_, cache = model.run_with_cache(dataset_tokens)
z = cache["z", layer]
b, s, n, d = z.shape
del cache
z = einops.rearrange(z, "b s n d -> (b s) (n d)")

# Apply relevant SAE or transcoder to the activations
if component_name.startswith("L"):
    encoder = z_saes[layer]
else:
    encoder = transcoders[layer]

z_hidden = encoder.encode(z)
z_hidden = einops.rearrange(z_hidden, "(b s) h -> b s h", s=s)

# Only keep feature indices (last dimension)
feature_indices = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]
print(feature_indices)
z_hidden = z_hidden[:, :, feature_indices]

# For each batch (first dimension) and each token in each batch (second dimension), only keep the max activation (third dimension)
z_hidden = z_hidden.max(dim=2).values

# Set first entry in each batch to 0
z_hidden[:, 0] = 0

# Print rows which have a non-zero value
non_zero_indices = np.where(z_hidden > 0.0)[0].tolist()[:3]
print(non_zero_indices)

# Keep 3 of these rows plus the row after
indices_to_keep = []
for i, j in enumerate(non_zero_indices):
    indices_to_keep.extend([j, j+1])

print(indices_to_keep)

# Keep z_hidden rows
z_hidden = z_hidden[indices_to_keep]

z_hidden

In [None]:
# Set up new LLM interpreter given interpretation
from jinja2 import Template
from typing import List

def follow_up_activation_prediction_prompt(
    interpretation: str, sentence: str
):
    last_word_in_sentence = sentence.split()[-1]
    
    template = Template(
        """
{# You are an AI researcher continuing an important investigation into a certain neuron in a language model. Your task is to predict whether this neuron will activate on the final word of a given sentence based on a previously provided interpretation of the neuron's behavior. Here's how you will complete this task: #}

You are an AI researcher continuing an important investigation into a certain neuron in a language model. This language model is trained to predict the text that will follow a given input. Your task is to predict whether this neuron will have a zero or non-zero activation on the final word of a given sentence based on the provided interpretation of the neuron's behavior.

INTERPRETATION:
{{interpretation}}

INPUT:
The sentence to analyze is:
=================================================
{{sentence}}
=================================================

The final word to analyze is: {{last_word_in_sentence}}

OUTPUT:
Based on the provided interpretation, analyze the sentence and describe your reasoning in two sentences. Then, predict whether the neuron will have a zero or non-zero activation on the final word of the sentence. Provide your answer in the following format:
[ANALYSIS]: <two sentences of analysis>
[ACTIVATION]: zero or non-zero

Guidelines:
- Carefully consider the interpretation and apply it to the given sentence.
- Your analysis should be concise and relevant to the provided interpretation.
- Do not be too rigid; if the interpretation provides an example of an activating token don't assume that specific token always has to be present - follow the pattern instead.
- For instance, if the interpretation suggests the neuron activates on names and provides an example name 'David', whilst suggesting other names can activate it, don't just predict zero if David isn't present.
- Your prediction should be either "zero" or "non-zero".

EXAMPLE:
[ANALYSIS]: The final word in the sentence fits the pattern described in the interpretation. The context provided in the sentence suggests a non-zero activation.
[ACTIVATION]: non-zero
"""
    )

    return template.render(
        {"interpretation": interpretation, "sentence": sentence, "last_word_in_sentence": last_word_in_sentence}
    )


In [None]:
# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    # positions = [5, 6, 14, 15, 19, 20]
    threshold = 1.0
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    print(non_zero_indices)
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x > 0 and x < len(sentence_tokens)]

    if len(positions) == 0:
        # Randomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    print(positions)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(interpretation, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        print(prediction)
        print(sentence_str_example)
        print(ground_truth)
        print()
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > threshold else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

In [None]:
# from openai import AzureOpenAI

# config = yaml.safe_load(open("config.yaml"))
# azure_client = AzureOpenAI(
#     azure_endpoint=config["base_url"],
#     api_key=config["azure_api_key"],
#     api_version=config["api_version"],
# )

# opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)
# print(opening_prompt)
# messages = [{"role": "user", "content": opening_prompt}]
# response = azure_client.chat.completions.create(
#     model="gpt4_large",
#     messages=messages,
# )
# interpretation = f"{response.choices[0].message.content}"

interpretation = """ 
[EXPLANATION]: The neuron activates on names, especially in possessive or direct reference contexts, and activates in texts involving repeated mentions of specific names or entities.
"""

In [None]:
# Mean of accuracies and f1 scores
print(f"Mean Accuracy: {np.mean(accuracies)}")
print(f"Mean F1 Score: {np.mean(f1s)}")

In [None]:
# # For each individual feature in our feature family, rerun the autointerp and get the scores for that individual feature
# features = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]

# feature_scores = get_feature_scores(model, z_saes[layer], owt_tokens_torch[:1024*4], feature_indices=[x for x in range(len(features))], batch_size=128)

feat_idx = 1

examples_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feat_idx, :], owt_tokens_torch[:1024*4], k=10, display_html=False)

top_tokens, top_logits = get_top_k_tokens(model, [(z_saes[layer], [features[feat_idx]])], k=10, act_strength=5)

interpretation = get_interpretation(examples_clean_text, top_tokens)

print(interpretation)

In [None]:
# Get some IOI examples where it activates (max activation across feature family is the ground-truth)

# Get the actual prompts
n_prompts = 100
dataset_prompts = gen_templated_prompts(template_idex=1, N=n_prompts)
dataset_prompts = [x['text'] + x['correct'] for x in dataset_prompts]
dataset_tokens = model.to_tokens(dataset_prompts)

# Run the model over the prompts and get the feature activations at each token in each prompt
_, cache = model.run_with_cache(dataset_tokens)
z = cache["z", layer]
b, s, n, d = z.shape
del cache
z = einops.rearrange(z, "b s n d -> (b s) (n d)")

# Apply relevant SAE or transcoder to the activations
if component_name.startswith("L"):
    encoder = z_saes[layer]
else:
    encoder = transcoders[layer]

z_hidden = encoder.encode(z)
z_hidden = einops.rearrange(z_hidden, "(b s) h -> b s h", s=s)

# Only keep feature indices (last dimension)
feature_indices = [features[feat_idx]]
z_hidden = z_hidden[:, :, feature_indices]

# For each batch (first dimension) and each token in each batch (second dimension), only keep the max activation (third dimension)
z_hidden = z_hidden.max(dim=2).values

# Print rows which have a non-zero value
non_zero_indices = np.where(z_hidden > 0.0)[0].tolist()[:3]
print(non_zero_indices)

# Keep 3 of these rows plus the row after
indices_to_keep = []
for i, j in enumerate(non_zero_indices):
    indices_to_keep.extend([j, j+1])

print(indices_to_keep)

# Keep z_hidden rows
z_hidden = z_hidden[indices_to_keep]

z_hidden.shape

In [None]:
from openai_utils import get_response

# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    threshold = 0.0
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x >= 0 and x < len(sentence_tokens)]
    print(positions)

    if len(positions) == 0:
        # Radomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(interpretation, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > 0.0 else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    print(pred_dict)

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

In [None]:
# Means
print(f"Mean Accuracy: {np.mean(accuracies)}")
print(f"Mean F1 Score: {np.mean(f1s)}")

In [None]:
results_dict = {
    "L5H5": {"families": {"f1": 1.0, "accuracy": 1.0}, "individual": {"f1": 0.0, "accuracy": 0.0}},
    "L8H6": {},
    "L0H1": {"families": {"f1": 0.59, "accuracy": 0.81}, "individual": {"f1": 0.05, "accuracy": 0.75}},
    "L2H2": {"families": {"f1": 0.75, "accuracy": 0.91}, "individual": {"f1": 0.06, "accuracy": 0.49}},
}

In [None]:
import plotly.graph_objects as go

results_dict = {
    "L5H5": {"families": {"f1": 1.0, "accuracy": 1.0}, "individual": {"f1": 0.02, "accuracy": 0.4}},
    "L0H1": {"families": {"f1": 0.59, "accuracy": 0.81}, "individual": {"f1": 0.05, "accuracy": 0.75}},
    "L2H2": {"families": {"f1": 0.75, "accuracy": 0.91}, "individual": {"f1": 0.06, "accuracy": 0.49}},
}

labels = list(results_dict.keys())
families_f1 = [results_dict[key]["families"]["f1"] if "families" in results_dict[key] else None for key in labels]
families_accuracy = [results_dict[key]["families"]["accuracy"] if "families" in results_dict[key] else None for key in labels]
individual_f1 = [results_dict[key]["individual"]["f1"] if "individual" in results_dict[key] else None for key in labels]
individual_accuracy = [results_dict[key]["individual"]["accuracy"] if "individual" in results_dict[key] else None for key in labels]

fig = go.Figure(data=[
    go.Bar(name='Families F1', x=labels, y=families_f1, marker_color='#ADD8E6'),  # Light blue
    go.Bar(name='Families Accuracy', x=labels, y=families_accuracy, marker_color='#00008B'),  # Dark blue
    go.Bar(name='Individual F1', x=labels, y=individual_f1, marker_color='#90EE90'),  # Light green
    go.Bar(name='Individual Accuracy', x=labels, y=individual_accuracy, marker_color='#006400')  # Dark green
])

fig.update_layout(
    title='Family vs. individual feature autointerp - IOI performance',
    xaxis_title='Attention Head',
    yaxis_title='Score',
    barmode='group',
    legend_title_text='Metrics',
    font=dict(size=14),
    template='plotly_white',
    width=900
)

fig.show()

## Feature co-occurrence

In [198]:
dataset_prompts = gen_templated_prompts(template_idex=1, N=500)
dataset_prompts = [x['text'] + x['correct'] for x in dataset_prompts]

names = ['David', 'Elizabeth', 'Paul', 'Sarah']

# Count occurrence of each name in dataset prompts
name_counts = {name: 0 for name in names}

for prompt in dataset_prompts:
    for name in names:
        if name in prompt:
            name_counts[name] += 1

name_counts

{'David': 22, 'Elizabeth': 18, 'Paul': 22, 'Sarah': 20}

In [221]:
import numpy as np
import plotly.graph_objects as go
from scipy.stats import linregress

dataset_occurrence = np.array([23, 22, 18, 19])
dataset_occurrence = np.divide(dataset_occurrence, np.sum(dataset_occurrence))
co_occurrence = np.array([13, 13, 11, 11]) / np.sum([13, 13, 11, 11])

# Perform linear regression
slope, intercept, r_value, p_value, std_err = linregress(co_occurrence, dataset_occurrence)

# Create the line of best fit
line_of_best_fit = slope * co_occurrence + intercept

fig = go.Figure()

fig.add_trace(go.Scatter(x=co_occurrence, y=dataset_occurrence,
                         mode='markers',
                         name='Names',
                         marker=dict(color='blue', size=10)))

fig.add_trace(go.Scatter(x=co_occurrence, y=line_of_best_fit,
                         mode='lines',
                         name='LS Fit',
                         line=dict(color='red', width=2)))

fig.update_layout(
    title='Co-occurrence of name feature (MLP0 + L2H2) vs. Dataset Occurrence',
    xaxis_title='Co-occurrence',
    yaxis_title='Dataset Occurrence',
    font=dict(size=14),
    template='plotly_white',
    width=800,
    height=600
)

fig.show()

In [None]:
import pandas as pd

freq = pd.read_csv("data/unigram_freq.csv")
freq.head()

In [None]:
freq[freq.word == 'elizabeth']

In [199]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L2H2", k=20)

defaultdict(dict,
            {(('mlp_feature', 0), ('attn_head', 2, 2)): {(4522, 24166): 13,
              (20546, 24166): 13,
              (20323, 24166): 13,
              (630, 24166): 13,
              (10461, 14186): 11,
              (17363, 14186): 11,
              (17845, 14186): 11,
              (7734, 14186): 11,
              (14148, 14186): 10,
              (23507, 24166): 10,
              (5348, 24166): 9,
              (12965, 14186): 8,
              (156, 24166): 7,
              (3201, 6510): 7,
              (11530, 6510): 7,
              (5245, 6510): 7,
              (18880, 4268): 7,
              (12385, 4268): 7,
              (20546, 4268): 7,
              (10461, 4268): 7}})

In [208]:
encoder_feature_pairs = [(transcoders[0], [12965]), (z_saes[2], [14186])]
top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20)

pretty_print_tokens_logits(top_tokens, top_logits)

╒════════════════╤═════════╕
│ Token          │   Logit │
╞════════════════╪═════════╡
│ [34mpn[0m             │  [32m5.8298[0m │
├────────────────┼─────────┤
│ [34m NEC[0m           │  [32m5.7384[0m │
├────────────────┼─────────┤
│ [34m pleasure[0m      │  [32m5.6151[0m │
├────────────────┼─────────┤
│ [34mULT[0m            │  [32m5.6104[0m │
├────────────────┼─────────┤
│ [34m Klu[0m           │  [32m5.3963[0m │
├────────────────┼─────────┤
│ [34m TL[0m            │  [32m5.3659[0m │
├────────────────┼─────────┤
│ [34mTL[0m             │  [32m5.1752[0m │
├────────────────┼─────────┤
│ [34mWhit[0m           │  [32m5.1358[0m │
├────────────────┼─────────┤
│ [34m Parliamentary[0m │  [32m5.0727[0m │
├────────────────┼─────────┤
│ [34m amusement[0m     │  [32m5.0423[0m │
├────────────────┼─────────┤
│ [34m EV[0m            │  [32m5.03[0m   │
├────────────────┼─────────┤
│ [34m microscope[0m    │  [32m4.9409[0m │
├────────────────┼─────────┤
│

In [209]:
feature_scores = get_feature_scores_across_layers(model, encoder_feature_pairs, owt_tokens_torch[:1024*2], batch_size=128)

SparseTranscoder


100%|██████████| 16/16 [00:10<00:00,  1.46it/s]


ZSAE


100%|██████████| 16/16 [00:12<00:00,  1.25it/s]


In [211]:
feature_scores.shape

(2048, 2, 128)

In [213]:
examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*8], feature_indices=[0, 1], k=20, display_html=True)

In [None]:
opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)

# interpret
interpretation = get_response(opening_prompt)
print(interpretation)

In [None]:
interpretation = "The neuron is activated by the use of many different common names. These names include 'David', as well as other names and words in the context of preceding names earlier in the sentences."

In [None]:
# Get some IOI examples where it activates (max activation across feature family is the ground-truth)
layer = 2
feature = 24166
component_name = 'L2_H2'

# Get the actual prompts
n_prompts = 100
dataset_prompts = gen_templated_prompts(template_idex=1, N=n_prompts)
dataset_prompts = [x['text'] + x['correct'] for x in dataset_prompts]
dataset_tokens = model.to_tokens(dataset_prompts)

# Run the model over the prompts and get the feature activations at each token in each prompt
_, cache = model.run_with_cache(dataset_tokens)
z = cache["z", layer]
b, s, n, d = z.shape
del cache
z = einops.rearrange(z, "b s n d -> (b s) (n d)")

# Apply relevant SAE or transcoder to the activations
if component_name.startswith("L"):
    encoder = z_saes[layer]
else:
    encoder = transcoders[layer]

z_hidden = encoder.encode(z)
z_hidden = einops.rearrange(z_hidden, "(b s) h -> b s h", s=s)

# Only keep feature indices (last dimension)
feature_indices = [feature]
z_hidden = z_hidden[:, :, feature_indices]

# For each batch (first dimension) and each token in each batch (second dimension), only keep the max activation (third dimension)
z_hidden = z_hidden.max(dim=2).values

# Print rows which have a non-zero value
non_zero_indices = list(set(np.where(z_hidden > 1.0)[0].tolist()))[:5]
print(non_zero_indices)

# Keep 3 of these rows plus the row after
indices_to_keep = []
for i, j in enumerate(non_zero_indices):
    indices_to_keep.extend([j])

print(indices_to_keep)

# Keep z_hidden rows
z_hidden = z_hidden[indices_to_keep]

z_hidden

In [None]:
# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    threshold = 0.5
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x >= 0 and x < len(sentence_tokens)]
    print(positions)

    if len(positions) == 0:
        # Radomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(interpretation, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}
        print(prediction)
        print(sentence_str_example)
        print(ground_truth)

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > 0.0 else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

# Means
print(f"Mean Accuracy: {np.mean(accuracies)}")
print(f"Mean F1 Score: {np.mean(f1s)}")

In [None]:
# Repeat with just attn head component
encoder_feature_pairs = [(z_saes[2], [24166])]
top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20)

# Get feature scores
feature_scores = get_feature_scores_across_layers(model, encoder_feature_pairs, owt_tokens_torch[:8*1024], batch_size=128)

examples_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores, owt_tokens_torch[:8*1024], k=10, display_html=True)

In [None]:
opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)
base_interp = get_response(opening_prompt)
print(base_interp)

In [None]:
# For each IOI example, predict the feature activation score on a given token
accuracies, f1s = [], []
for i in range(z_hidden.shape[0]):
    sentence = dataset_prompts[i]
    sentence_tokens = model.to_tokens(sentence).squeeze()
    activations = z_hidden[i]

    pred_dict = {}

    threshold = 0.0
    non_zero_indices = np.where(activations > threshold)[0].tolist()
    # Add one before and after each non-zero index
    positions = []
    for i in non_zero_indices:
        positions.extend([i-1, i, i+1])
    positions = [x for x in positions if x >= 0 and x < len(sentence_tokens)]
    print(positions)

    if len(positions) == 0:
        # Radomly sample 3 positions
        positions = np.random.choice(len(sentence_tokens), 3, replace=False)

    for i in tqdm(positions):
        sentence_str_example = model.to_string(sentence_tokens[:i+1])
        sentence_tokens_example = sentence_tokens[:i+1]
        ground_truth = activations[i].item()

        # Get the prediction
        prompt = follow_up_activation_prediction_prompt(base_interp, sentence_str_example)
        prediction = get_response(prompt) #gen_openai_completion(prompt, visualize_stream=False)
        pred_dict[i] = {"prediction": prediction, "sentence": sentence_str_example, "ground_truth": ground_truth}

        print(prediction)
        print(sentence_str_example)
        print(ground_truth)

    for k, v in pred_dict.items():
        prediction = v["prediction"].split('[ACTIVATION]: ')[-1].strip()
        pred_dict[k]['prediction_numeric'] = 1.0 if prediction == 'non-zero' else 0.0
        pred_dict[k]['ground_truth_numeric'] = 1.0 if v['ground_truth'] > 0.0 else 0.0

    # Get accuracy and f1 score
    correct = 0
    tp = 0
    fp = 0
    fn = 0
    for k, v in pred_dict.items():
        if v['prediction_numeric'] == v['ground_truth_numeric']:
            correct += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 1:
            tp += 1
        if v['prediction_numeric'] == 1 and v['ground_truth_numeric'] == 0:
            fp += 1
        if v['prediction_numeric'] == 0 and v['ground_truth_numeric'] == 1:
            fn += 1

    print(f"TP: {tp}, FP: {fp}, FN: {fn}")

    if tp + fp + fn > 0:
        accuracy = correct / len(pred_dict)
        precision = tp / (tp + fp) if tp + fp > 0 else 0.0
        recall = tp / (tp + fn) if tp + fn > 0 else 0.0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0.0
        
        accuracies.append(accuracy)
        f1s.append(f1)

# Means
print(f"Mean Accuracy: {np.mean(accuracies)}")
print(f"Mean F1 Score: {np.mean(f1s)}")

In [None]:
results_dict = {
    "15404_L9H9_20546_MLP0": {
        "co_occurrence": {
            "accuracies": 0.68,
            "f1s": 0.54,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.0,
        }
    },
    "16513_L8H6_10461_MLP0": {
        "co_occurrence": {
            "accuracies": 0.82,
            "f1s": 0.68,
        },
        "base": {
            "accuracies": 0.7,
            "f1s": 0.5,
        }
    },
    "24166_L2H2_4522_MLP0": {
        "co_occurrence": {
            "accuracies": 0.52,
            "f1s": 0.55,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.0
        }
    }
}

In [None]:
import plotly.graph_objects as go

results_dict = {
    "15404_L9H9_20546_MLP0": {
        "co_occurrence": {
            "accuracies": 0.68,
            "f1s": 0.54,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.01,
        }
    },
    "16513_L8H6_10461_MLP0": {
        "co_occurrence": {
            "accuracies": 0.82,
            "f1s": 0.68,
        },
        "base": {
            "accuracies": 0.7,
            "f1s": 0.5,
        }
    },
    "24166_L2H2_4522_MLP0": {
        "co_occurrence": {
            "accuracies": 0.62,
            "f1s": 0.55,
        },
        "base": {
            "accuracies": 0.6,
            "f1s": 0.01
        }
    }
}

def create_label(key):
    feature1, layer_head, feature2, mlp = key.split("_")
    layer, head = layer_head[1:].split("H")
    return f"{feature1} & {feature2} (L{layer}H{head} MLP{mlp[-1]})"

labels = [create_label(key) for key in results_dict.keys()]
co_occurrence_accuracies = [results_dict[key]["co_occurrence"]["accuracies"] for key in results_dict.keys()]
co_occurrence_f1s = [results_dict[key]["co_occurrence"]["f1s"] for key in results_dict.keys()]
base_accuracies = [results_dict[key]["base"]["accuracies"] for key in results_dict.keys()]
base_f1s = [results_dict[key]["base"]["f1s"] for key in results_dict.keys()]

fig = go.Figure(data=[
    go.Bar(name='Co-occurrence Accuracies', x=labels, y=co_occurrence_accuracies, marker_color='#ADD8E6'),  # Light blue
    go.Bar(name='Co-occurrence F1s', x=labels, y=co_occurrence_f1s, marker_color='#00008B'),  # Dark blue
    go.Bar(name='Base Accuracies', x=labels, y=base_accuracies, marker_color='#90EE90'),  # Light green
    go.Bar(name='Base F1s', x=labels, y=base_f1s, marker_color='#006400')  # Dark green
])

fig.update_layout(
    title='Co-occurrence vs. Base Performance',
    xaxis_title='Features',
    yaxis_title='Score',
    barmode='group',
    legend_title_text='Metrics',
    font=dict(size=14),
    template='plotly_white',
    width=1000
)

fig.show()