<a href="https://colab.research.google.com/github/mnida/mech-interp-r1/blob/main/Mech_Interp_Circuit_Experiment_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%pip install transformer_lens
%pip install circuitsvis
# Install a faster Node version
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

Collecting transformer_lens
  Downloading transformer_lens-1.15.0-py3-none-any.whl (124 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/124.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m124.2/124.2 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate>=0.23.0 (from transformer_lens)
  Downloading accelerate-0.28.0-py3-none-any.whl (290 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl (739 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m739.7/739.7 kB[0m [31m12.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl (3.5 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloadi

In [3]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

In [136]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Exploring Linearity of Time Within GPT 2 Small

The motivation of this work is to get my feet wet in the world of mechanistic interpretability research by learning more about how models represent or learn about time.

I choose this idea because I am imagining a future in which GPT-8 (or
 superintellegence equivalent), might represent time in a non-intuitive manner. Maybe something like the aliens in the movie Arrival. This new representation for time could allow the model to make tons of progress on unsolved physics questions.

In [4]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


I first attempted the type of prompts below, where I was testing if the model could correctly understand time and what 2 hours later meant. Unfortunately GPT-2 small was terrible at this, so I don't think its worth digging deeper.

In [84]:
example_prompt = "In New York the time is 8:00 PM. In Los Angeles the time is"
example_answer = "5"
example_prompt_written = "In New York the time is eight PM. In Los Angeles the time is"
example_answer_written = "five"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)
utils.test_prompt(example_prompt_written, example_answer_written, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'In', ' New', ' York', ' the', ' time', ' is', ' 8', ':', '00', ' PM', '.', ' In', ' Los', ' Angeles', ' the', ' time', ' is', ' 3', ' hours']
Tokenized answer: [' earlier']


Top 0th token. Logit: 17.19 Prob: 33.36% Token: |.|
Top 1th token. Logit: 16.77 Prob: 22.05% Token: | and|
Top 2th token. Logit: 15.49 Prob:  6.14% Token: |,|
Top 3th token. Logit: 13.96 Prob:  1.33% Token: | later|
Top 4th token. Logit: 13.92 Prob:  1.28% Token: | 9|
Top 5th token. Logit: 13.83 Prob:  1.17% Token: | 11|
Top 6th token. Logit: 13.77 Prob:  1.10% Token: | from|
Top 7th token. Logit: 13.70 Prob:  1.03% Token: | 6|
Top 8th token. Logit: 13.69 Prob:  1.02% Token: | before|
Top 9th token. Logit: 13.68 Prob:  1.00% Token: | 8|


Tokenized prompt: ['<|endoftext|>', 'In', ' New', ' York', ' the', ' time', ' is', ' eight', ' PM', '.', ' In', ' Los', ' Angeles', ' the', ' time', ' is']
Tokenized answer: [' five']


Top 0th token. Logit: 15.66 Prob:  7.70% Token: | 8|
Top 1th token. Logit: 15.61 Prob:  7.31% Token: | 10|
Top 2th token. Logit: 15.57 Prob:  7.07% Token: | 9|
Top 3th token. Logit: 15.37 Prob:  5.75% Token: | 11|
Top 4th token. Logit: 15.11 Prob:  4.44% Token: | midnight|
Top 5th token. Logit: 15.08 Prob:  4.32% Token: | 7|
Top 6th token. Logit: 14.99 Prob:  3.95% Token: | ten|
Top 7th token. Logit: 14.95 Prob:  3.78% Token: | 12|
Top 8th token. Logit: 14.87 Prob:  3.50% Token: | seven|
Top 9th token. Logit: 14.70 Prob:  2.95% Token: | five|


In [97]:
example_prompt = "John gets one new coin every day. Yesterday he had two coins. After getting a new coin, today he has"
example_answer = " three"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'John', ' gets', ' one', ' new', ' coin', ' every', ' day', '.', ' Yesterday', ' he', ' had', ' two', ' coins', '.', ' After', ' getting', ' a', ' new', ' coin', ',', ' today', ' he', ' has']
Tokenized answer: [' three']


Top 0th token. Logit: 17.26 Prob: 28.07% Token: | two|
Top 1th token. Logit: 16.76 Prob: 17.14% Token: | three|
Top 2th token. Logit: 16.69 Prob: 15.92% Token: | one|
Top 3th token. Logit: 15.73 Prob:  6.08% Token: | four|
Top 4th token. Logit: 15.39 Prob:  4.33% Token: | a|
Top 5th token. Logit: 14.86 Prob:  2.55% Token: | five|
Top 6th token. Logit: 14.71 Prob:  2.20% Token: | just|
Top 7th token. Logit: 14.57 Prob:  1.92% Token: | only|
Top 8th token. Logit: 14.28 Prob:  1.43% Token: | six|
Top 9th token. Logit: 14.07 Prob:  1.17% Token: | 2|


After trying a lot of math type of tasks, I realized gpt2-small is just not good at numbers. I'm going to try to find a different task related to language for this specific project, just so that I can get started.

However, I think there is an interesting question that I found which is how the model interprets tokens of numbers in numerical form (ex: 1,2,3) vs tokens of numbers in english (ex: one, two, three). I think there could be an interesting analysis done of how the model acts differently when the only difference in the prompt is how the number is represented. After experimenting for a few minutes, I see slight differences in the outputs.

In [125]:
example_prompt = "A mouse is small, but an elephant is"
example_answer = "big"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' mouse', ' is', ' small', ',', ' but', ' an', ' elephant', ' is']
Tokenized answer: [' big']


Top 0th token. Logit: 18.63 Prob: 26.38% Token: | big|
Top 1th token. Logit: 18.14 Prob: 16.10% Token: | large|
Top 2th token. Logit: 17.25 Prob:  6.64% Token: | huge|
Top 3th token. Logit: 17.23 Prob:  6.45% Token: | bigger|
Top 4th token. Logit: 17.00 Prob:  5.18% Token: | larger|
Top 5th token. Logit: 16.91 Prob:  4.69% Token: | a|
Top 6th token. Logit: 16.41 Prob:  2.87% Token: | small|
Top 7th token. Logit: 16.17 Prob:  2.25% Token: | much|
Top 8th token. Logit: 15.98 Prob:  1.85% Token: | very|
Top 9th token. Logit: 15.68 Prob:  1.38% Token: | tiny|


In [158]:
example_prompt = "A , but a sheet of paper is"
example_answer = "hard"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'A', ' tree', ' trunk', ' is', ' thick', ',', ' but', ' a', ' sheet', ' of', ' paper', ' is']
Tokenized answer: [' hard']


Top 0th token. Logit: 17.39 Prob: 11.62% Token: | not|
Top 1th token. Logit: 16.52 Prob:  4.85% Token: | thin|
Top 2th token. Logit: 16.40 Prob:  4.32% Token: | thick|
Top 3th token. Logit: 16.34 Prob:  4.05% Token: | a|
Top 4th token. Logit: 16.14 Prob:  3.32% Token: | just|
Top 5th token. Logit: 15.80 Prob:  2.36% Token: | still|
Top 6th token. Logit: 15.78 Prob:  2.31% Token: | the|
Top 7th token. Logit: 15.42 Prob:  1.62% Token: | thicker|
Top 8th token. Logit: 15.19 Prob:  1.28% Token: | very|
Top 9th token. Logit: 15.16 Prob:  1.25% Token: | much|


I am going to study the opposite adjactive task where the model has to output the opposite of an adjactive given context. Some example adjactives are old vs young and big vs small.

Since we are going to fix the output to one correct answer, I am going to use the logit difference metric and ignore synonyms of the answer that might work as well (for instance "new" instead of "young" might be outputed but we won't consider it a correct or incorrect answer to simplify the problem).

In [159]:
prompt_format = [
    "A mouse is{}, but, an elephant is",
    "A tortoise is{}, but a kitten is",
    "A desert is{}, but a rainforest is",
    "A pillow is{}, but, a brick is",
    "A rock is{}, but, a sponge is",
    "A rose is{}, but, a weed is"
]
names = [
    (" old", " young"),
    (" small", " big"),
    (" dry", " wet"),
    (" soft", " hard"),
    (" hard", " soft"),
    (" beautiful", " ugly")

]
# List of prompts
prompts = []
# List of answers, in the format (correct, incorrect)
answers = []
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompt_format)):
        answers.append((names[i][j], names[i][1 - j]))
        answer_tokens.append(
            (
                model.to_single_token(answers[-1][0]),
                model.to_single_token(answers[-1][1]),
            )
        )
        # Insert the *incorrect* answer to the prompt, making the correct answer the indirect object.
        prompts.append(prompt_format[i].format(answers[-1][1]))
answer_tokens = torch.tensor(answer_tokens).to(device)
print(prompts)
print(answers)

['A tortoise is old, but a kitten is', 'A mouse is small, but, an elephant is', 'A desert is dry, but a rainforest is', 'A pillow is soft, but, a brick is', 'A rock is hard, but, a sponge is', 'A rose is beautiful, but, a weed is']
[(' young', ' old'), (' big', ' small'), (' wet', ' dry'), (' hard', ' soft'), (' soft', ' hard'), (' ugly', ' beautiful')]


In [161]:
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' tort', 'oise', ' is', ' old', ',', ' but', ' a', ' kitten', ' is']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' mouse', ' is', ' small', ',', ' but', ',', ' an', ' elephant', ' is']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' desert', ' is', ' dry', ',', ' but', ' a', ' rain', 'forest', ' is']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' pillow', ' is', ' soft', ',', ' but', ',', ' a', ' brick', ' is']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' rock', ' is', ' hard', ',', ' but', ',', ' a', ' sponge', ' is']
Prompt length: 11
Prompt as tokens: ['<|endoftext|>', 'A', ' rose', ' is', ' beautiful', ',', ' but', ',', ' a', ' weed', ' is']


In [162]:
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

In [163]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(
    "Average logit difference:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),
)

Per prompt logit difference: tensor([1.7240, 2.0330, 0.4730, 0.1930, 0.0750, 1.0100])
Average logit difference: 0.918


Hmm, this logit difference is not that large compared to the IOI example. I think part of the reason is because of the ambuguity of the answer/opposite, for instance the model has some probability that the token is "large" as well as "big".

## Some general predictions as to what is happening.

I assume the most relevant tokens in the prompt for the first data point are, old, but, tortoise, and kitten. It seems the "but" is very helpful for signifying the opposite of old, since I tried replacing "but" with "then" and the logit difference was significantly smaller.

The model needs to figure out what the previous adjactive is and then predict the opposite. Therefore, maybe there is a head that attends the noun to adjective relationship and then some other component (MLP?) that is able to find the opposite of the adjective. This is likely combined with a similar layer that processes both the noun we want to describe (eg: "kitten") and also the opposite (eg: "old").

In [164]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([6, 2, 768])
Logit difference directions shape: torch.Size([6, 768])


In [165]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([6, 11, 768])
Calculated average logit diff: 1.218
Original logit difference: 0.918


Now, we'll take the logit lens approach to decompose where the logit differences are coming from.

In [166]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [167]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

This stream shows that most of the performance comes from attention layer 9 as well as the last layer which looks like layer 11.

More interestingly, the performance is negative up until layer 9, which means the model is more likely to predict the incorrect answer, the same adjictive in the prompt. I haven't seen this graph that stays in the negative so long and I think it might have to do with my hypothesis that the model has a lot of attention on the first noun/adjective pairing to learn that relationship.

Time to dig deeper.

In [168]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

This is similar to the IOI task where the attention makes the most difference in each layer since we also care a lot about moving the adjactive information around. One hypothesis for the decrease in performance in the earlier layers is that maybe the attention is simply moving the incorrect adjictive to the "is" token, but the MLP isn't able to process this that well. However in layer 9, maybe the attention is moving different information, like both the noun of "kitten" and the other adjictive "old" and the MLP has better information to figure out an opposite.

From this graph, I can better understand how the attention sets up the MLP to process information effectively.

In [169]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


Here we can see that the only heads that matter in a large way are, L9H7 and L10H7 as well as L10H1 in the negative direction. This shows that the drop in the earlier layers are due to a combination of multiple heads which have realitively small logit difference effects of about -0.1 , especially in layer 8.

Let's dig deeper and do some attention head analysis.

In [170]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [171]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

## Note

It's at this point where I'm wondering how solid of a task this was to predict because it has to introduce new information that isn't in the prompt. My intuition was that there could be an algorithm that combines some form of lookup information with an algorithm that finds the opposite of the adjictive in the prompt("incorrect answer") and something that is similar to the second noun ("kitten).

Taking a look at these heads, I see that the most impactful head attends old to but, a and is. This seems to act as the opposite head, given that old and but are attended to pretty highly.

The negative heads, specifically L10H1 and L9H6, seem to attend the first noun ("tortoise") with the "a" describing kitten. These heads seem mainly focused on relating the two nouns but I am confused why it would attend with "a" instead of "kitten" here.