## Load model

In [1]:
import torch
from transformer_lens import HookedTransformer

model = HookedTransformer.from_pretrained("gpt2-small")

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model gpt2-small into HookedTransformer


### Forward pass through the LLM to get activations

In [2]:
text = "The quick brown fox jumps over the lazy dog."

tokens = model.tokenizer.tokenize(text, add_special_tokens=True)
token_ids = model.tokenizer.encode(text, return_tensors="pt")

_, activation_cache = model.run_with_cache(token_ids)

## Find heads that match a certain criteria

In [None]:
from llm_inspect import ActivationAnalyzer, TokenDisplayer

activation_analyser = ActivationAnalyzer(tokens, activation_cache)

token_displayer = TokenDisplayer.create_for_tokenizer(model.tokenizer)

### Heads that look at the first token

In [4]:
# Find heads that always look at the first token, i.e. the highest attention score for each token is the first token

looks_at_first_token = lambda attention: torch.all(attention.argmax(-1) == 0)
heads_looking_at_first_token = activation_analyser.find_heads_matching_criteria(looks_at_first_token)

print(f"Found {len(heads_looking_at_first_token)} head(s) looking at the first token")

print(f"Head {heads_looking_at_first_token[0]}:")

token_displayer.html_for_token_attention(tokens, activation_cache, heads_looking_at_first_token[0])

Found 89 head(s) looking at the first token
Head 0.2:
None
<div id="circuits-vis-dc79ba87-d612" style="margin: 15px 0;"/>
    <script crossorigin type="module">
    import { render, ColoredTokensMulti } from "https://unpkg.com/circuitsvis@1.43.3/dist/cdn/esm.js";
    render(
      "circuits-vis-dc79ba87-d612",
      ColoredTokensMulti,
      {"tokens": ["The", " quick", " brown", " fox", " jumps", " over", " the", " lazy", " dog", "."], "values": [[1.0, 0.9663582444190979, 0.8084290027618408, 0.7258327603340149, 0.42324545979499817, 0.48319047689437866, 0.39391928911209106, 0.3522254526615143, 0.36793631315231323, 0.5735489130020142], [0.0, 0.03364172205328941, 0.11924774199724197, 0.0672319158911705, 0.3104708790779114, 0.16336768865585327, 0.06828046590089798, 0.09294448047876358, 0.07791423797607422, 0.06987302750349045], [0.0, 0.0, 0.07232324033975601, 0.10149607807397842, 0.17521129548549652, 0.06768796592950821, 0.05586722865700722, 0.05886082351207733, 0.03597508370876312, 0.020

### Heads that look at the previous token

In [5]:
# Find heads where every token looks at the previous token, i.e. the highest attention score for each token is the previous token

looks_at_previous_token = lambda attention: torch.all(attention.argmax(-1)[1:] == torch.arange(attention.shape[0]-1))
heads_looking_at_previous_token = activation_analyser.find_heads_matching_criteria(looks_at_previous_token)

print(f"Found {len(heads_looking_at_previous_token)} head(s) looking at the previous token")

print(f"Head {heads_looking_at_previous_token[0]}:")

token_displayer.html_for_token_attention(tokens, activation_cache, heads_looking_at_previous_token[0])

Found 1 head(s) looking at the previous token
Head 4.11:
None
<div id="circuits-vis-c34f2417-06da" style="margin: 15px 0;"/>
    <script crossorigin type="module">
    import { render, ColoredTokensMulti } from "https://unpkg.com/circuitsvis@1.43.3/dist/cdn/esm.js";
    render(
      "circuits-vis-c34f2417-06da",
      ColoredTokensMulti,
      {"tokens": ["The", " quick", " brown", " fox", " jumps", " over", " the", " lazy", " dog", "."], "values": [[1.0, 0.9987305998802185, 4.5460211595127475e-07, 5.0940279550104606e-08, 5.868634023187269e-09, 6.281377501471397e-11, 3.857139496687978e-09, 2.6639092993718805e-06, 1.8228432097089353e-08, 1.1756584594735386e-10], [0.0, 0.0012693916214630008, 0.9998726844787598, 7.438994998665294e-06, 4.2916821626440976e-14, 3.548097600700871e-17, 3.355527307120566e-17, 3.504710965884783e-09, 2.1310512382521907e-11, 4.518948566870274e-17], [0.0, 0.0, 0.00012680067447945476, 0.9997484087944031, 2.097707920256653e-06, 1.561583842920808e-13, 1.6742897927725

### Attention pattern

In [6]:
token_displayer.html_for_attention_pattern(tokens, activation_cache, heads_looking_at_previous_token[0])