# Circuit autointerpretability

This stuff just sets up everything we need.

In [7]:
from autointerpretability import *

# Autoreload
%load_ext autoreload
%autoreload 2

In [8]:
config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)



Loaded pretrained model gpt2-small into HookedTransformer


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Token indices sequence length is longer than the specified maximum sequence length for this model (73252 > 1024). Running this sequence through the model will result in indexing errors


Note you can specify the features you want to examine, in each layer, and just pass in either the relevant ZSAE or MLP transcoder depending on what component you want to look at. The `get_feature_scores` function will handle the differences. Let's have a look at the max-activating examples on Danny's features he wanted to check out (note you can slice `owt_tokens_torch` to run for shorter).

In [None]:
features = [16513, 7861]
sae = z_saes[8]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

Our feature scores are a tensor of shape `(batch, feature, seq_pos)`, and so I've got a function to help extract the max-activating examples for each feature. You need to specify the feature index, which is why it's helpful to know from above the features in your list.

In [None]:
feature_scores.shape

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15)

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=3)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
# Same thing
feature_idx = 1
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15, display_html=False)
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)
pretty_print_tokens_logits(top_tokens, top_logits)
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

You can also pass in and boost logits for multiple features at a time.

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features, k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)


examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

Then, you can just pass it off to GPT-4 to interpret what's going on. Note that I haven't got access to `GPT-4o` with my credits yet, so this will have to wait a few days.

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)

In [None]:
print(feature_interpretation)

Finally, we can pass in multiple features at once to see the max activating examples for features together.

In [None]:
_ = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

However, instead of passing in individual features for specific components in specific layers, I created an object called `CircuitPrediction` to basically store all this stuff for you. I'll quickly illustrate how to use it in conjunction with the above.

In [None]:
cp = get_circuit_prediction(task='ioi', N=20)

The main thing you'll want to do with this is get features from certain components to look at on a specific task. The features for each component are stored in the circuit hypergraph. For instance:

In [None]:
cp.circuit_hypergraph

If you want to look at MLP 3, all you have to do is access it:

In [None]:
cp.circuit_hypergraph['MLP3']

And just repeat what we did above:

In [None]:
features = list(set(cp.circuit_hypergraph['MLP3']['features']))
transcoder = transcoders[3]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=64)

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch, k=5, show_score=True)

There's a few other methods, but you probably don't need to bother with those.

In [None]:
_ = cp.unique_feature_array(visualize=True)

# Playground

What do we want to actually look for?
* We could take specific components, and look at all their features across the circuit hypergraph, then get some sort of "mass autointerpretation" of what this feature is doing. I think for this you'd need to also feed in information from where it activates on the actual circuit. Might seem a bit soft and qualitative, but if you do it principled enough, it could be useful. Also try weighting the cluster max-act examples + logits by how often the feature shows up.
* Look at what features co-occur together in examples. Should give more signal than just looking at features that activate heaps. (Also look at features that activate strongly across all examples as well though.) 

## Feature cluster interpretation of model components

## Co-occurrence of features

In [4]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

100%|██████████| 20/20 [00:48<00:00,  2.43s/it]


In [None]:
cp.co_occurrence_dict

In [None]:
cp.get_cooccurrences("MLP0", "L9H9")

In [None]:
cp.get_cooccurrences("MLP0", "L9H6")

In [None]:
cp.visualize_co_occurrences()

In [None]:
cp.get_top_k_feature_tuples(k=10)

In [None]:
from collections import Counter, defaultdict

def get_top_k_feature_tuples_for_component(co_occurrence_dict, component_str, k=5):
    # Parse the component string to get the appropriate tuple key
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        component = ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        component = ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through the co-occurrence dictionary
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = comp_pair

        if comp1 == component or comp2 == component:
            for feature_tuple in co_occurrences:
                global_counter[(comp_pair, feature_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)

    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (comp_pair, feature_tuple), count in top_k_tuples:
        top_k_dict[comp_pair][feature_tuple] = count

    return top_k_dict

get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L8H6", k=20)

In [None]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L9H9", k=20)

In [None]:
features = [20101]
sae = z_saes[9]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

In [None]:
# Display top k activating examples
feature_idx = 0 # corresponding to 16109
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=5)

encoder_feature_pairs = [(sae, [20101])]

# top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)

top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
features = [6798]
transcoder = transcoders[0]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=128)

In [None]:
# Display top k activating examples
feature_idx = 0 # corresponding to 16109
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=5)

encoder_feature_pairs = [(transcoder, [6798])]

# top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)

top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
encoder_feature_pairs = [(transcoders[0], [6798]), (z_saes[9], [20101])]
top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
feature_scores = get_feature_scores_across_layers(model, encoder_feature_pairs, owt_tokens_torch, batch_size=128)

In [None]:
examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=25, show_score=True)

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [None]:
get_top_k_feature_tuples_for_component(cp.co_occurrence_dict, "L4H11", k=20)

In [None]:
layer = 4
head = 11
feature = 20359

features = [feature]
sae = z_saes[layer]

feature_scores = get_feature_scores(model, sae, owt_tokens_torch[:1024*4], features, batch_size=128)

In [None]:
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch[:1024*4], k=15)

In [None]:
top_tokens, top_logits = get_top_k_tokens(model, [(sae, [feature])], k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

In [None]:
from autointerp_prompts import get_opening_prompt

# Autoreload
%load_ext autoreload
%autoreload 2

def new_get_response(llm_client, examples_clean_text, top_tokens):
    opening_prompt = get_opening_prompt(examples_clean_text, top_tokens)
    messages = [{"role": "user", "content": opening_prompt}]
    response = llm_client.chat.completions.create(
        model="gpt4_large",
        messages=messages,
    )
    return f"{response.choices[0].message.content}"

In [None]:
feature_interpretation = new_get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

## Autointerp over clusters of features

In [5]:
# Go through co-occurrence dict and print any tuples that have a feature over 24576
for k, v in cp.co_occurrence_dict.items():
    for feature_tuple in v:
        if any([f > 24576 for f in feature_tuple]):
            print(k, feature_tuple)

(('attn_head', 0, 1), ('attn_head', 5, 0)) (6901, 44452)
(('attn_head', 0, 1), ('attn_head', 5, 0)) (17242, 44452)
(('attn_head', 0, 1), ('attn_head', 5, 0)) (4229, 44452)
(('attn_head', 0, 1), ('attn_head', 5, 0)) (9715, 44452)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (451, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (4229, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (17242, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (4229, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (6901, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (17242, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (4229, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (20191, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (17242, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (17242, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (17242, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (2680, 27535)
(('attn_head', 0, 1), ('attn_head', 5, 5)) (2680, 27535)
(('attn_head', 0, 1), ('a

In [9]:
sae = z_saes[5]
sae.W_dec.shape

torch.Size([49152, 768])

In [10]:
def feature_scores_for_component_cluster(component_name: str, layer: int):
    features = [x for x in list(set(cp.circuit_hypergraph[component_name]['features'])) if x!=-1]

    sae = z_saes[layer]
    feature_scores = get_feature_scores(model, sae, owt_tokens_torch[:1024*4], features, batch_size=128)

    top_tokens, top_logits = get_top_k_tokens(model, [(sae, features)], k=10, act_strength=5)

    return feature_scores, top_tokens, top_logits

In [None]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L0_H1", 0)

In [None]:
example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=15)

In [None]:
pretty_print_tokens_logits(top_tokens, top_logits)

In [11]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L5_H5", 5)

ZSAE


100%|██████████| 32/32 [00:56<00:00,  1.75s/it]


TypeError: get_combined_logits() takes from 2 to 3 positional arguments but 4 were given

In [72]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L5_H5", 5)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

ZSAE


100%|██████████| 32/32 [00:52<00:00,  1.64s/it]


IndexError: index 44256 is out of bounds for dimension 0 with size 24576

In [None]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

In [65]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L10_H7", 10)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

ZSAE


100%|██████████| 32/32 [00:20<00:00,  1.59it/s]


╒═══════════╤═════════╕
│ Token     │   Logit │
╞═══════════╪═════════╡
│ [34m:][0m        │ [32m14.2477[0m │
├───────────┼─────────┤
│ [34msections[0m  │ [32m12.23[0m   │
├───────────┼─────────┤
│ [34mobi[0m       │ [32m12.1438[0m │
├───────────┼─────────┤
│ [34mimar[0m      │ [32m12.0566[0m │
├───────────┼─────────┤
│ [34m »[0m        │ [32m11.9453[0m │
├───────────┼─────────┤
│ [34many[0m       │ [32m11.924[0m  │
├───────────┼─────────┤
│ [34m Crossing[0m │ [32m11.8792[0m │
├───────────┼─────────┤
│ [34mlear[0m      │ [32m11.7354[0m │
├───────────┼─────────┤
│ [34mndum[0m      │ [32m11.691[0m  │
├───────────┼─────────┤
│ [34marms[0m      │ [32m11.649[0m  │
╘═══════════╧═════════╛


In [66]:
feature_scores.shape

(4096, 6, 128)

In [69]:
feature_scores, top_tokens, top_logits = feature_scores_for_component_cluster("L8_H6", 8)

example_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens_torch[:1024*4], feature_indices=[x for x in range(feature_scores.shape[1])], k=10)

pretty_print_tokens_logits(top_tokens, top_logits)

ZSAE


100%|██████████| 32/32 [00:39<00:00,  1.22s/it]


╒═══════════╤═════════╕
│ Token     │   Logit │
╞═══════════╪═════════╡
│ [34m:][0m        │ [32m14.2477[0m │
├───────────┼─────────┤
│ [34msections[0m  │ [32m12.23[0m   │
├───────────┼─────────┤
│ [34mobi[0m       │ [32m12.1438[0m │
├───────────┼─────────┤
│ [34mimar[0m      │ [32m12.0566[0m │
├───────────┼─────────┤
│ [34m »[0m        │ [32m11.9453[0m │
├───────────┼─────────┤
│ [34many[0m       │ [32m11.924[0m  │
├───────────┼─────────┤
│ [34m Crossing[0m │ [32m11.8792[0m │
├───────────┼─────────┤
│ [34mlear[0m      │ [32m11.7354[0m │
├───────────┼─────────┤
│ [34mndum[0m      │ [32m11.691[0m  │
├───────────┼─────────┤
│ [34marms[0m      │ [32m11.649[0m  │
╘═══════════╧═════════╛


In [68]:
feature_scores.shape

(4096, 0, 128)