# Setup

In [None]:
# Detect if we're running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
except:
    IN_COLAB = False

# Install if in Colab
if IN_COLAB:
    %pip install transformer_lens
    %pip install circuitsvis
    # Install a faster Node version
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

# Hot reload in development mode & not running on the CD
if not IN_COLAB:
    from IPython import get_ipython
    ip = get_ipython()
    if not ip.extension_manager.loaded:
        ip.extension_manager.load('autoreload')
        %autoreload 2

Running as a Colab notebook
Collecting transformer_lens
  Downloading transformer_lens-1.9.1-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.4/116.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.14.6-py3-none-any.whl (493 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops>=0.6.0 (from t

In [None]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

In [None]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


In [None]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [None]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


# Brainstorm

What project should I do? Probably one from Neel Nanda's list with difficulty A. I found the ROME paper interesting, so could try an extension of that. Or a simple feature locating excercise in a model.

I think I will try an easy feature finding problem in GPT-2 first. Possible behaviours to look at:
 - end of sentence
 - year
 - number
 - etc?

Or maybe could try extensions to the IOI stuff. Ablate earlier layers or dig deeper into QK circuits?

I will try to explore how GPT-2 handles sequences of words that are common like the months, days of week, seasons etc. I think it will be interesting to see particularly how it deals with the end of the sequences like December to January.


Let's try the days of the week first. Of course to create a sequence we need at least two instances of days in the prompt. Using more than two would only make it easier, so let's stay with two.

It works well! Even though it's the end of the week, it still knows that the next day should be Monday.

Now let's look at the months:

This works well as well! Now let's see the end of the year:

In [None]:
example_prompt = "November, December and"
example_answer = " January"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'November', ',', ' December', ' and']
Tokenized answer: [' January']


Top 0th token. Logit: 18.05 Prob: 68.49% Token: | January|
Top 1th token. Logit: 16.30 Prob: 11.90% Token: | March|
Top 2th token. Logit: 15.39 Prob:  4.76% Token: | February|
Top 3th token. Logit: 14.81 Prob:  2.68% Token: | April|
Top 4th token. Logit: 14.58 Prob:  2.12% Token: | December|
Top 5th token. Logit: 13.29 Prob:  0.59% Token: | May|
Top 6th token. Logit: 13.13 Prob:  0.50% Token: | September|
Top 7th token. Logit: 13.11 Prob:  0.49% Token: | November|
Top 8th token. Logit: 13.07 Prob:  0.47% Token: | Spring|
Top 9th token. Logit: 13.00 Prob:  0.44% Token: | the|


So it still knows this, but with less confidence, as the probability for January is lower here than it was for March in the above example.

I really want to break this, so let's try another example with the days of the week in the middle of the week.

In [None]:
example_prompt = "Tuesday, Wednesday and"
example_answer = " Thursday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Tuesday', ',', ' Wednesday', ' and']
Tokenized answer: [' Thursday']


Top 0th token. Logit: 19.10 Prob: 46.37% Token: | Friday|
Top 1th token. Logit: 19.05 Prob: 43.90% Token: | Thursday|
Top 2th token. Logit: 16.89 Prob:  5.10% Token: | Saturday|
Top 3th token. Logit: 15.12 Prob:  0.86% Token: | Sunday|
Top 4th token. Logit: 14.62 Prob:  0.53% Token: | Monday|
Top 5th token. Logit: 14.56 Prob:  0.50% Token: | Fridays|
Top 6th token. Logit: 13.63 Prob:  0.20% Token: | Wednesday|
Top 7th token. Logit: 13.35 Prob:  0.15% Token: | the|
Top 8th token. Logit: 13.26 Prob:  0.14% Token: | Thurs|
Top 9th token. Logit: 13.20 Prob:  0.13% Token: |Thursday|


Yay, it broke! Interestingly the logits are very close to each other for Thursday and Friday.

Out of curiousity let's see whether or not making it the beginning of the week fixes it:

In [None]:
example_prompt = "Monday, Tuesday and"
example_answer = " Wednesday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Monday', ',', ' Tuesday', ' and']
Tokenized answer: [' Wednesday']


Top 0th token. Logit: 19.61 Prob: 53.72% Token: | Wednesday|
Top 1th token. Logit: 19.37 Prob: 42.19% Token: | Thursday|
Top 2th token. Logit: 15.91 Prob:  1.32% Token: | Saturday|
Top 3th token. Logit: 14.81 Prob:  0.44% Token: | Friday|
Top 4th token. Logit: 14.29 Prob:  0.26% Token: | Wed|
Top 5th token. Logit: 14.26 Prob:  0.26% Token: | Tuesday|
Top 6th token. Logit: 14.19 Prob:  0.24% Token: | Thurs|
Top 7th token. Logit: 13.46 Prob:  0.11% Token: | Monday|
Top 8th token. Logit: 13.25 Prob:  0.09% Token: | even|
Top 9th token. Logit: 13.02 Prob:  0.07% Token: | March|


Yes, it works this way! Wedneday and Thursday have very similar logits and probabilities, but Wednesday's is a bit larger.

One last example; does adding more days to the broken example make it better?

In [None]:
example_prompt = "Monday, Tuesday, Wednesday and"
example_answer = " Thursday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Monday', ',', ' Tuesday', ',', ' Wednesday', ' and']
Tokenized answer: [' Thursday']


Top 0th token. Logit: 19.54 Prob: 68.04% Token: | Thursday|
Top 1th token. Logit: 18.60 Prob: 26.51% Token: | Friday|
Top 2th token. Logit: 16.03 Prob:  2.03% Token: | Saturday|
Top 3th token. Logit: 14.45 Prob:  0.42% Token: | Fridays|
Top 4th token. Logit: 14.17 Prob:  0.32% Token: | Thurs|
Top 5th token. Logit: 14.09 Prob:  0.29% Token: | Sunday|
Top 6th token. Logit: 14.05 Prob:  0.28% Token: |Thursday|
Top 7th token. Logit: 14.03 Prob:  0.28% Token: | Wednesday|
Top 8th token. Logit: 13.52 Prob:  0.16% Token: | Monday|
Top 9th token. Logit: 13.06 Prob:  0.10% Token: | the|


Yes! In this case Thursday has a much higher probability than Friday.

Let's modify this type of sentence to a slightly longer one that is more similar to an IOI type of structure:

In [None]:
example_prompt = "Today it is Tuesday, tomorrow it is"
example_answer = " Wednesday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Today', ' it', ' is', ' Tuesday', ',', ' tomorrow', ' it', ' is']
Tokenized answer: [' Wednesday']


Top 0th token. Logit: 17.34 Prob: 36.70% Token: | Wednesday|
Top 1th token. Logit: 16.97 Prob: 25.18% Token: | Thursday|
Top 2th token. Logit: 15.78 Prob:  7.66% Token: | Monday|
Top 3th token. Logit: 15.59 Prob:  6.33% Token: | Tuesday|
Top 4th token. Logit: 15.53 Prob:  6.00% Token: | Saturday|
Top 5th token. Logit: 14.86 Prob:  3.06% Token: | Friday|
Top 6th token. Logit: 13.77 Prob:  1.03% Token: | the|
Top 7th token. Logit: 13.59 Prob:  0.86% Token: | March|
Top 8th token. Logit: 13.51 Prob:  0.79% Token: | Sunday|
Top 9th token. Logit: 13.05 Prob:  0.50% Token: | January|


Let's see if it works for months as well:

In [None]:
example_prompt = "Now it is August, the next month is"
example_answer = " September"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Now', ' it', ' is', ' August', ',', ' the', ' next', ' month', ' is']
Tokenized answer: [' September']


Top 0th token. Logit: 13.83 Prob: 14.69% Token: | the|
Top 1th token. Logit: 12.63 Prob:  4.45% Token: | September|
Top 2th token. Logit: 12.31 Prob:  3.22% Token: | a|
Top 3th token. Logit: 12.27 Prob:  3.10% Token: | when|
Top 4th token. Logit: 12.08 Prob:  2.56% Token: | World|
Top 5th token. Logit: 12.08 Prob:  2.55% Token: | Halloween|
Top 6th token. Logit: 11.88 Prob:  2.09% Token: | National|
Top 7th token. Logit: 11.57 Prob:  1.53% Token: | going|
Top 8th token. Logit: 11.37 Prob:  1.26% Token: | called|
Top 9th token. Logit: 11.25 Prob:  1.12% Token: | Christmas|


Interestingly, it doesn't work for months! Might be because with the days I use "today" and "tomorrow" which are more specific.

Let's focus on the days of the week and try some more examples.



In [None]:
example_prompt = "Today it is Sunday so tomorrow it is"
example_answer = " Monday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Today', ' it', ' is', ' Sunday', ' so', ' tomorrow', ' it', ' is']
Tokenized answer: [' Monday']


Top 0th token. Logit: 14.57 Prob: 14.74% Token: | Monday|
Top 1th token. Logit: 14.27 Prob: 10.94% Token: | Sunday|
Top 2th token. Logit: 14.13 Prob:  9.57% Token: | Saturday|
Top 3th token. Logit: 13.48 Prob:  4.99% Token: | Tuesday|
Top 4th token. Logit: 13.26 Prob:  4.01% Token: | a|
Top 5th token. Logit: 13.24 Prob:  3.91% Token: | Thursday|
Top 6th token. Logit: 13.22 Prob:  3.82% Token: | Friday|
Top 7th token. Logit: 13.12 Prob:  3.49% Token: | Wednesday|
Top 8th token. Logit: 13.00 Prob:  3.09% Token: | the|
Top 9th token. Logit: 12.24 Prob:  1.44% Token: | going|


In [None]:
example_prompt = "Today it is Monday so tomorrow it is"
example_answer = " Tuesday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Today', ' it', ' is', ' Monday', ' so', ' tomorrow', ' it', ' is']
Tokenized answer: [' Tuesday']


Top 0th token. Logit: 16.04 Prob: 24.11% Token: | Tuesday|
Top 1th token. Logit: 15.58 Prob: 15.16% Token: | Wednesday|
Top 2th token. Logit: 15.25 Prob: 10.91% Token: | Monday|
Top 3th token. Logit: 15.04 Prob:  8.85% Token: | Friday|
Top 4th token. Logit: 14.86 Prob:  7.36% Token: | Thursday|
Top 5th token. Logit: 14.64 Prob:  5.93% Token: | Saturday|
Top 6th token. Logit: 14.54 Prob:  5.38% Token: | Sunday|
Top 7th token. Logit: 13.24 Prob:  1.46% Token: | a|
Top 8th token. Logit: 13.16 Prob:  1.35% Token: | the|
Top 9th token. Logit: 12.29 Prob:  0.56% Token: | on|



I want to try examples where the sentence involves "yesterday", as I think this should work fairly similarly, since it should still pay attention to the actual day today, and then it should process whether or not it's the previous or following day that should be predicted.

In [None]:
example_prompt = "Today it is Wednesday, yesterday it was"
example_answer = " Tuesday"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Today', ' it', ' is', ' Wednesday', ',', ' yesterday', ' it', ' was']
Tokenized answer: [' Tuesday']


Top 0th token. Logit: 14.95 Prob: 12.60% Token: | Thursday|
Top 1th token. Logit: 14.80 Prob: 10.88% Token: | Wednesday|
Top 2th token. Logit: 14.65 Prob:  9.39% Token: | Friday|
Top 3th token. Logit: 14.49 Prob:  8.02% Token: | yesterday|
Top 4th token. Logit: 14.42 Prob:  7.46% Token: | Tuesday|
Top 5th token. Logit: 13.97 Prob:  4.74% Token: | Monday|
Top 6th token. Logit: 13.64 Prob:  3.43% Token: | a|
Top 7th token. Logit: 13.55 Prob:  3.13% Token: | Saturday|
Top 8th token. Logit: 13.09 Prob:  1.97% Token: | the|
Top 9th token. Logit: 12.81 Prob:  1.49% Token: | last|


So it doesn't work for "yesterday"! I also tried changing the day in the input prompt and the highest probability token is consistently the following day (the "tomorrow").

### Brainstorm conclusions

It is a bit counter intuitive for me that in case of the days of the week GPT-2 predicts the correct day when it's the weekend, but not when it is the middle of the week. In case of the months it works well in both cases, and it works as expected (by me); it is more confident in the middle months than at the end of the year.

I think I will look at the IOI style sentence, because I think it would be cool to see similarities and differences between the two. It's also interesting that it works better for days of the week.

# Final
I decided to look at a similar structure to the IOI task, but with sequences of common words. So like the sentence: "Today it is Sunday, tomorrow it is BLANK". I think this is interesting because here it's both attention and MLP should be important, as it's not just about moving information (like the fact that today is Sunday) but also processing it (to determine the next day).

Also, in case of the IOI task, all information is already in the prompt whereas here the correct next token is not present there.

So, there should be some similarities in the attention patterns I suspect, but there should be plenty differences as well!

# Investigation

What I want to do is investigate sentences like I mentioned and look at their logit attributions, attentions and do some activation patching to find out which heads and layers are important.



### Logit attribution

Let's look at how the direct logit attribution looks like.

In the exploratory analysis notebook from Neel Nanda, here the difference between the logits is looked at for two names in case of the IOI behaviour. But of course in my case there is only one answer prompt. Looking at the examples sentences that I played around with, the day in the input prompt (the "today") is usually second or third in logit values, so I think I will use that as a comparison. So calculate the difference between the logits of the right day (actual tomorrow) and the repeated day (the today). Otherwise I could also just compare the correct answer to the token with the highest probability that is not the correct answer. But this seems slightly less "mechanistic", so I will stick with the repeated day (today).

In [None]:
prompt_format = [
    "Today it is Sunday, tomorrow it is",
    "Today it is Friday, tomorrow it is",
    "Today it is Thursday, tomorrow it is",
    "Today it is Monday, tomorrow it is",
]
days = [
    " Monday",
    " Saturday",
    " Friday",
    " Tuesday",
]
todays = [
    " Sunday",
    " Friday",
    " Thursday",
    " Monday",
]
# List of prompts
prompts = []
# List of answers
answers = days
# List of repeated days (wrong answers)
wrongs = todays
# List of the token (ie an integer) corresponding to each answer
answer_tokens = []
wrong_tokens = []
for i in range(len(prompt_format)):

      answer_tokens.append(
          (
                model.to_single_token(answers[i]),
          )
      )
      prompts.append(prompt_format[i].format(answers[i]))

      wrong_tokens.append(
          (
                model.to_single_token(wrongs[i]),
          )
      )

answer_tokens = torch.tensor(answer_tokens).to(device)
wrong_tokens = torch.tensor(wrong_tokens).to(device)

print(prompts)
print(answers)
print(wrongs)

['Today it is Sunday, tomorrow it is', 'Today it is Friday, tomorrow it is', 'Today it is Thursday, tomorrow it is', 'Today it is Monday, tomorrow it is']
[' Monday', ' Saturday', ' Friday', ' Tuesday']
[' Sunday', ' Friday', ' Thursday', ' Monday']


In [None]:
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

Now here the difference compared to the IOI analysis is that we only have one answer prompt rather than two, so we can't measure the difference between two logits. I think it should be okay to simply measure the logit attribution to the answer token itself. This might be a bad decision, but it's the easiest, so I'll stick with it. Otherwise I could try to measure the difference between the answer logit and the logit which has the largest value that is not the answer token.

In [None]:
def logits_to_ave_logit_diff(logits, answer_tokens, wrong_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    # Get the logits for the answer and wrong answer tokens
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    wrong_logits  = final_logits.gather(dim=-1, index=wrong_tokens)
    # Find their diff
    answer_logit_diff = answer_logits[:, 0] - wrong_logits[:, 0]

    #print(answer_logits)
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit diff:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, wrong_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, wrong_tokens)
print(
    "Average logit diff:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens, wrong_tokens).item(), 3),
)

Per prompt logit diff: tensor([1.3700, 2.0530, 1.3010, 0.6650])
Average logit diff: 1.347


So on average the right answer is $e^{1.35}$ more likely than the actual "today" which is about 4x.  

In [None]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
wrong_residual_directions = model.tokens_to_residual_directions(wrong_tokens)

print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - wrong_residual_directions[:, 0]
)
print("Logit value directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([4, 1, 768])
Logit value directions shape: torch.Size([4, 768])


In [None]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)
#print(final_residual_stream[::2])
average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions[:],
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([4, 9, 768])
Calculated average logit diff: 1.39
Original logit difference: 1.347


These two numbers should agree, not quite sure why there's a small difference between them. Is it because LayerNorm is not exactly linear? Not sure.

## Logit lens

In [None]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [None]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

## Layer attribution

In [None]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

## Head attribution

In [None]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


# Attention

In [None]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [None]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

# Activation patching

Now let's try activation patching!

First, I need to create corrupted prompts. Way I see it, I have two fairly obvious ways to do this (assuming that the answer prompts remain the same as before):
 - by entering the same day as in the answer prompt to the input prompt (so if answer prompt day = n, then corrupted input prompt day = n)
 - by creating a corrupted input prompt so that the relationship is not "following day" but rather "previous day" to the answer prompt (so if answer prompt day = n, then corrupted input prompt day = n+1)


 It is important to think about this having in mind that I calculate the logit difference compared to the same day as in the input prompt (aka wrong answer):

  In the first case, if I set the wrong answers this way as well, then they are exactly the same as the answer prompts, so the logit difference will always be 0. That shouldn't be the case. I could, however, just leave the wrong answers to be the same as before. But then I can't reliably calculate the logit difference.

  In the second case, I can set the wrong answer to be the same as in the corrupted input prompt without having this issue. So let's do that!



In [None]:
corrupted_prompts = [
    "Today it is Tuesday, tomorrow it is",
    "Today it is Thursday, tomorrow it is",
    "Today it is Saturday, tomorrow it is",
    "Today it is Wednesday, tomorrow it is",
]

new_todays = [
    " Tuesday",
    " Thursday",
    " Saturday",
    " Wednesday",
]
# List of repeated days (wrong answers)
new_wrongs = new_todays
# List of the token (ie an integer) corresponding to each answer
new_wrong_tokens = []
for i in range(len(prompt_format)):
      new_wrong_tokens.append(
          (
                model.to_single_token(new_wrongs[i]),
          )
      )

new_wrong_tokens = torch.tensor(new_wrong_tokens).to(device)
#new_wrong_tokens = wrong_tokens
print(corrupted_prompts)
print(answers)

['Today it is Tuesday, tomorrow it is', 'Today it is Thursday, tomorrow it is', 'Today it is Saturday, tomorrow it is', 'Today it is Wednesday, tomorrow it is']
[' Monday', ' Saturday', ' Friday', ' Tuesday']


Perfect! So now, the answer to the prompt is the "yesterday" rather than the desired "today".

In [None]:
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens, new_wrong_tokens)
print(
    "Per prompt logit value:",
    logits_to_ave_logit_diff(corrupted_logits, answer_tokens, new_wrong_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

Per prompt logit value: tensor([ 0.1910,  0.8020, -0.9770, -1.0550])
Corrupted Average Logit Diff -0.26
Clean Average Logit Diff 1.35


I think this is good! We want the corrupted average logit diff to be negative. On the other hand, looking at the individual corrupted logit differences, we see that not all of them are negative. This is explained by the fact that for some of the prompts the "today" is given a higher probability than the "yesterday" and for some it's the other way around. Not sure if this is bad, but I can't think of a way to resolve this at the moment as it would require calculating the logit differences differently.

## Residual stream

In [None]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens, new_wrong_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

## Layers

In [None]:
patched_attn_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, answer_tokens, new_wrong_tokens
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, answer_tokens, new_wrong_tokens
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

In [None]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)

In [None]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)

## Heads

In [None]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens, new_wrong_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

## Decomposing heads

In [None]:
patched_head_v_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens, new_wrong_tokens)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)

In [None]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads
    ),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching",
)

In [None]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_pattern


patched_head_attn_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens, new_wrong_tokens)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [None]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)

In [None]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    patched_head_z_diff.abs().flatten(), k=top_k
).indices
first_mid_layer = 7
first_late_layer = 9
early_heads = top_heads_by_output_patch[
    top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer
]
mid_heads = top_heads_by_output_patch[
    torch.logical_and(
        model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,
        top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,
    )
]
late_heads = top_heads_by_output_patch[
    model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch
]

early = visualize_attention_patterns(
    early_heads, cache, tokens[0], title=f"Top Early Heads"
)
mid = visualize_attention_patterns(
    mid_heads, cache, tokens[0], title=f"Top Middle Heads"
)
late = visualize_attention_patterns(
    late_heads, cache, tokens[0], title=f"Top Late Heads"
)

HTML(early + mid + late)