<a href="https://colab.research.google.com/github/nmonson1/walkthroughs/blob/main/TransformerLens101.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Step 0: Setup

In [None]:
!pip install transformer_lens

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


### Plotting function

In [None]:
import plotly.express as px

def imshow(
    tensor,
    xlabel="X",
    ylabel="Y",
    zlabel=None,
    xticks=None,
    yticks=None,
    c_midpoint=0.0,
    c_scale="RdBu",
    show=True,
    **kwargs
):
    tensor = utils.to_numpy(tensor)
    xticks = [str(x) for x in xticks]
    yticks = [str(y) for y in yticks]
    labels = {"x": xlabel, "y": ylabel}
    if zlabel is not None:
        labels["color"] = zlabel
    fig = px.imshow(
        tensor,
        x=xticks,
        y=yticks,
        labels=labels,
        color_continuous_midpoint=c_midpoint,
        color_continuous_scale=c_scale,
        **kwargs
    ).show()

## Note on Errors:

* If your notebook crashes due to CUDA out of memory try (a) Restart notebook, and (b) Go to Runtime, Change Runtime type, select no hardware accelerator (no GPU). Colab without GPU should have enough memory to run the examples shown here

## Step 1: Getting a model to play with

In [None]:
from transformer_lens import HookedTransformer, utils

model_gpt = HookedTransformer.from_pretrained("gpt2-small")


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
logits = model_gpt("Famous computer scientist Alan")
# The logit dimensions are: [batch, position, vocab]
next_token_logits = logits[0, -1]
next_token_prediction = next_token_logits.argmax()
next_word_prediction = model_gpt.tokenizer.decode(next_token_prediction)
print(next_word_prediction)

 Turing


In [None]:
logits, cache = model_gpt.run_with_cache("Famous computer scientist Alan")
for key, value in cache.items():
    print(key, value.shape)

hook_embed torch.Size([1, 6, 768])
hook_pos_embed torch.Size([1, 6, 768])
blocks.0.hook_resid_pre torch.Size([1, 6, 768])
blocks.0.ln1.hook_scale torch.Size([1, 6, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 6, 768])
blocks.0.attn.hook_q torch.Size([1, 6, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 6, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 6, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 6, 6])
blocks.0.attn.hook_pattern torch.Size([1, 12, 6, 6])
blocks.0.attn.hook_z torch.Size([1, 6, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 6, 768])
blocks.0.hook_resid_mid torch.Size([1, 6, 768])
blocks.0.ln2.hook_scale torch.Size([1, 6, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 6, 768])
blocks.0.mlp.hook_pre torch.Size([1, 6, 3072])
blocks.0.mlp.hook_post torch.Size([1, 6, 3072])
blocks.0.hook_mlp_out torch.Size([1, 6, 768])
blocks.0.hook_resid_post torch.Size([1, 6, 768])
blocks.1.hook_resid_pre torch.Size([1, 6, 768])
blocks.1.ln1.hook_scale torch.Size([1, 6, 1])

In [None]:
utils.test_prompt("Her name was Alex Hart. Tomorrow at lunch time Alex",
                  " Hart", model_gpt)

Tokenized prompt: ['<|endoftext|>', 'Her', ' name', ' was', ' Alex', ' Hart', '.', ' Tomorrow', ' at', ' lunch', ' time', ' Alex']
Tokenized answer: [' Hart']


Top 0th token. Logit: 15.64 Prob: 28.38% Token: | will|
Top 1th token. Logit: 14.47 Prob:  8.79% Token: | would|
Top 2th token. Logit: 14.34 Prob:  7.74% Token: | was|
Top 3th token. Logit: 14.29 Prob:  7.35% Token: | Hart|
Top 4th token. Logit: 14.18 Prob:  6.54% Token: | and|
Top 5th token. Logit: 14.09 Prob:  6.00% Token: | is|
Top 6th token. Logit: 13.51 Prob:  3.38% Token: |'s|
Top 7th token. Logit: 13.23 Prob:  2.53% Token: |,|
Top 8th token. Logit: 12.73 Prob:  1.55% Token: | had|
Top 9th token. Logit: 12.00 Prob:  0.74% Token: | has|


### Method 1: Residual Stream Patching

#### My introductory explanation

In [None]:
model_gpt.reset_hooks()
_, corrupt_cache_gpt = model_gpt.run_with_cache("Her name was Alex Carroll. Tomorrow at lunch time Alex")

def patch_residual_stream(activations, hook, layer="blocks.2.hook_resid_post", pos=5):
   # The residual stream dimensions are [batch, position, d_embed]
   activations[:, pos, :] = corrupt_cache_gpt[layer][:, pos, :]
   return activations


# add_hook takes 2 args: Where to insert the patch,
# and the function providing the updated activations
model_gpt.add_hook("blocks.2.hook_resid_post", patch_residual_stream)


utils.test_prompt("Her name was Alex Hart. Tomorrow at lunch time Alex", " Hart", model_gpt)

model_gpt.reset_hooks()


Tokenized prompt: ['<|endoftext|>', 'Her', ' name', ' was', ' Alex', ' Hart', '.', ' Tomorrow', ' at', ' lunch', ' time', ' Alex']
Tokenized answer: [' Hart']


Top 0th token. Logit: 15.63 Prob: 30.45% Token: | will|
Top 1th token. Logit: 14.44 Prob:  9.23% Token: | was|
Top 2th token. Logit: 14.42 Prob:  9.11% Token: | Carroll|
Top 3th token. Logit: 14.17 Prob:  7.04% Token: | is|
Top 4th token. Logit: 13.88 Prob:  5.31% Token: | and|
Top 5th token. Logit: 13.76 Prob:  4.70% Token: | would|
Top 6th token. Logit: 13.38 Prob:  3.20% Token: |'s|
Top 7th token. Logit: 12.93 Prob:  2.04% Token: |,|
Top 8th token. Logit: 12.33 Prob:  1.12% Token: | had|
Top 9th token. Logit: 12.01 Prob:  0.81% Token: | has|


#### How I actually use this

Oops Colab doesn't have enough memory for doing this on GPT2, so use this 6 layer model from Neel. In the LW post I use GPT2-small.


```
model = HookedTransformer.from_pretrained("solu-6l")
```

In [None]:
import torch
from functools import partial

model_6l = HookedTransformer.from_pretrained("solu-6l")


#corrupt_prompt = "Her name was Sarah Hart. Tomorrow at lunch time Alex"
#corrupt_prompt = "Her name was Alex Hart. Tomorrow at lunch time Sarah"

# Clean and corrupt prompts in variables
clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
corrupt_prompt = "Her name was Alex Carroll. Tomorrow at lunch time Alex"
# Get the list of tokens the model will deal with
clean_tokens = model_6l.to_str_tokens(clean_prompt)
# Indices of the right and wrong answers (last names) to judge what the model predicts
_, corrupt_cache_6l = model_6l.run_with_cache(corrupt_prompt)

def patch_residual_stream(activations, hook, layer="blocks.6.hook_resid_post", pos=5):
   # The residual stream dimensions are [batch, position, d_embed]
   activations[:, pos, :] = corrupt_cache_6l[layer][:, pos, :]
   return activations

# List of layers and positions to iterate over. We want to patch before the
# first layer, and after every layer (so we cover 13 positions in total).
layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model_6l.cfg.n_layers)]]
n_layers = len(layers)
n_pos = len(clean_tokens)

# Indices of the right and wrong answers (last names) to judge what the model predicts
clean_answer_index = model_6l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_6l.tokenizer.encode(" Carroll")[0]

# Test the effect of  this patch at any layer and any position
patching_effect = torch.zeros(n_layers, n_pos)
for l, layer in enumerate(layers):
    print("Patching layer", l)
    for pos in range(n_pos):
        fwd_hooks = [(layer, partial(patch_residual_stream, layer=layer, pos=pos))]
        prediction_logits = model_6l.run_with_hooks(clean_prompt,
                                                    fwd_hooks=fwd_hooks)[0, -1]
        patching_effect[l, pos] = prediction_logits[clean_answer_index] \
                                  - prediction_logits[corrupt_answer_index]

torch.cuda.empty_cache()

Loaded pretrained model solu-6l into HookedTransformer
Patching layer 0
Patching layer 1
Patching layer 2
Patching layer 3
Patching layer 4
Patching layer 5
Patching layer 6


In [None]:
# Plot
token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="pos", ylabel="layer",
       zlabel="Logit difference", title="Patching with 1st occurrence of first name", width=600, height=380)


##### Plots of 2 layer model, as used in the post:

In [None]:
model_2l = HookedTransformer.from_pretrained("gelu-2l")

clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
clean_tokens = model_2l.to_str_tokens(clean_prompt)

# Indices of the right and wrong answers (lastnames) to judge what the model predicts
clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]

def patch_residual_stream(activations, hook, layer="blocks.6.hook_resid_post", pos=5):
   # The residual stream dimensions are [batch, position, d_embed]
   activations[:, pos, :] = corrupt_cache[layer][:, pos, :]
   return activations

layers = ["blocks.0.hook_resid_pre", *[f"blocks.{i}.hook_resid_post" for i in range(model_2l.cfg.n_layers)]]
n_layers = len(layers)
n_pos = len(clean_tokens)


for corrupt_prompt in ["Her name was Alex Carroll. Tomorrow at lunch time Alex",
                       "Her name was Sarah Hart. Tomorrow at lunch time Alex",
                       "Her name was Alex Hart. Tomorrow at lunch time Sarah"]:
    corrupt_tokens = model_2l.to_str_tokens(corrupt_prompt)
    _, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

    # Test this patch at any layer and any position
    patching_effect = torch.zeros(n_layers, n_pos)
    for l, layer in enumerate(layers):
        for pos in range(n_pos):
            torch.cuda.empty_cache()
            fwd_hooks = [(layer, partial(patch_residual_stream, layer=layer, pos=pos))]
            prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
            patching_effect[l, pos] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]

    # Plot
    token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
    imshow(patching_effect, xticks=token_labels, yticks=layers, xlabel="pos", ylabel="layer",
           zlabel="logit diff", title=f"Patching with{corrupt_tokens[4]} /{corrupt_tokens[5]} /{corrupt_tokens[11]}", width=600, height=380)


Loaded pretrained model gelu-2l into HookedTransformer


# Method 2

In [None]:
model_2l.cfg.use_attn_result = True


In [None]:
clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
#corrupt_prompt = "Her name was Sarah Hart. Tomorrow at lunch time Alex"
#corrupt_prompt = "Her name was Alex Hart. Tomorrow at lunch time Sarah"
corrupt_prompt = "Her name was Alex Carroll. Tomorrow at lunch time Alex"
clean_tokens = model_2l.to_str_tokens(clean_prompt)
_, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]

n_layers = model_2l.cfg.n_layers
n_heads = model_2l.cfg.n_heads
n_pos = len(clean_tokens)

def patch_head_result(activations, hook, layer=None, head=None):
   activations[:, :, head, :] = corrupt_cache[hook.name][:, :, head, :]
   return activations

patching_effect = torch.zeros(n_layers, n_heads)
for layer in range(n_layers):
    for head in range(n_heads):
       fwd_hooks = [(f"blocks.{layer}.attn.hook_result", partial(patch_head_result, layer=layer, head=head))]
       prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
       patching_effect[layer, head] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]

In [None]:
# Plot
token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
head_labels = [f"(head {i:2})" for i in range(n_heads)]
layer_labels = [f"(layer {i:2})" for i in range(n_layers)]

imshow(patching_effect, xticks=head_labels, yticks=layer_labels, xlabel="head", ylabel="layer",
       zlabel="Logit difference", title="Patching with 1st occurrence of first name", width=600, height=380)


In [None]:
model_2l.cfg.use_attn_result = True


clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
clean_tokens = model_2l.to_str_tokens(clean_prompt)

# Indices of the right and wrong answers (lastnames) to judge what the model predicts
clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]


n_layers = model_2l.cfg.n_layers
n_heads = model_2l.cfg.n_heads
n_pos = len(clean_tokens)

def patch_head_result(activations, hook, layer=None, head=None):
   activations[:, :, head, :] = corrupt_cache[hook.name][:, :, head, :]
   return activations



for corrupt_prompt in ["Her name was Alex Carroll. Tomorrow at lunch time Alex",
                       "Her name was Sarah Hart. Tomorrow at lunch time Alex",
                       "Her name was Alex Hart. Tomorrow at lunch time Sarah"]:
    corrupt_tokens = model_2l.to_str_tokens(corrupt_prompt)
    _, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

    patching_effect = torch.zeros(n_layers, n_heads)
    for layer in range(n_layers):
        for head in range(n_heads):
           fwd_hooks = [(f"blocks.{layer}.attn.hook_result", partial(patch_head_result, layer=layer, head=head))]
           prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
           patching_effect[layer, head] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]

    # Plot
    head_labels = [f"(head {i:2})" for i in range(n_heads)]
    layer_labels = [f"(layer {i:2})" for i in range(n_layers)]

    imshow(patching_effect.T, yticks=head_labels, xticks=layer_labels, ylabel="head", xlabel="layer",
           zlabel="Logit difference", title=f"Patching with{corrupt_tokens[4]} /{corrupt_tokens[5]} /{corrupt_tokens[11]}", width=600, height=380)


## Bonus

In [None]:
clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
#corrupt_prompt = "Her name was Sarah Hart. Tomorrow at lunch time Alex"
#corrupt_prompt = "Her name was Alex Hart. Tomorrow at lunch time Sarah"
corrupt_prompt = "Her name was Alex Carroll. Tomorrow at lunch time Alex"
corrupt_tokens = model_2l.to_str_tokens(corrupt_prompt)
clean_tokens = model_2l.to_str_tokens(clean_prompt)
_, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]

n_layers = model_2l.cfg.n_layers
n_heads = model_2l.cfg.n_heads
n_pos = len(clean_tokens)

def patch_head_result(activations, hook, layer=None, head=None, pos=None):
   activations[:, pos, head, :] = corrupt_cache[hook.name][:, pos, head, :]
   return activations

patching_effect = torch.zeros(n_layers*n_heads, n_pos)
for layer in range(n_layers):
    for head in range(n_heads):
      for pos in range(n_pos):
          fwd_hooks = [(f"blocks.{layer}.attn.hook_result", partial(patch_head_result, layer=layer, head=head, pos=pos))]
          prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
          patching_effect[n_heads*layer+head, pos] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]

In [None]:
# Plot
token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
layerhead_labels = [f"{l}.{h}" for l in range(n_layers) for h in range(n_heads)]
imshow(patching_effect, xticks=token_labels, yticks=layerhead_labels, xlabel="position", ylabel="layer.head",
           zlabel="Logit difference", title=f"Patching with{corrupt_tokens[4]} /{corrupt_tokens[5]} /{corrupt_tokens[11]}", width=700, height=800)



In [None]:
model_2l.cfg.use_attn_result = True


clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
clean_tokens = model_2l.to_str_tokens(clean_prompt)

# Indices of the right and wrong answers (lastnames) to judge what the model predicts
clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]


n_layers = model_2l.cfg.n_layers
n_heads = model_2l.cfg.n_heads
n_pos = len(clean_tokens)

def patch_head_result(activations, hook, layer=None, head=None, pos=None):
   activations[:, pos, head, :] = corrupt_cache[hook.name][:, pos, head, :]
   return activations


for corrupt_prompt in ["Her name was Alex Carroll. Tomorrow at lunch time Alex",
                       "Her name was Sarah Hart. Tomorrow at lunch time Alex",
                       "Her name was Alex Hart. Tomorrow at lunch time Sarah"]:
    corrupt_tokens = model_2l.to_str_tokens(corrupt_prompt)
    _, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

    patching_effect = torch.zeros(n_layers*n_heads, n_pos)
    for layer in range(n_layers):
        for head in range(n_heads):
          for pos in range(n_pos):
              fwd_hooks = [(f"blocks.{layer}.attn.hook_result", partial(patch_head_result, layer=layer, head=head, pos=pos))]
              prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
              patching_effect[n_heads*layer+head, pos] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]
    # Plot
    token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
    layerhead_labels = [f"{l}.{h}" for l in range(n_layers) for h in range(n_heads)]
    imshow(patching_effect, xticks=token_labels, yticks=layerhead_labels, xlabel="position", ylabel="layer.head",
               zlabel="Logit difference", title=f"Patching with{corrupt_tokens[4]} /{corrupt_tokens[5]} /{corrupt_tokens[11]}", width=500, height=500)



#### Better color plots

Since the baseline logit diff is already a latge number

In [None]:
model_2l.cfg.use_attn_result = True


clean_prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
clean_tokens = model_2l.to_str_tokens(clean_prompt)
clean_prediction_logits = model_2l(clean_prompt)[0, -1]
clean_prediction_logit_diff = clean_prediction_logits[clean_answer_index] - clean_prediction_logits[corrupt_answer_index]

# Indices of the right and wrong answers (lastnames) to judge what the model predicts
clean_answer_index = model_2l.tokenizer.encode(" Hart")[0]
corrupt_answer_index = model_2l.tokenizer.encode(" Carroll")[0]


n_layers = model_2l.cfg.n_layers
n_heads = model_2l.cfg.n_heads
n_pos = len(clean_tokens)

def patch_head_result(activations, hook, layer=None, head=None, pos=None):
   activations[:, pos, head, :] = corrupt_cache[hook.name][:, pos, head, :]
   return activations


for corrupt_prompt in ["Her name was Alex Carroll. Tomorrow at lunch time Alex",
                       "Her name was Sarah Hart. Tomorrow at lunch time Alex",
                       "Her name was Alex Hart. Tomorrow at lunch time Sarah"]:
    corrupt_tokens = model_2l.to_str_tokens(corrupt_prompt)
    _, corrupt_cache = model_2l.run_with_cache(corrupt_prompt)

    patching_effect = torch.zeros(n_layers*n_heads, n_pos)
    for layer in range(n_layers):
        for head in range(n_heads):
          for pos in range(n_pos):
              fwd_hooks = [(f"blocks.{layer}.attn.hook_result", partial(patch_head_result, layer=layer, head=head, pos=pos))]
              prediction_logits = model_2l.run_with_hooks(clean_prompt, fwd_hooks=fwd_hooks)[0, -1]
              patching_effect[n_heads*layer+head, pos] = prediction_logits[clean_answer_index] - prediction_logits[corrupt_answer_index]
    # Plot
    token_labels = [f"(pos {i:2}) {t}" for i, t in enumerate(clean_tokens)]
    layerhead_labels = [f"{l}.{h}" for l in range(n_layers) for h in range(n_heads)]
    imshow(patching_effect-clean_prediction_logit_diff, xticks=token_labels, yticks=layerhead_labels, xlabel="position", ylabel="layer.head",
               zlabel="Logit difference difference", title=f"Patching with{corrupt_tokens[4]} /{corrupt_tokens[5]} /{corrupt_tokens[11]}", width=500, height=500)


# Method 3: Circuitsvis attention plots

In [None]:
%pip install circuitsvis


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import circuitsvis as cv


In [None]:
prompt = "Her name was Alex Hart. Tomorrow at lunch time Alex"
_, cache = model_2l.run_with_cache(prompt)
cv.attention.attention_patterns(tokens=model_2l.to_str_tokens(prompt),
                                attention=cache[f'blocks.1.attn.hook_pattern'][0])

In [None]:
import circuitsvis as cv
layer = 1
cv.attention.attention_patterns(tokens=model_2l.to_str_tokens("Her name was Alex Hart. Tomorrow at lunch time Alex"),
                                attention=cache[f'blocks.{layer}.attn.hook_pattern'][0])