This example finds attention heads that move information to the referent to the word 'it'.

## Setup

In [1]:
from transformer_lens import HookedTransformer


llm = HookedTransformer.from_pretrained("EleutherAI/pythia-2.8b-deduped-v0")

  from .autonotebook import tqdm as notebook_tqdm


Loaded pretrained model EleutherAI/pythia-2.8b-deduped-v0 into HookedTransformer


In [16]:
text = "I went to the store and bought some milk, but I forgot to bring it home."

input_tokens = llm.tokenizer.tokenize(text, add_special_tokens=True)
input_token_ids = llm.tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")

print(input_tokens)
print(input_token_ids)

['<|endoftext|>', 'I', 'Ġwent', 'Ġto', 'Ġthe', 'Ġstore', 'Ġand', 'Ġbought', 'Ġsome', 'Ġmilk', ',', 'Ġbut', 'ĠI', 'Ġforgot', 'Ġto', 'Ġbring', 'Ġit', 'Ġhome', '.']
tensor([[    0,    42,  2427,   281,   253,  4657,   285,  8686,   690,  8463,
            13,   533,   309, 18298,   281,  3324,   352,  1728,    15]])


In [17]:
_, activation_cache = llm.run_with_cache(input_token_ids)

## Find relevant tokens

In [None]:
from llm_inspect import TokenFinder, AttentionHeadFinder


token_finder = TokenFinder.create_from_tokenizer(text, llm.tokenizer)
activation_analyzer = AttentionHeadFinder.create_from_tokenizer(llm.tokenizer, input_tokens, activation_cache)

In [None]:
store = token_finder.find_first("store", allow_space_prefix=True)
milk = token_finder.find_first("milk", allow_space_prefix=True)
it = token_finder.find_first("it", allow_space_prefix=True)

print(f"Store token: {store}")
print(f"Milk token: {milk}")
print(f"It token: {it}")

### Find heads that move information from 'milk' to 'it'

In [21]:
matching_heads = activation_analyzer.find_heads_where_query_looks_at_value(it, milk)

print(f"Found {len(matching_heads)} heads that move information from 'milk' to 'it':")

Found 21 heads that move information from 'milk' to 'it':


## Visualise

In [None]:
from llm_inspect import TokenDisplayer


token_displayer = TokenDisplayer.create_for_tokenizer(llm.tokenizer)

In [None]:
print(f"Head {matching_heads[0]}:")

token_displayer.html_for_token_attention(
    input_tokens,
    activation_cache,
    matching_heads[0],
)