# Circuit autointerpretability

This stuff just sets up everything we need.

In [67]:
from autointerpretability import *

# Autoreload
%load_ext autoreload
%autoreload 2

In [7]:
config = yaml.safe_load(open("config.yaml"))
llm_client = AzureOpenAI(
    azure_endpoint=config["base_url"],
    api_key=config["azure_api_key"],
    api_version=config["api_version"],
)

model = HookedTransformer.from_pretrained('gpt2-small')

dataset = load_dataset('Skylion007/openwebtext', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10_000)
tokenized_owt = tokenize_and_concatenate(dataset, model.tokenizer, max_length=128, streaming=True)
tokenized_owt = tokenized_owt.shuffle(42)
tokenized_owt = tokenized_owt.take(12800 * 2)
owt_tokens = np.stack([x['tokens'] for x in tokenized_owt])
owt_tokens_torch = torch.tensor(owt_tokens)

device = 'cpu'
tl_model, z_saes, transcoders = get_model_encoders(device=device)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



Loaded pretrained model gpt2-small into HookedTransformer



The repository for Skylion007/openwebtext contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/Skylion007/openwebtext
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.

Token indices sequence length is longer than the specified maximum sequence length for this model (73252 > 1024). Running this sequence through the model will result in indexing errors


Note you can specify the features you want to examine, in each layer, and just pass in either the relevant ZSAE or MLP transcoder depending on what component you want to look at. The `get_feature_scores` function will handle the differences. Let's have a look at the max-activating examples on Danny's features he wanted to check out (note you can slice `owt_tokens_torch` to run for shorter).

In [7]:
features = [16513, 7861]
sae = z_saes[8]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

ZSAE


100%|██████████| 200/200 [04:04<00:00,  1.22s/it]


Our feature scores are a tensor of shape `(batch, feature, seq_pos)`, and so I've got a function to help extract the max-activating examples for each feature. You need to specify the feature index, which is why it's helpful to know from above the features in your list.

In [13]:
feature_scores.shape

(25600, 2, 128)

In [43]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15)

In [45]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=3)

pretty_print_tokens_logits(top_tokens, top_logits)

╒═════════════╤═════════╕
│ Token       │   Logit │
╞═════════════╪═════════╡
│ [34marth[0m        │  [32m3.6752[0m │
├─────────────┼─────────┤
│ [34mrers[0m        │  [32m3.4801[0m │
├─────────────┼─────────┤
│ [34mdisplayText[0m │  [32m3.3468[0m │
├─────────────┼─────────┤
│ [34mpool[0m        │  [32m3.3323[0m │
├─────────────┼─────────┤
│ [34mrovers[0m      │  [32m3.2823[0m │
├─────────────┼─────────┤
│ [34mqua[0m         │  [32m3.28[0m   │
├─────────────┼─────────┤
│ [34massian[0m      │  [32m3.2042[0m │
├─────────────┼─────────┤
│ [34mcember[0m      │  [32m3.1544[0m │
├─────────────┼─────────┤
│ [34mrer[0m         │  [32m3.1482[0m │
├─────────────┼─────────┤
│ [34miple[0m        │  [32m3.14[0m   │
├─────────────┼─────────┤
│ [34mutch[0m        │  [32m3.1132[0m │
├─────────────┼─────────┤
│ [34msembly[0m      │  [32m3.0786[0m │
├─────────────┼─────────┤
│ [34mphrine[0m      │  [32m3.0376[0m │
├─────────────┼─────────┤
│ [34mchanc

In [46]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

(Part 1)
Step 1.
ACTIVATING TOKENS: "in the county", "-,", " and", " than", "based", "Ab", ",, ", ")," "," "," ",",", " and".
PREVIOUS TOKENS: "relocated", "in", "Paul", "Zarrin", "galleries", "Strength:", "more", "prohibited".

Step 2.
The activating tokens are conjunctions, punctuation, or prepositions.
The previous tokens don't seem to have a clear pattern apart from being nouns or verbs.
The neuron seems to activate on specific grammatical elements rather than content tokens.

Step 3.
- The neuron activates on grammatical elements such as conjunctions, punctuation, and prepositions.
- The activating token sometimes follows a verb or a noun but there doesn't seem to be a clear pattern.

(Part 2)
Step 4.
SIMILAR TOKENS: Many of the top logits like "arth", "rers", "qua", " assian", "cember", " rer", "iple", "utch", "sembly", "phrine" seems to be fragments of words rather than full tokens. They seem to be related with letter "r".

Step 5.
[EXPLANATION]: This neuron seems to activate on

In [47]:
# Same thing
feature_idx = 1
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15, display_html=False)
top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)
pretty_print_tokens_logits(top_tokens, top_logits)
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

╒══════════════╤═════════╕
│ Token        │   Logit │
╞══════════════╪═════════╡
│ [34mividual[0m      │  [32m4.4617[0m │
├──────────────┼─────────┤
│ [34m confir[0m      │  [32m4.3262[0m │
├──────────────┼─────────┤
│ [34mlegate[0m       │  [32m4.1775[0m │
├──────────────┼─────────┤
│ [34mhots[0m         │  [32m3.8846[0m │
├──────────────┼─────────┤
│ [34mbreaker[0m      │  [32m3.8705[0m │
├──────────────┼─────────┤
│ [34muality[0m       │  [32m3.7917[0m │
├──────────────┼─────────┤
│ [34m festivals[0m   │  [32m3.7612[0m │
├──────────────┼─────────┤
│ [34mulk[0m          │  [32m3.6968[0m │
├──────────────┼─────────┤
│ [34mvertisements[0m │  [32m3.5496[0m │
├──────────────┼─────────┤
│ [34mworker[0m       │  [32m3.5273[0m │
├──────────────┼─────────┤
│ [34mlander[0m       │  [32m3.5042[0m │
├──────────────┼─────────┤
│ [34maily[0m         │  [32m3.4963[0m │
├──────────────┼─────────┤
│ [34m arrivals[0m    │  [32m3.4198[0m │
├─────────

You can also pass in and boost logits for multiple features at a time.

In [51]:
top_tokens, top_logits = get_top_k_tokens(model, sae, features, k=10, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)


examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

╒══════════╤═════════╕
│ Token    │   Logit │
╞══════════╪═════════╡
│ [34mhots[0m     │  [32m7.4445[0m │
├──────────┼─────────┤
│ [34mlegate[0m   │  [32m7.3243[0m │
├──────────┼─────────┤
│ [34mhower[0m    │  [32m6.7189[0m │
├──────────┼─────────┤
│ [34mumn[0m      │  [32m6.7184[0m │
├──────────┼─────────┤
│ [34mividual[0m  │  [32m6.7176[0m │
├──────────┼─────────┤
│ [34m confir[0m  │  [32m6.5577[0m │
├──────────┼─────────┤
│ [34mlevision[0m │  [32m6.4419[0m │
├──────────┼─────────┤
│ [34mrers[0m     │  [32m6.2546[0m │
├──────────┼─────────┤
│ [34mutch[0m     │  [32m6.2519[0m │
├──────────┼─────────┤
│ [34mpool[0m     │  [32m6.1928[0m │
╘══════════╧═════════╛


(Part 1)
Step 1.
ACTIVATING TOKENS: "Pal", "Dud", "Mr", "Ton", "Al", "p", "Darn", "Mold", "Dot", "Plum", "Put", "Dart", "Pl", "Rad", "Mud", "Plot", "L", "op", "Tan"
PREVIOUS TOKENS: "Mr", "&", "D", "u", "Nut", "the", "v", "Pl", "Art"

Step 2.
The activating tokens appear to be fragments or short abbreviations of words in a sentence or mathematical notation. These activated tokens often follow names or titles. The previous tokens are often names (e.g., "Mr") or letters associated with mathematical notation or abbreviation. These activating strings of tokens appear to be part of a larger pattern (e.g., chains of tokens together form a more meaningful phrase or concept).

Step 3.
- The examples contain names, acronyms, or abbreviations.
- The activating tokens often appear part of a larger whole, forming a meaningful concept when strung together. 
- Some examples involve mathematical or programming notation.

(Part 2)
Step 4.
SIMILAR TOKENS: "hots", "legate", "hower", "umn", "ividual", " 

Then, you can just pass it off to GPT-4 to interpret what's going on. Note that I haven't got access to `GPT-4o` with my credits yet, so this will have to wait a few days.

In [38]:
feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)

In [40]:
print(feature_interpretation)

(Part 1)
Step 1.
ACTIVATING TOKENS: "in the county", "days", "asia", "half, and", "ial,", "Byndom", "ia,", "plate", "to", "resett".
PREVIOUS TOKENS: "evacuated in", "-", "un", "par", "cliq", "Carr", "Pers", "name", "res", "charact".

Step 2.
The activating tokens are a mixture of prepositions, conjunctions, parts of words, days of the week and multipart words. 
The previous tokens have nothing in common.

Step 3.
- Many activating tokens are parts of words or phrases.
- The texts geographically widespread places.

(Part 2)
Step 4.
SIMILAR TOKENS: "arth", "rers", "rovers", "rer".
These tokens seem to be part of words, particularly endings part of nouns, adjectives or even verbs. 

Step 5:
[EXPLANATION]: Parts of words, notably the endings of nouns, verbs, or adjectives.


Finally, we can pass in multiple features at once to see the max activating examples for features together.

In [44]:
_ = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=5, show_score=True)

However, instead of passing in individual features for specific components in specific layers, I created an object called `CircuitPrediction` to basically store all this stuff for you. I'll quickly illustrate how to use it in conjunction with the above.

In [47]:
cp = get_circuit_prediction(task='ioi', N=20)

100%|██████████| 20/20 [00:48<00:00,  2.45s/it]
100%|██████████| 20/20 [00:45<00:00,  2.28s/it]
100%|██████████| 20/20 [00:46<00:00,  2.31s/it]
100%|██████████| 20/20 [00:43<00:00,  2.17s/it]


The main thing you'll want to do with this is get features from certain components to look at on a specific task. The features for each component are stored in the circuit hypergraph. For instance:

In [48]:
cp.circuit_hypergraph

{'L0_H0': {'freq': 0.0, 'features': []},
 'L0_H1': {'freq': 0.800000011920929,
  'features': [9715,
   1846,
   451,
   16579,
   -1,
   3949,
   10094,
   -1,
   23825,
   451,
   23825,
   17242,
   5142,
   13846,
   451,
   9715,
   17242,
   451,
   -1,
   17242,
   451,
   4229,
   17242,
   451,
   451,
   -1,
   14731,
   451,
   3501]},
 'L0_H2': {'freq': 0.0, 'features': []},
 'L0_H3': {'freq': 0.5,
  'features': [18802,
   2591,
   11470,
   11470,
   11470,
   21859,
   16579,
   11470,
   17242,
   11470,
   14731,
   17242]},
 'L0_H4': {'freq': 0.0, 'features': []},
 'L0_H5': {'freq': 0.8999999761581421,
  'features': [3392,
   16230,
   3545,
   9715,
   11883,
   10404,
   24085,
   20904,
   17083,
   18034,
   5690,
   4515,
   18034,
   16230,
   10404,
   14037,
   3160,
   21859,
   23663,
   18034,
   1412,
   16132,
   21859,
   3160,
   21339,
   3160,
   7161,
   21859,
   23603]},
 'L0_H6': {'freq': 0.949999988079071,
  'features': [-1,
   -1,
   -1,
   -1,
  

If you want to look at MLP 3, all you have to do is access it:

In [None]:
cp.circuit_hypergraph['MLP3']

And just repeat what we did above:

In [None]:
features = list(set(cp.circuit_hypergraph['MLP3']['features']))
transcoder = transcoders[3]
feature_scores = get_feature_scores(model, transcoder, owt_tokens_torch, features, batch_size=64)

In [None]:
feature_idx = 0 # corresponding to 16513
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, 0, :], owt_tokens_torch, k=5, show_score=True)

There's a few other methods, but you probably don't need to bother with those.

In [None]:
_ = cp.unique_feature_array(visualize=True)

# Playground

What do we want to actually look for?
* We could take specific components, and look at all their features across the circuit hypergraph, then get some sort of "mass autointerpretation" of what this feature is doing. I think for this you'd need to also feed in information from where it activates on the actual circuit. Might seem a bit soft and qualitative, but if you do it principled enough, it could be useful. Also try weighting the cluster max-act examples + logits by how often the feature shows up.
* Look at what features co-occur together in examples. Should give more signal than just looking at features that activate heaps. (Also look at features that activate strongly across all examples as well though.) 

## Feature cluster interpretation of model components

## Co-occurrence of features

In [1]:
from autointerpretability import *

cp = get_circuit_prediction(task='ioi', N=3)



Loaded pretrained model gpt2-small into HookedTransformer
Loaded pretrained model gpt2-small into HookedTransformer

Loading SAEs...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 12/12 [00:07<00:00,  1.58it/s]



Loading Transcoders...


100%|██████████| 12/12 [00:04<00:00,  2.78it/s]
100%|██████████| 3/3 [00:09<00:00,  3.04s/it]


In [2]:
cp.co_occurrence_dict

{(('attn_head', 0, 1), ('attn_head', 0, 3)): [(3949, 11470),
  (3949, 17242),
  (3949, 11470),
  (17242, 21619),
  (17242, 21619),
  (451, 21619)],
 (('attn_head', 0, 1), ('attn_head', 0, 5)): [(3949, 1412),
  (3949, 1412),
  (17242, 21619),
  (17242, 14037),
  (17242, 21619),
  (17242, 14037),
  (451, 21619),
  (451, 14037)],
 (('attn_head', 0, 1), ('mlp_feature', 0)): [(3949, 10672),
  (3949, 10672),
  (3949, 20546),
  (3949, 10672),
  (3949, 20546),
  (3949, 19296),
  (3949, 10672),
  (3949, 20546),
  (3949, 19296),
  (3949, 10672),
  (3949, 20546),
  (3949, 19296),
  (3949, 10461),
  (3949, 19296),
  (3949, 20546),
  (3949, 12965),
  (3949, 10672),
  (3949, 10461),
  (3949, 19296),
  (3949, 20546),
  (3949, 12965),
  (3949, 10672),
  (3949, 10461),
  (3949, 19296),
  (3949, 20546),
  (3949, 12965),
  (3949, 10672),
  (3949, 10461),
  (3949, 19296),
  (3949, 20546),
  (3949, 21346),
  (3949, 12965),
  (3949, 10672),
  (3949, 10461),
  (3949, 19296),
  (3949, 20546),
  (3949, 21346),

In [3]:
# def parse_component(component_str):
#     """
#     Parse the component string to return the appropriate tuple key.
#     """
#     if component_str.startswith("MLP"):
#         layer = int(component_str[3:])
#         return ('mlp_feature', layer)
#     elif component_str.startswith("L") and "H" in component_str:
#         layer, head = map(int, component_str[1:].split("H"))
#         return ('attn_head', layer, head)
#     else:
#         raise ValueError(f"Invalid component format: {component_str}")

# def get_cooccurrences(co_occurrence_dict, component1_str, component2_str):
#     """
#     Retrieve co-occurrences between two components given their string representations.
#     Ensures symmetry in retrieval.
#     """
#     component1 = parse_component(component1_str)
#     component2 = parse_component(component2_str)
    
#     # Create a sorted tuple to ensure symmetry
#     key = tuple(sorted((component1, component2)))
    
#     if key in co_occurrence_dict:
#         return co_occurrence_dict[key], key
#     elif (key[1], key[0]) in co_occurrence_dict:
#         return co_occurrence_dict[(key[1], key[0])], (key[1], key[0])
#     else:
#         return [], None
    

cp.get_cooccurrences("MLP0", "L9H9")

([(10672, 6365),
  (10672, 6365),
  (20546, 6365),
  (10672, 6365),
  (20546, 6365),
  (19296, 6365),
  (10672, 6365),
  (20546, 6365),
  (19296, 6365),
  (10672, 6365),
  (20546, 6365),
  (19296, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (21346, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (21346, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (19296, 6365),
  (20546, 6365),
  (21346, 6365),
  (12965, 6365),
  (10672, 6365),
  (10461, 6365),
  (10462, 132),
  (10462, 132),
  (3604, 132),
  (10462, 132),
  (3604, 132),
  (10461, 132),
  (10462, 132),
  (3604, 132),
  (10461, 132),
  (10462, 132),
  (3604, 132),
  (10461, 132),
  (10462, 132),
  (5348, 132),

In [4]:
cp.get_cooccurrences("MLP0", "L9H6")

([], None)

In [5]:
cp.visualize_co_occurrences()

In [6]:
features = [15490]
sae = z_saes[9]
feature_scores = get_feature_scores(model, sae, owt_tokens_torch, features, batch_size=128)

NameError: name 'z_saes' is not defined

In [36]:
# Display top k activating examples
feature_idx = 0 # corresponding to 16109
example_html, examples_clean_text = display_top_k_activating_examples(model, feature_scores[:, feature_idx, :], owt_tokens_torch, k=15)

top_tokens, top_logits = get_top_k_tokens(model, sae, features[feature_idx], k=20, act_strength=5)

pretty_print_tokens_logits(top_tokens, top_logits)

feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

╒═════════════╤═════════╕
│ Token       │   Logit │
╞═════════════╪═════════╡
│ [34m McCoy[0m      │  [32m4.225[0m  │
├─────────────┼─────────┤
│ [34m yesterday[0m  │  [32m4.048[0m  │
├─────────────┼─────────┤
│ [34m Vaughan[0m    │  [32m3.9792[0m │
├─────────────┼─────────┤
│ [34m Spectre[0m    │  [32m3.7762[0m │
├─────────────┼─────────┤
│ [34m pictured[0m   │  [32m3.7127[0m │
├─────────────┼─────────┤
│ [34mintendent[0m   │  [32m3.697[0m  │
├─────────────┼─────────┤
│ [34m fault[0m      │  [32m3.6671[0m │
├─────────────┼─────────┤
│ [34m ages[0m       │  [32m3.6526[0m │
├─────────────┼─────────┤
│ [34m faults[0m     │  [32m3.5994[0m │
├─────────────┼─────────┤
│ [34mlla[0m         │  [32m3.5621[0m │
├─────────────┼─────────┤
│ [34m ME[0m         │  [32m3.5604[0m │
├─────────────┼─────────┤
│ [34mcause[0m       │  [32m3.531[0m  │
├─────────────┼─────────┤
│ [34m Bed[0m        │  [32m3.4252[0m │
├─────────────┼─────────┤
│ [34m Wido

In [64]:
# def get_top_logits(model, encoder, features, act_strength=4.0, dict_size=24576):
#     print("New get top logits")
#     hidden_acts = torch.zeros(dict_size, device='cpu')
#     if isinstance(features, list):
#         for feature in features:
#             hidden_acts[feature] = act_strength
#     else:
#         hidden_acts[features] = act_strength  # Single feature case

#     hidden_acts = hidden_acts.unsqueeze(0)
#     hidden_acts = encoder.decode(hidden_acts)
#     logits = einops.einsum(
#         hidden_acts.to('cpu'), model.W_U.to('cpu'),
#         'b h, h l -> b l'
#     )
#     return logits

# def get_combined_logits(model, encoder_feature_pairs, act_strength=4.0, dict_size=24576):
#     print("New get combined logits")
#     combined_logits = torch.zeros((1, model.W_U.size(1)), device='cpu')
#     for encoder, features in encoder_feature_pairs:
#         logits = get_top_logits(model, encoder, features, act_strength, dict_size)
#         combined_logits += logits
#     return combined_logits

# def get_top_k_tokens(model, encoder_feature_pairs, dict_size=24576, act_strength=4.0, k=10):
#     print("New get top k tokens")
#     combined_logits = get_combined_logits(model, encoder_feature_pairs, act_strength, dict_size)
#     top_k = torch.topk(combined_logits, k)
#     top_k_indices = top_k.indices.squeeze().tolist()
#     top_k_logits = top_k.values.squeeze().tolist()
#     top_k_tokens = [model.to_string(x) for x in top_k_indices]
#     return top_k_tokens, top_k_logits


encoder_feature_pairs = [(transcoders[0], 10173), (z_saes[9], 16109)]
top_tokens, top_logits = get_top_k_tokens(model, encoder_feature_pairs, k=20)

pretty_print_tokens_logits(top_tokens, top_logits)

New get top k tokens
New get combined logits
New get top logits
New get top logits
╒═════════════╤═════════╕
│ Token       │   Logit │
╞═════════════╪═════════╡
│ [34m McCoy[0m      │  [32m5.9774[0m │
├─────────────┼─────────┤
│ [34m ages[0m       │  [32m5.6441[0m │
├─────────────┼─────────┤
│ [34m Hugo[0m       │  [32m5.5717[0m │
├─────────────┼─────────┤
│ [34m Hutchinson[0m │  [32m5.4651[0m │
├─────────────┼─────────┤
│ [34m Kay[0m        │  [32m5.4403[0m │
├─────────────┼─────────┤
│ [34m Thomas[0m     │  [32m5.3777[0m │
├─────────────┼─────────┤
│ [34m Ellis[0m      │  [32m5.332[0m  │
├─────────────┼─────────┤
│ [34m Carly[0m      │  [32m5.2565[0m │
├─────────────┼─────────┤
│ [34m Bed[0m        │  [32m5.2278[0m │
├─────────────┼─────────┤
│ [34m Jones[0m      │  [32m5.2252[0m │
├─────────────┼─────────┤
│ [34m Kitt[0m       │  [32m5.2231[0m │
├─────────────┼─────────┤
│ [34m ME[0m         │  [32m5.2082[0m │
├─────────────┼─────────┤

In [56]:
# def get_feature_scores(model, encoder, tokens_arr, feature_indices, batch_size=64, act_name=None, 
#                        use_raw_scores=False, use_decoder=False, feature_post=None, ignore_endoftext=False):
    
#     print("New get feature scores")
#     # Determine the type of encoder and set defaults
#     if isinstance(encoder, ZSAE):
#         print("ZSAE")
#         act_name = act_name or 'attn.hook_z'
#         layer = encoder.cfg['layer']
#         name_filter = f'blocks.{layer}.attn.hook_z'
#     elif isinstance(encoder, SparseTranscoder):
#         print("SparseTranscoder")
#         act_name = act_name or encoder.cfg.hook_point
#         layer = encoder.cfg.hook_point_layer
#         name_filter = act_name
#     else:
#         raise ValueError("Unsupported encoder type")

#     scores = []
#     endoftext_token = model.tokenizer.eos_token

#     for i in tqdm(range(0, tokens_arr.shape[0], batch_size)):
#         with torch.no_grad():
#             _, cache = model.run_with_cache(tokens_arr[i:i+batch_size], stop_at_layer=layer+1, names_filter=name_filter)
#             mlp_acts = cache[name_filter]
#             mlp_acts_flattened = mlp_acts.reshape(-1, encoder.W_enc.shape[0])
            
#             if feature_post is None:
#                 if isinstance(encoder, SparseTranscoder) and use_decoder:
#                     feature_post = encoder.W_dec[:, feature_indices]
#                 else:
#                     feature_post = encoder.W_enc[:, feature_indices]
                    
#             if isinstance(encoder, SparseTranscoder) and use_decoder:
#                 bias = -(encoder.b_dec @ feature_post)
#             else:
#                 bias = encoder.b_enc[feature_indices] - (encoder.b_dec @ feature_post)
            
#             if use_raw_scores:
#                 cur_scores = (mlp_acts_flattened @ feature_post) + bias
#             else:
#                 hidden_acts = encoder.encode(mlp_acts_flattened)
#                 cur_scores = hidden_acts[:, feature_indices]
#                 del hidden_acts
            
#             if ignore_endoftext:
#                 cur_scores[tokens_arr[i:i+batch_size].reshape(-1) == endoftext_token] = -torch.inf

#         scores.append(
#             to_numpy(
#                 einops.rearrange(cur_scores, "(b pos) n -> b n pos", pos=tokens_arr.shape[1])
#             ).astype(np.float16)
#         )

#     return np.concatenate(scores, axis=0)

# def get_feature_scores_across_layers(model, encoder_feature_pairs, tokens_arr, act_name=None, batch_size=64):
#     combined_scores = []
#     for encoder, feature_indices in encoder_feature_pairs:
#         scores = get_feature_scores(model, encoder, tokens_arr, feature_indices, act_name=act_name, batch_size=batch_size)
#         combined_scores.append(scores)

#     # We want to combine the scores along the second dimension (features), as feature scores is shape (batch, features, tokens)
#     return np.stack(combined_scores, axis=1).squeeze()

encoder_feature_pairs = [(transcoders[0], [10173]), (z_saes[9], [16109])]
feature_scores = get_feature_scores_across_layers(model, encoder_feature_pairs, owt_tokens_torch, batch_size=128)

New get feature scores
SparseTranscoder


100%|██████████| 200/200 [02:06<00:00,  1.58it/s]


New get feature scores
ZSAE


100%|██████████| 200/200 [04:26<00:00,  1.33s/it]


In [63]:
examples_html, examples_clean_text = display_top_k_activating_examples_sum(model, feature_scores, owt_tokens, [0, 1], k=25, show_score=True)

In [66]:
#feature_interpretation = get_response(llm_client, examples_clean_text, top_tokens)
print(feature_interpretation)

(Part 1)
Step 1.
ACTIVATING TOKENS: "Thomas" repeated in various contexts. Other minor activated tokens include "killing", quotation marks, periods, "with", "and", "of", "to", "ber", and "But".
PREVIOUS TOKENS: Various, no clear pattern.

Step 2.
The activating token "Thomas" is consistently a proper noun, often associated with a position or title and in possessive phrases. Other tokens appear random and do not seem to show a pattern.

Step 3.
- The key theme represented in the examples is people with the name "Thomas", though context and roles vary across examples.
- The token "Thomas" appears often near punctuation marks, such as periods and quotation marks.
- Many examples involve proper title or possessive usage of "Thomas".

(Part 2)
Step 4.
SIMILAR TOKENS: all names ("McCoy", "ages", "Hugo", "Hutchinson", etc.).
The top logits list contains primarily proper names.

Step 5.
[EXPLANATION]: The neuron activates on the proper noun "Thomas", especially related to possessive phrases or

In [68]:
cp.get_top_k_feature_tuples()

AttributeError: 'CircuitPrediction' object has no attribute 'get_top_k_feature_tuples'