# Circuit autointerpretability

This stuff just sets up everything we need.

In [49]:
from autointerpretability import *

# Autoreload
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)



Loaded pretrained model gpt2-small into HookedTransformer


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Token indices sequence length is longer than the specified maximum sequence length for this model (73252 > 1024). Running this sequence through the model will result in indexing errors


Note you can specify the features you want to examine, in each layer, and just pass in either the relevant ZSAE or MLP transcoder depending on what component you want to look at. The `get_feature_scores` function will handle the differences. Let's have a look at the max-activating examples on Danny's features he wanted to check out (note you can slice `owt_tokens_torch` to run for shorter).

In [7]:
features = [16513, 7861]
sae = z_saes[8]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

ZSAE


100%|██████████| 200/200 [04:04<00:00,  1.22s/it]


Our feature scores are a tensor of shape `(batch, feature, seq_pos)`, and so I've got a function to help extract the max-activating examples for each feature. You need to specify the feature index, which is why it's helpful to know from above the features in your list.

In [13]:
feature_scores.shape

(25600, 2, 128)

In [43]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15)

In [45]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=3)

pretty_print_tokens_logits(top_tokens, top_logits)

╒═════════════╤═════════╕
│ Token       │   Logit │
╞═════════════╪═════════╡
│ [34marth[0m        │  [32m3.6752[0m │
├─────────────┼─────────┤
│ [34mrers[0m        │  [32m3.4801[0m │
├─────────────┼─────────┤
│ [34mdisplayText[0m │  [32m3.3468[0m │
├─────────────┼─────────┤
│ [34mpool[0m        │  [32m3.3323[0m │
├─────────────┼─────────┤
│ [34mrovers[0m      │  [32m3.2823[0m │
├─────────────┼─────────┤
│ [34mqua[0m         │  [32m3.28[0m   │
├─────────────┼─────────┤
│ [34massian[0m      │  [32m3.2042[0m │
├─────────────┼─────────┤
│ [34mcember[0m      │  [32m3.1544[0m │
├─────────────┼─────────┤
│ [34mrer[0m         │  [32m3.1482[0m │
├─────────────┼─────────┤
│ [34miple[0m        │  [32m3.14[0m   │
├─────────────┼─────────┤
│ [34mutch[0m        │  [32m3.1132[0m │
├─────────────┼─────────┤
│ [34msembly[0m      │  [32m3.0786[0m │
├─────────────┼─────────┤
│ [34mphrine[0m      │  [32m3.0376[0m │
├─────────────┼─────────┤
│ [34mchanc

In [46]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

(Part 1)
Step 1.
ACTIVATING TOKENS: "in the county", "-,", " and", " than", "based", "Ab", ",, ", ")," "," "," ",",", " and".
PREVIOUS TOKENS: "relocated", "in", "Paul", "Zarrin", "galleries", "Strength:", "more", "prohibited".

Step 2.
The activating tokens are conjunctions, punctuation, or prepositions.
The previous tokens don't seem to have a clear pattern apart from being nouns or verbs.
The neuron seems to activate on specific grammatical elements rather than content tokens.

Step 3.
- The neuron activates on grammatical elements such as conjunctions, punctuation, and prepositions.
- The activating token sometimes follows a verb or a noun but there doesn't seem to be a clear pattern.

(Part 2)
Step 4.
SIMILAR TOKENS: Many of the top logits like "arth", "rers", "qua", " assian", "cember", " rer", "iple", "utch", "sembly", "phrine" seems to be fragments of words rather than full tokens. They seem to be related with letter "r".

Step 5.
[EXPLANATION]: This neuron seems to activate on

In [47]:
# Same thing
feature_idx = 1
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15, display_html=False)
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)
pretty_print_tokens_logits(top_tokens, top_logits)
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

╒══════════════╤═════════╕
│ Token        │   Logit │
╞══════════════╪═════════╡
│ [34mividual[0m      │  [32m4.4617[0m │
├──────────────┼─────────┤
│ [34m confir[0m      │  [32m4.3262[0m │
├──────────────┼─────────┤
│ [34mlegate[0m       │  [32m4.1775[0m │
├──────────────┼─────────┤
│ [34mhots[0m         │  [32m3.8846[0m │
├──────────────┼─────────┤
│ [34mbreaker[0m      │  [32m3.8705[0m │
├──────────────┼─────────┤
│ [34muality[0m       │  [32m3.7917[0m │
├──────────────┼─────────┤
│ [34m festivals[0m   │  [32m3.7612[0m │
├──────────────┼─────────┤
│ [34mulk[0m          │  [32m3.6968[0m │
├──────────────┼─────────┤
│ [34mvertisements[0m │  [32m3.5496[0m │
├──────────────┼─────────┤
│ [34mworker[0m       │  [32m3.5273[0m │
├──────────────┼─────────┤
│ [34mlander[0m       │  [32m3.5042[0m │
├──────────────┼─────────┤
│ [34maily[0m         │  [32m3.4963[0m │
├──────────────┼─────────┤
│ [34m arrivals[0m    │  [32m3.4198[0m │
├─────────

You can also pass in and boost logits for multiple features at a time.

In [51]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features, k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)


examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

╒══════════╤═════════╕
│ Token    │   Logit │
╞══════════╪═════════╡
│ [34mhots[0m     │  [32m7.4445[0m │
├──────────┼─────────┤
│ [34mlegate[0m   │  [32m7.3243[0m │
├──────────┼─────────┤
│ [34mhower[0m    │  [32m6.7189[0m │
├──────────┼─────────┤
│ [34mumn[0m      │  [32m6.7184[0m │
├──────────┼─────────┤
│ [34mividual[0m  │  [32m6.7176[0m │
├──────────┼─────────┤
│ [34m confir[0m  │  [32m6.5577[0m │
├──────────┼─────────┤
│ [34mlevision[0m │  [32m6.4419[0m │
├──────────┼─────────┤
│ [34mrers[0m     │  [32m6.2546[0m │
├──────────┼─────────┤
│ [34mutch[0m     │  [32m6.2519[0m │
├──────────┼─────────┤
│ [34mpool[0m     │  [32m6.1928[0m │
╘══════════╧═════════╛


(Part 1)
Step 1.
ACTIVATING TOKENS: "Pal", "Dud", "Mr", "Ton", "Al", "p", "Darn", "Mold", "Dot", "Plum", "Put", "Dart", "Pl", "Rad", "Mud", "Plot", "L", "op", "Tan"
PREVIOUS TOKENS: "Mr", "&", "D", "u", "Nut", "the", "v", "Pl", "Art"

Step 2.
The activating tokens appear to be fragments or short abbreviations of words in a sentence or mathematical notation. These activated tokens often follow names or titles. The previous tokens are often names (e.g., "Mr") or letters associated with mathematical notation or abbreviation. These activating strings of tokens appear to be part of a larger pattern (e.g., chains of tokens together form a more meaningful phrase or concept).

Step 3.
- The examples contain names, acronyms, or abbreviations.
- The activating tokens often appear part of a larger whole, forming a meaningful concept when strung together. 
- Some examples involve mathematical or programming notation.

(Part 2)
Step 4.
SIMILAR TOKENS: "hots", "legate", "hower", "umn", "ividual", " 

Then, you can just pass it off to GPT-4 to interpret what's going on. Note that I haven't got access to `GPT-4o` with my credits yet, so this will have to wait a few days.

In [38]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)

In [40]:
print(feature_interpretation)

(Part 1)
Step 1.
ACTIVATING TOKENS: "in the county", "days", "asia", "half, and", "ial,", "Byndom", "ia,", "plate", "to", "resett".
PREVIOUS TOKENS: "evacuated in", "-", "un", "par", "cliq", "Carr", "Pers", "name", "res", "charact".

Step 2.
The activating tokens are a mixture of prepositions, conjunctions, parts of words, days of the week and multipart words. 
The previous tokens have nothing in common.

Step 3.
- Many activating tokens are parts of words or phrases.
- The texts geographically widespread places.

(Part 2)
Step 4.
SIMILAR TOKENS: "arth", "rers", "rovers", "rer".
These tokens seem to be part of words, particularly endings part of nouns, adjectives or even verbs. 

Step 5:
[EXPLANATION]: Parts of words, notably the endings of nouns, verbs, or adjectives.


Finally, we can pass in multiple features at once to see the max activating examples for features together.

In [44]:
_ = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

However, instead of passing in individual features for specific components in specific layers, I created an object called `CircuitPrediction` to basically store all this stuff for you. I'll quickly illustrate how to use it in conjunction with the above.

In [47]:
cp = get_circuit_prediction(task='ioi', N=20)

100%|██████████| 20/20 [00:48<00:00,  2.45s/it]
100%|██████████| 20/20 [00:45<00:00,  2.28s/it]
100%|██████████| 20/20 [00:46<00:00,  2.31s/it]
100%|██████████| 20/20 [00:43<00:00,  2.17s/it]


The main thing you'll want to do with this is get features from certain components to look at on a specific task. The features for each component are stored in the circuit hypergraph. For instance:

In [48]:
cp.circuit_hypergraph

{'L0_H0': {'freq': 0.0, 'features': []},
 'L0_H1': {'freq': 0.800000011920929,
  'features': [9715,
   1846,
   451,
   16579,
   -1,
   3949,
   10094,
   -1,
   23825,
   451,
   23825,
   17242,
   5142,
   13846,
   451,
   9715,
   17242,
   451,
   -1,
   17242,
   451,
   4229,
   17242,
   451,
   451,
   -1,
   14731,
   451,
   3501]},
 'L0_H2': {'freq': 0.0, 'features': []},
 'L0_H3': {'freq': 0.5,
  'features': [18802,
   2591,
   11470,
   11470,
   11470,
   21859,
   16579,
   11470,
   17242,
   11470,
   14731,
   17242]},
 'L0_H4': {'freq': 0.0, 'features': []},
 'L0_H5': {'freq': 0.8999999761581421,
  'features': [3392,
   16230,
   3545,
   9715,
   11883,
   10404,
   24085,
   20904,
   17083,
   18034,
   5690,
   4515,
   18034,
   16230,
   10404,
   14037,
   3160,
   21859,
   23663,
   18034,
   1412,
   16132,
   21859,
   3160,
   21339,
   3160,
   7161,
   21859,
   23603]},
 'L0_H6': {'freq': 0.949999988079071,
  'features': [-1,
   -1,
   -1,
   -1,
  

If you want to look at MLP 3, all you have to do is access it:

In [None]:
cp.circuit_hypergraph['MLP3']

And just repeat what we did above:

In [None]:
features = list(set(cp.circuit_hypergraph['MLP3']['features']))
transcoder = transcoders[3]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=64)

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch, k=5, show_score=True)

There's a few other methods, but you probably don't need to bother with those.

In [None]:
_ = cp.unique_feature_array(visualize=True)

# Playground

What do we want to actually look for?
* We could take specific components, and look at all their features across the circuit hypergraph, then get some sort of "mass autointerpretation" of what this feature is doing. I think for this you'd need to also feed in information from where it activates on the actual circuit. Might seem a bit soft and qualitative, but if you do it principled enough, it could be useful. Also try weighting the cluster max-act examples + logits by how often the feature shows up.
* Look at what features co-occur together in examples. Should give more signal than just looking at features that activate heaps. (Also look at features that activate strongly across all examples as well though.) 

## Feature cluster interpretation of model components

## Co-occurrence of features

In [10]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=20)

100%|██████████| 20/20 [00:50<00:00,  2.52s/it]
100%|██████████| 20/20 [00:50<00:00,  2.51s/it]
100%|██████████| 20/20 [00:49<00:00,  2.47s/it]
100%|██████████| 20/20 [00:49<00:00,  2.47s/it]


In [11]:
cp.co_occurrence_dict

{(('attn_head', 0, 1), ('attn_head', 0, 5)): [(17242, 5690),
  (17242, 17083),
  (17242, 5690),
  (17242, 17083)],
 (('attn_head', 0, 1), ('mlp_feature', 0)): [(17242, 20546),
  (17242, 5348),
  (17242, 8491),
  (17242, 10173),
  (17242, 10461)],
 (('attn_head', 0, 1), ('attn_head', 6, 9)): [(17242, 17410)],
 (('attn_head', 0, 1), ('attn_head', 8, 6)): [(17242, 16513)],
 (('attn_head', 0, 1), ('attn_head', 9, 9)): [(17242, 16109)],
 (('attn_head', 0, 1), ('mlp_feature', 9)): [(17242, 5957)],
 (('attn_head', 0, 1), ('attn_head', 11, 1)): [(17242, 8907)],
 (('attn_head', 0, 5), ('mlp_feature', 0)): [(5690, 10173),
  (5690, 8491),
  (5690, 10173),
  (5690, 20546),
  (5690, 8491),
  (5690, 10173),
  (5690, 20546),
  (5690, 8491),
  (5690, 10173),
  (17083, 20546),
  (17083, 8491),
  (17083, 10173),
  (5690, 10461),
  (17083, 10461),
  (5690, 20546),
  (17083, 20546),
  (5690, 8491),
  (17083, 8491),
  (5690, 10173),
  (17083, 10173),
  (5690, 20546),
  (17083, 20546),
  (5690, 5348),
  (17

In [21]:
def parse_component(component_str):
    """
    Parse the component string to return the appropriate tuple key.
    """
    if component_str.startswith("MLP"):
        layer = int(component_str[3:])
        return ('mlp_feature', layer)
    elif component_str.startswith("L") and "H" in component_str:
        layer, head = map(int, component_str[1:].split("H"))
        return ('attn_head', layer, head)
    else:
        raise ValueError(f"Invalid component format: {component_str}")

def get_cooccurrences(co_occurrence_dict, component1_str, component2_str):
    """
    Retrieve co-occurrences between two components given their string representations.
    Ensures symmetry in retrieval.
    """
    component1 = parse_component(component1_str)
    component2 = parse_component(component2_str)
    
    # Create a sorted tuple to ensure symmetry
    key = tuple(sorted((component1, component2)))
    
    if key in co_occurrence_dict:
        return co_occurrence_dict[key]
    else:
        return []
    

get_cooccurrences(cp.co_occurrence_dict, "L9H9", "MLP0")

[]

In [19]:
import numpy as np
import plotly.express as px

def visualize_co_occurrences(co_occurrence_dict):
    num_layers = 12
    num_heads = 12
    num_components = num_layers * (num_heads + 1)
    
    # Create a mapping from components to indices
    component_to_index = {}
    index = 0
    for layer in range(num_layers):
        for head in range(num_heads):
            component_to_index[('attn_head', layer, head)] = index
            index += 1
        component_to_index[('mlp_feature', layer)] = index
        index += 1

    # Initialize the co-occurrence matrix
    co_occurrence_matrix = np.zeros((num_components, num_components))

    # Populate the co-occurrence matrix
    for comp_pair, co_occurrences in co_occurrence_dict.items():
        comp1, comp2 = list(comp_pair)
        index1 = component_to_index[comp1]
        index2 = component_to_index[comp2]
        co_occurrence_matrix[index1, index2] = len(co_occurrences)
        co_occurrence_matrix[index2, index1] = len(co_occurrences)

    # Create the labels for the heatmap
    labels = []
    for layer in range(num_layers):
        for head in range(num_heads):
            labels.append(f"L{layer}H{head}")
        labels.append(f"MLP{layer}")

    # Plot the heatmap using Plotly
    fig = px.imshow(
        co_occurrence_matrix,
        labels=dict(x="Component", y="Component", color="Co-occurrence Count"),
        x=labels,
        y=labels,
        color_continuous_scale="blues",
        title="Component Co-occurrences Heatmap"
    )
    
    fig.update_layout(width=800, height=800)
    fig.show()

In [20]:
visualize_co_occurrences(cp.co_occurrence_dict)

In [11]:
from collections import defaultdict, Counter

def get_top_k_feature_tuples(co_occurrence_dict, k):
    # Use a Counter to count the occurrences of each tuple
    global_counter = Counter()

    # Iterate through each component set and their feature lists
    for component_set, feature_list in co_occurrence_dict.items():
        for feature_tuple in feature_list:
            # Create a sorted tuple to ensure (x, y) and (y, x) are treated the same
            sorted_tuple = tuple(sorted(feature_tuple))
            global_counter[(component_set, sorted_tuple)] += 1

    # Get the top-k tuples by count
    top_k_tuples = global_counter.most_common(k)
    
    # Create a dictionary to store the results
    top_k_dict = defaultdict(dict)
    
    for (component_set, feature_tuple), count in top_k_tuples:
        top_k_dict[component_set][feature_tuple] = count

    return top_k_dict


top_k_feature_tuples = get_top_k_feature_tuples(cp.co_occurrence_dict, 10)
top_k_feature_tuples

defaultdict(dict,
            {frozenset({('attn_head', 0, 1),
                        ('mlp_feature', 0)}): {(5142, 19750): 4, (5142,
               12965): 3},
             frozenset({('attn_head', 0, 5),
                        ('mlp_feature', 0)}): {(19750, 24085): 3},
             frozenset({('mlp_feature', 0),
                        ('mlp_feature', 1)}): {(16935, 19750): 3},
             frozenset({('mlp_feature', 0),
                        ('mlp_feature', 3)}): {(1324, 19750): 3},
             frozenset({('attn_head', 9, 9),
                        ('mlp_feature', 0)}): {(8777, 19750): 3},
             frozenset({('attn_head', 0, 1),
                        ('attn_head', 0, 3)}): {(5142, 11470): 2},
             frozenset({('attn_head', 0, 1),
                        ('attn_head', 0, 5)}): {(5142, 24085): 2},
             frozenset({('attn_head', 0, 1),
                        ('mlp_feature', 1)}): {(5142, 16935): 2},
             frozenset({('attn_head', 0, 1),
              

Let's focus on `L9H9` composing with `MLP0`. This has a feature tuple `(8777, 19750)` which occurs three times. 