<a href="https://colab.research.google.com/github/mnida/mech-interp-r1/blob/main/Mech_Interp_Circuit_Experiment_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install transformer_lens
%pip install circuitsvis
# Install a faster Node version
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

Collecting transformer_lens
  Downloading transformer_lens-1.15.0-py3-none-any.whl (124 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.2/124.2 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.29.1-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.3/297.3 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/5

In [2]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

In [3]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x79bc53480c40>

In [4]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Idea 1: Exploring Linearity of Time Within GPT 2 Small (short-lived)

The motivation of this work is to get my feet wet in the world of mechanistic interpretability research by learning more about how models represent or learn about time.

I choose this idea because I am imagining a future in which GPT-8 (or
 superintellegence equivalent), might represent time in a non-intuitive manner. Maybe something like the aliens in the movie Arrival. This new representation for time could allow the model to make tons of progress on unsolved physics questions.


In [5]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


I first attempted the type of prompts below, where I was testing if the model could correctly understand time and what 2 hours later meant. Unfortunately GPT-2 small was terrible at this, so I don't think its worth digging deeper.

In [6]:
example_prompt = "In New York the time is 8:00 PM. In Los Angeles the time is"
example_answer = "5"
example_prompt_written = "In New York the time is eight PM. In Los Angeles the time is"
example_answer_written = "five"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
utils.test_prompt(example_prompt_written, example_answer_written, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'In', ' New', ' York', ' the', ' time', ' is', ' 8', ':', '00', ' PM', '.', ' In', ' Los', ' Angeles', ' the', ' time', ' is']
Tokenized answer: [' 5']


Top 0th token. Logit: 17.43 Prob: 16.80% Token: | 9|
Top 1th token. Logit: 17.38 Prob: 16.00% Token: | 8|
Top 2th token. Logit: 16.89 Prob:  9.80% Token: | 7|
Top 3th token. Logit: 16.75 Prob:  8.46% Token: | 6|
Top 4th token. Logit: 16.56 Prob:  7.01% Token: | 10|
Top 5th token. Logit: 16.54 Prob:  6.89% Token: | 2|
Top 6th token. Logit: 16.48 Prob:  6.49% Token: | 5|
Top 7th token. Logit: 16.29 Prob:  5.36% Token: | 11|
Top 8th token. Logit: 16.26 Prob:  5.19% Token: | 4|
Top 9th token. Logit: 16.20 Prob:  4.92% Token: | 1|


Tokenized prompt: ['<|endoftext|>', 'In', ' New', ' York', ' the', ' time', ' is', ' eight', ' PM', '.', ' In', ' Los', ' Angeles', ' the', ' time', ' is']
Tokenized answer: [' five']


Top 0th token. Logit: 15.66 Prob:  7.70% Token: | 8|
Top 1th token. Logit: 15.61 Prob:  7.31% Token: | 10|
Top 2th token. Logit: 15.57 Prob:  7.07% Token: | 9|
Top 3th token. Logit: 15.37 Prob:  5.75% Token: | 11|
Top 4th token. Logit: 15.11 Prob:  4.44% Token: | midnight|
Top 5th token. Logit: 15.08 Prob:  4.32% Token: | 7|
Top 6th token. Logit: 14.99 Prob:  3.95% Token: | ten|
Top 7th token. Logit: 14.95 Prob:  3.78% Token: | 12|
Top 8th token. Logit: 14.87 Prob:  3.50% Token: | seven|
Top 9th token. Logit: 14.70 Prob:  2.95% Token: | five|


In [7]:
example_prompt = "John gets one new coin every day. Yesterday he had two coins. After getting a new coin, today he has"
example_answer = " three"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'John', ' gets', ' one', ' new', ' coin', ' every', ' day', '.', ' Yesterday', ' he', ' had', ' two', ' coins', '.', ' After', ' getting', ' a', ' new', ' coin', ',', ' today', ' he', ' has']
Tokenized answer: [' three']


Top 0th token. Logit: 17.26 Prob: 28.07% Token: | two|
Top 1th token. Logit: 16.76 Prob: 17.13% Token: | three|
Top 2th token. Logit: 16.69 Prob: 15.92% Token: | one|
Top 3th token. Logit: 15.73 Prob:  6.08% Token: | four|
Top 4th token. Logit: 15.39 Prob:  4.33% Token: | a|
Top 5th token. Logit: 14.86 Prob:  2.55% Token: | five|
Top 6th token. Logit: 14.71 Prob:  2.20% Token: | just|
Top 7th token. Logit: 14.57 Prob:  1.92% Token: | only|
Top 8th token. Logit: 14.28 Prob:  1.43% Token: | six|
Top 9th token. Logit: 14.07 Prob:  1.17% Token: | 2|


After trying a lot of math type of tasks, I realized gpt2-small is just not good at numbers. I'm going to try to find a different task related to language for this specific project, just so that I can get started.

However, I think there is an interesting question that I found which is how the model interprets tokens of numbers in numerical form (ex: 1,2,3) vs tokens of numbers in english (ex: one, two, three). I think there could be an interesting analysis done of how the model acts differently when the only difference in the prompt is how the number is represented. After experimenting for a few minutes, I see slight differences in the outputs.

In [8]:
example_prompt = "A mouse is small, but an elephant is"
example_answer = "big"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' mouse', ' is', ' small', ',', ' but', ' an', ' elephant', ' is']
Tokenized answer: [' big']


Top 0th token. Logit: 18.63 Prob: 26.37% Token: | big|
Top 1th token. Logit: 18.14 Prob: 16.10% Token: | large|
Top 2th token. Logit: 17.25 Prob:  6.64% Token: | huge|
Top 3th token. Logit: 17.23 Prob:  6.45% Token: | bigger|
Top 4th token. Logit: 17.00 Prob:  5.18% Token: | larger|
Top 5th token. Logit: 16.91 Prob:  4.69% Token: | a|
Top 6th token. Logit: 16.41 Prob:  2.87% Token: | small|
Top 7th token. Logit: 16.17 Prob:  2.25% Token: | much|
Top 8th token. Logit: 15.98 Prob:  1.85% Token: | very|
Top 9th token. Logit: 15.68 Prob:  1.38% Token: | tiny|


In [9]:
example_prompt = "A pillow is soft, but a brick is"
example_answer = "hard"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' pillow', ' is', ' soft', ',', ' but', ' a', ' brick', ' is']
Tokenized answer: [' hard']


Top 0th token. Logit: 17.52 Prob: 14.28% Token: | hard|
Top 1th token. Logit: 17.39 Prob: 12.53% Token: | soft|
Top 2th token. Logit: 16.49 Prob:  5.10% Token: | not|
Top 3th token. Logit: 16.29 Prob:  4.17% Token: | a|
Top 4th token. Logit: 16.02 Prob:  3.18% Token: | thick|
Top 5th token. Logit: 15.99 Prob:  3.09% Token: | solid|
Top 6th token. Logit: 15.86 Prob:  2.73% Token: | softer|
Top 7th token. Logit: 15.72 Prob:  2.36% Token: | strong|
Top 8th token. Logit: 15.71 Prob:  2.34% Token: | heavy|
Top 9th token. Logit: 15.64 Prob:  2.19% Token: | firm|


## Idea 2: Opposite Adjective
I am going to study the opposite adjective task where the model has to output the opposite of an adjactive given context. Some example adjactives are old vs young and big vs small.

Since we are going to fix the output to one correct answer, I am going to use the logit difference metric and ignore synonyms of the answer that might work as well (for instance "new" instead of "young" might be outputed but we won't consider it a correct or incorrect answer to simplify the problem).

In [10]:
prompt_format = [
    "A mouse is{} but an elephant is",
    "A turtle is{} but a kitten is",
    "A desert is{} but an ocean is",
    "A pillow is{} but a brick is",
    "A rock is{} but a sponge is",
    "A rose is{} but a weed is"
]
prompt_format_commas = [
    "A mouse is{}, but an elephant is",
    "A turtle is{}, but a kitten is",
    "A desert is{}, but an ocean is",
    "A pillow is{}, but a brick is",
    "A rock is{}, but a sponge is",
    "A rose is{}, but a weed is"
]
names = [
    (" small", " big"),
    (" old", " young"),
    (" dry", " wet"),
    (" soft", " hard"),
    (" hard", " soft"),
    (" beautiful", " ugly")

]
# List of prompts
prompts = []
prompts_commas = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):

        answers.append((names[i][1], names[i][0]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
        prompts_commas.append(prompt_format_commas[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(prompts_commas)
print(answers)

['A mouse is small but an elephant is', 'A turtle is old but a kitten is', 'A desert is dry but an ocean is', 'A pillow is soft but a brick is', 'A rock is hard but a sponge is', 'A rose is beautiful but a weed is']
['A mouse is small, but an elephant is', 'A turtle is old, but a kitten is', 'A desert is dry, but an ocean is', 'A pillow is soft, but a brick is', 'A rock is hard, but a sponge is', 'A rose is beautiful, but a weed is']
[(' big', ' small'), (' young', ' old'), (' wet', ' dry'), (' hard', ' soft'), (' soft', ' hard'), (' ugly', ' beautiful')]


In [11]:
# Code to make sure positions are aligned across prompts and lengths are the same
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' mouse', ' is', ' small', ' but', ' an', ' elephant', ' is']
Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' turtle', ' is', ' old', ' but', ' a', ' kitten', ' is']
Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' desert', ' is', ' dry', ' but', ' an', ' ocean', ' is']
Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' pillow', ' is', ' soft', ' but', ' a', ' brick', ' is']
Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' rock', ' is', ' hard', ' but', ' a', ' sponge', ' is']
Prompt length: 9
Prompt as tokens: ['<|endoftext|>', 'A', ' rose', ' is', ' beautiful', ' but', ' a', ' weed', ' is']


In [12]:
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

#Getting the logits for the prompts with commas

tokens_commas = model.to_tokens(prompts_commas, prepend_bos=True)

original_logits_commas, cache_commas = model.run_with_cache(tokens_commas)


In [13]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
print(
    "Average logit difference:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),
)

Per prompt logit difference: tensor([3.1440, 2.0830, 0.3140, 0.5160, 1.7990, 2.6010])
Average logit difference: 1.743


In [14]:
print(
    "Per prompt logit difference with commas:",
    logits_to_ave_logit_diff(original_logits_commas, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff_commas = logits_to_ave_logit_diff(original_logits_commas, answer_tokens)
print(
    "Average logit difference with commas:",
    round(logits_to_ave_logit_diff(original_logits_commas, answer_tokens).item(), 3),
)

Per prompt logit difference with commas: tensor([2.2180, 1.2900, 0.3890, 0.1310, 0.8260, 1.9750])
Average logit difference with commas: 1.138


Hmm, this logit difference is not that large compared to the IOI example. I think part of the reason is because of the ambuguity of the answer/opposite, for instance the model has some probability that the token is "large" as well as "big".

## Some general predictions as to what is happening.

I assume the most relevant tokens in the prompt for the first data point are, old, but, tortoise, and kitten. It seems the "but" is very helpful for signifying the opposite of old, since I tried replacing "but" with "then" and the logit difference was significantly smaller.

The model needs to figure out what the previous adjactive is and then predict the opposite. Therefore, maybe there is a head that attends the noun to adjective relationship and then some other component (MLP?) that is able to find the opposite of the adjective. This is likely combined with a similar layer that processes both the noun we want to describe (eg: "kitten") and also the opposite (eg: "old").

## Logit Difference

In [15]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([6, 2, 768])
Logit difference directions shape: torch.Size([6, 768])


In [17]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)

print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([6, 9, 768])
Calculated average logit diff: 2.043
Original logit difference: 1.743


Now, we'll take the logit lens approach to decompose where the logit differences are coming from.

In [18]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [19]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

This stream shows that most of the performance comes from attention layer 9 as well as the last layer which looks like layer 11.

More interestingly, the performance is negative up until layer 9, which means the model is more likely to predict the incorrect answer, the same adjictive in the prompt. I haven't seen this graph that stays in the negative so long and I think it might have to do with my hypothesis that the model has a lot of attention on the first noun/adjective pairing to learn that relationship.

Time to dig deeper.

In [20]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

This is similar to the IOI task where the attention makes the most difference in each layer since we also care a lot about moving the adjactive information around. One hypothesis for the decrease in performance in the earlier layers is that maybe the attention is simply moving the incorrect adjictive to the "is" token, but the MLP isn't able to process this that well. However in layer 9, maybe the attention is moving different information, like both the noun of "kitten" and the other adjictive "old" and the MLP has better information to figure out an opposite.

From this graph, I can better understand how the attention sets up the MLP to process information effectively.

In [21]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


Here we can see that the only heads that matter in a large way are: L9H7 and L10H7 as well as L10H1 in the negative direction. This shows that the drop in the earlier layers are due to a combination of multiple heads which have realitively small logit difference effects of about -0.1 , especially in layer 8.

Let's dig deeper and do some attention head analysis.

In [22]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [23]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

## Note

It's at this point where I'm wondering how solid of a task this was to predict because it has to introduce new information that isn't in the prompt. I've heard that most stored knowledge/facts exist in the earlier layers of the transformer. My intuition was that there could be an algorithm that combines some form of lookup information with an algorithm that finds the opposite of the adjictive in the prompt("incorrect answer") and something that is similar to the second noun ("elephant").

Taking a look at these heads, I see that the most impactful head L9H7 attends "small" to "but", "an", and "is". This seems to act as the "opposite" head, given that small is the wrong adjective but it is attended to "but", giving the next circuit information about which opposite to look at.

It's curious that the second noun, "elephant" is hardly attended from in the postitive difference heads.

The negative heads, seem to focus mostly on attending the first adjective "small", with the last token "is", which makes sense that moving the opposite adjective to the last token would result in a higher likelihood of predicting that "wrong" adjective, assuming there was little processing done after moving it.

There is also some movement for the first noun to the second noun, which I expected more of in the postive heads because the relationship between the nouns should be helpful in understanding what adjective to characterize the second noun.

## Activation Patching Approach

I will first try changing the prompt by swapping both the noun and the adjective. It is probably better to do this because swapping just the adjective will confuse the model that is also probably doing some work to describe the relevant noun instead of just output the opposite adjective.

Therefore I will try these prompts:

Correct:   "A mouse is small but an elephant is",
Incorrect: "An elephant is big but a mouse is",

I hope this can control only for the circuits that are meant to understand what proper opposite adjective should be since the structure of both prompts are the same.

In [24]:
corrupted_prompts = [
"An elephant is big but a mouse is",
"A kitten is young but a turtle is",
"An ocean is wet but a desert is",
"A brick is hard but a pillow is",
"A sponge is soft but a rock is",
"A weed is ugly but a rose is"
]
# for prompt in corrupted_prompts:
#     str_tokens = model.to_str_tokens(prompt)
#     print("Prompt length:", len(str_tokens))
#     print("Prompt as tokens:", str_tokens)

corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)

corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

Corrupted Average Logit Diff -1.63
Clean Average Logit Diff 1.74


In [25]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [26]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

In [None]:
del patched_residual_stream_diff # delete to free up memory

Ok this graph from the patched residual stream helps a lot and shows that the most important token that the circuit is paying attention to is the first adjective "small" in this case. Elephant or the second noun is also being looked at closely, which we didn't understand from our initial logit difference experiments that only loooked at the en of the circuits.

In addition, the last two layers move all information to the last token and have both actively improved on clean performance.

Now we try patching in Layers to see the effect from attention vs MLP. We already know that both play a role from our earlier graphs but attention seems more important especially in layer 9. But let's see.

##### Issue: Running into a ton of OOM issue here on this layer patching script. The solution was to include the disable automatic differentiation at the start

In [27]:
patched_attn_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in range(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, answer_tokens
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, answer_tokens
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

In [28]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)

In [29]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)

Wow! This is awesome, it seems that attention only plays a big part at layer 9 where all the information is moved to the last token.

However, the MLPs are doing a lot of the work, unlike the IOI task. This makes sense when we consider that this task involves more processing to figure out what the opposite adjective is since it isn't in the target.

What's cool is that we can see in the earlier layers, even up to layer 0, the MLP starts processing the first adjective. The note in Neel's notebook states a hypothesis that MLP layer 0 in gpt-2 small has high activity across the board and can act as compying over the token embeddings, so that when later layers access the input tokens they do so at MLP0. This is still interesting because both the first adjective and the second noun are highly active here but it seems like the first noun is not being used/pulled much from the MLP0.

However we still see a lot of activity at layer 1 on "small".

Could it be that the general algorithm is simply to do processing and compare the first adjective with the last noun and then figure out what the correct adjective to output is in the context of the first adjective?

We also see that once all the information is moved to the last token, a lot of processing happens in layer 8,9,10 and 11. Therefore, I'm not sure what kind of processing happens when because know there has to be an MLP that takes in "small" and outputs the opposite.

## Individual Heads


In [33]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [34]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output Attention",
    labels={"x": "Head", "y": "Layer"},
)

This again shows how layer 10 is doing work at head 1 and head 7, but you can't see this at the layer level because their work cancels out. I'm curious to explore layer 9 and layer 10 attention heads the most.

## Decomposing Heads: Values

In [37]:
patched_head_v_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [38]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)

In [47]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads
    ),
    range_x=(-2, 2),
    range_y=(-2, 2),
    title="Scatter plot of output patching vs value patching",
)

Just as a recap: value patching compared to the output patching helps us figure out when figuring out what to move is especially important (the OV circuit).

From this scatter plot we can see an interesting linear line, showing that all the heads seem to matter when considering the OV circuit, however L9H7 followed by L10H7 truly stand out as having significant contributions. I almost missed L9H7 because it was out of the graph's range!

## Decomposing Heads: Attention/ QK

In [52]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_pattern


patched_head_attn_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [53]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)

Here we see L9H7 is once again doing a ton of work and L10H7 is doing a bit. However the interesting heads here are L7H7 and L11H10 which have the highest attention patching differences but near 0 output patching difference which is why we haven't looked closley at them until now. These are both at two different steps of the algorithm, before and after layer 9 where all the information is moved to the last token.

## Top Attention Patterns to Look At

In [55]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    patched_head_z_diff.abs().flatten(), k=top_k
).indices
first_mid_layer = 9
first_late_layer = 10
early_heads = top_heads_by_output_patch[
    top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer
]
mid_heads = top_heads_by_output_patch[
    torch.logical_and(
        model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,
        top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,
    )
]
late_heads = top_heads_by_output_patch[
    model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch
]

early = visualize_attention_patterns(
    early_heads, cache, tokens[0], title=f"Top Early Heads (Pre-layer 9)"
)
mid = visualize_attention_patterns(
    mid_heads, cache, tokens[0], title=f"Layer 9 Heads"
)
late = visualize_attention_patterns(
    late_heads, cache, tokens[0], title=f"Layer 10 and 11 Top Heads"
)

HTML(early + mid+ late)

#### Note: L4H7 looks like an induction head pattern

We can also try just swapping the adjectives, although this should result in much different answers than expected because now the sentence is constructed as if the answer should be an even stronger adjective rather than an opposite "bigger" vs "small", even though we are describing a "mouse" in this case.

In [33]:
corrupted_prompts_2 = [
"An elephant is small, but a mouse is",
"A kitten is old, but a tortoise is",
"A rainforest is dry, but a desert is",
"A brick is soft, but a pillow is",
"A sponge is hard, but a rock is",
"A weed is beautiful, but a rose is"
]

# Current Conclusion/Thoughts

Throughout this experiment, I investigated if I could find a circuit for the task of contrasting adjectives when presented with opposite nouns. I was able to isolate a large amount of computation happening in layer 9 and 10.

Before layer 9, most of the computation is being done on the first adjective and the "but" token, mostly by MLPs which seem to be processing the adjective and figuring out a potential opposite.

Then at layer 9, all of the information is moved from the first adjective to the final token and both attention and MLPs are fired in layer 9 and 10. Layer 9 head 7 makes the largest logit difference across all metrics except "attention/QK", which means I believe that is where the opposite is chosen based on the processing done in that layer and previous layers.

Overall, this seems like just the start of trying to uncover a circuit and I would need to dig even deeper by analyzing way more prompts.

I spent a lot of time trying to come up with better prompts and debugging errors in my prompts like different token positions etc. Throughout this process, I discovered gpt-2 small is very context dependent.

For instance, for the majority of the research I had my prompt look like:

    "A mouse is small, but an elephant is"

However, I decided to remove the comma because it was weird to me how the comma was being attended to more than the 'but' and even more than 'mouse' most of the time. Then my new prompt was

    "A mouse is small but an elephant is"

and I had a 0.7 improvement in logit difference across prompts. Just by removing the comma!

I believe this task wasn't the best because there are multiple things that go into the token prediction like noun adjective association and understanding the sentence is comparing opposite adjectives.

### Areas of Improvement:

I feel like I lacked a lot of the intuion to understand what different circuits were doing on a fundamental level. In the EAD colab, which I followed closely to get up to speed on this code, conclusions are made about which heads do deduplication and I wasn't able to understand how they figured that out and then apply that methodology here. In this research, I felt like I was doing a lot of guessing about what's going on, which is necessary to some extend. I'm looking forward to having my intution improve from either mentorship or just putting reps in.

In addition, I'm excited to learn about how these type of experiment results (or more robust results like the Interpretability in the Wild Paper) get unified in a broader understanding of the gpt-2 small model and LLMs in general. Part of me feels like there might be all this specific, niche work going on but no way to unify all these learnings so that all researchers might better get a global view of the model. Maybe an idea for this could be implemented in the TransformerLens Library.


## Exciting other experiments

After reading the *Beyond A: Better Planning with Transformers via Search Dynamics Bootstrapping* paper, extremely motivated to see if we can use these mech interp techniques to uncover the types of algorithms that the model is learning.

After doing this excersise and immersing myself within alginment and interpretability for a short time, I believe the key to many scientific and mathematical discoveries are hidden within LLMs and their internal algorithms. Sometimes these algorithms are probably not optimal or correct, however other times, in the case of search, it seems that large language models are innovating on our best approaches like A*.

I would first like to start by investigating how models think about sorting, CS 101. I think because it is multi token output, it will be a lot harder than the excersise above or IOI but I'm sure we can reveal some interesting dynamics about how different sized models sort (I assume Gpt2-small is either incapable or does some algorithm with similar efficancy to bogo sort lol).