## Load model

In [None]:
from transformer_lens import HookedTransformer
from llm_inspect import AblationLlm
from llm_inspect import AttentionHead

llm = HookedTransformer.from_pretrained("gpt2-small")
ablation_llm = AblationLlm(llm)

### Ablate movement from 'quick' to 'fox'

In [23]:
text = "The quick brown fox jumps over the lazy dog."

tokens = llm.tokenizer.tokenize(text, add_special_tokens=True)
token_ids = llm.tokenizer.encode(text, return_tensors="pt")

In [None]:
from llm_inspect import TokenFinder

token_finder = TokenFinder.create_from_tokenizer(text, llm.tokenizer)

quick_token = token_finder.find_first("quick", allow_space_prefix=True)
fox_token = token_finder.find_first("fox", allow_space_prefix=True)

_, unablated_activation_cache = llm.run_with_cache(token_ids)
_, ablated_activation_cache = ablation_llm.forward(text, token_movement_to_ablate=[(quick_token, fox_token)])

In [None]:
unablated_activation_cache.remove_batch_dim()
ablated_activation_cache.remove_batch_dim()

from llm_inspect import TokenDisplayer

token_displayer = TokenDisplayer.create_for_tokenizer(llm.tokenizer)

In [None]:
from llm_inspect import TokenDisplayer

token_displayer = TokenDisplayer.create_for_tokenizer(llm.tokenizer)

In [None]:
attention_head = AttentionHead(0, 0)

Looking at the attention scores/patterns, we can see that no information is being moved from the 'quick' token to the 'fox' token

In [40]:
print(f"Unablated attention scores for head {attention_head}:")

token_displayer.html_for_token_attention(tokens, unablated_activation_cache, attention_head)

Unablated attention scores for head 0.0:


In [42]:
print(f"Ablated attention scores for head {attention_head}:")

token_displayer.html_for_token_attention(tokens, ablated_activation_cache, attention_head)

Ablated attention scores for head 0.0:


## Attention patterns

In [43]:
token_displayer.html_for_attention_pattern(tokens, unablated_activation_cache, attention_head)

In [44]:
token_displayer.html_for_attention_pattern(tokens, ablated_activation_cache, attention_head)