In [1]:
%pip install transformer_lens
%pip install circuitsvis
# Install a faster Node version
!curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs  # noqa

Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Password:
sudo: a password is required


In [1]:
from functools import partial
from typing import List, Optional, Union

import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from IPython.display import HTML, IFrame
from jaxtyping import Float

import transformer_lens.utils as utils
from transformer_lens import ActivationCache, HookedTransformer

In [2]:
def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

In [3]:
torch.set_grad_enabled(False)
print("Disabled automatic differentiation")

Disabled automatic differentiation


In [4]:
# NBVAL_IGNORE_OUTPUT
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

# Get the default device used
device: torch.device = utils.get_device()

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
example_prompt = "Graphics Processing Unit (GPU) Thermal Design Power"
example_answer = " (TDP)"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Graphics', ' Processing', ' Unit', ' (', 'GPU', ')', ' Thermal', ' Design', ' Power']
Tokenized answer: [' (', 'T', 'DP', ')']


Top 0th token. Logit: 18.62 Prob: 60.69% Token: | (|
Top 1th token. Logit: 17.65 Prob: 23.14% Token: | Consumption|
Top 2th token. Logit: 14.49 Prob:  0.99% Token: | Management|
Top 3th token. Logit: 14.43 Prob:  0.93% Token: |:|
Top 4th token. Logit: 14.34 Prob:  0.84% Token: | Supply|
Top 5th token. Logit: 14.01 Prob:  0.61% Token: | Generation|
Top 6th token. Logit: 13.94 Prob:  0.57% Token: | Usage|
Top 7th token. Logit: 13.77 Prob:  0.48% Token: | Level|
Top 8th token. Logit: 13.39 Prob:  0.33% Token: | Architecture|
Top 9th token. Logit: 13.34 Prob:  0.31% Token: |
|


Top 0th token. Logit: 20.30 Prob: 58.46% Token: |T|
Top 1th token. Logit: 19.01 Prob: 16.10% Token: |W|
Top 2th token. Logit: 18.45 Prob:  9.13% Token: |TC|
Top 3th token. Logit: 17.25 Prob:  2.77% Token: |Ther|
Top 4th token. Logit: 16.92 Prob:  1.99% Token: |WH|
Top 5th token. Logit: 16.36 Prob:  1.14% Token: |GPU|
Top 6th token. Logit: 15.71 Prob:  0.59% Token: |tu|
Top 7th token. Logit: 15.40 Prob:  0.43% Token: |CPU|
Top 8th token. Logit: 15.26 Prob:  0.38% Token: | watts|
Top 9th token. Logit: 15.24 Prob:  0.37% Token: |w|


Top 0th token. Logit: 28.60 Prob: 99.93% Token: |DP|
Top 1th token. Logit: 19.54 Prob:  0.01% Token: |PD|
Top 2th token. Logit: 19.51 Prob:  0.01% Token: |PU|
Top 3th token. Logit: 19.20 Prob:  0.01% Token: |d|
Top 4th token. Logit: 18.82 Prob:  0.01% Token: |DR|
Top 5th token. Logit: 18.18 Prob:  0.00% Token: |DD|
Top 6th token. Logit: 18.01 Prob:  0.00% Token: |dp|
Top 7th token. Logit: 17.94 Prob:  0.00% Token: |DC|
Top 8th token. Logit: 17.72 Prob:  0.00% Token: |DF|
Top 9th token. Logit: 17.71 Prob:  0.00% Token: |DB|


Top 0th token. Logit: 20.67 Prob: 94.10% Token: |)|
Top 1th token. Logit: 16.58 Prob:  1.57% Token: |),|
Top 2th token. Logit: 16.28 Prob:  1.17% Token: |):|
Top 3th token. Logit: 15.55 Prob:  0.56% Token: |/|
Top 4th token. Logit: 15.03 Prob:  0.33% Token: |).|
Top 5th token. Logit: 14.80 Prob:  0.26% Token: |)/|
Top 6th token. Logit: 14.61 Prob:  0.22% Token: | )|
Top 7th token. Logit: 14.52 Prob:  0.20% Token: |)-|
Top 8th token. Logit: 14.32 Prob:  0.16% Token: |-|
Top 9th token. Logit: 14.15 Prob:  0.14% Token: |)(|


In [6]:
example_prompt = "Marvel Cinematic Universe (MCU) Intensive Care Unit"
example_answer = " (ICU)"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'Marvel', ' Cinem', 'atic', ' Universe', ' (', 'MC', 'U', ')', ' Int', 'ensive', ' Care', ' Unit']
Tokenized answer: [' (', 'IC', 'U', ')']


Top 0th token. Logit: 14.94 Prob: 46.14% Token: | (|
Top 1th token. Logit: 13.31 Prob:  9.09% Token: |
|
Top 2th token. Logit: 12.59 Prob:  4.40% Token: |,|
Top 3th token. Logit: 11.59 Prob:  1.62% Token: |.|
Top 4th token. Logit: 11.53 Prob:  1.53% Token: | -|
Top 5th token. Logit: 11.11 Prob:  1.00% Token: |:|
Top 6th token. Logit: 10.97 Prob:  0.88% Token: | is|
Top 7th token. Logit: 10.68 Prob:  0.65% Token: | for|
Top 8th token. Logit: 10.50 Prob:  0.55% Token: | at|
Top 9th token. Logit: 10.45 Prob:  0.52% Token: | in|


Top 0th token. Logit: 18.13 Prob: 74.63% Token: |IC|
Top 1th token. Logit: 16.01 Prob:  8.92% Token: |I|
Top 2th token. Logit: 14.50 Prob:  1.97% Token: |ICE|
Top 3th token. Logit: 13.98 Prob:  1.17% Token: |IR|
Top 4th token. Logit: 13.53 Prob:  0.75% Token: |AC|
Top 5th token. Logit: 13.40 Prob:  0.66% Token: |IV|
Top 6th token. Logit: 13.38 Prob:  0.64% Token: |IM|
Top 7th token. Logit: 13.25 Prob:  0.57% Token: |ICS|
Top 8th token. Logit: 13.16 Prob:  0.52% Token: |ICES|
Top 9th token. Logit: 13.16 Prob:  0.52% Token: |ICA|


Top 0th token. Logit: 25.64 Prob: 99.09% Token: |U|
Top 1th token. Logit: 20.06 Prob:  0.37% Token: |Us|
Top 2th token. Logit: 18.17 Prob:  0.06% Token: |SU|
Top 3th token. Logit: 18.04 Prob:  0.05% Token: |V|
Top 4th token. Logit: 18.01 Prob:  0.05% Token: |MU|
Top 5th token. Logit: 17.75 Prob:  0.04% Token: |UP|
Top 6th token. Logit: 17.63 Prob:  0.03% Token: |US|
Top 7th token. Logit: 17.42 Prob:  0.03% Token: |UF|
Top 8th token. Logit: 17.30 Prob:  0.02% Token: |UC|
Top 9th token. Logit: 17.16 Prob:  0.02% Token: |UB|


Top 0th token. Logit: 18.30 Prob: 85.05% Token: |)|
Top 1th token. Logit: 15.94 Prob:  7.99% Token: |),|
Top 2th token. Logit: 14.17 Prob:  1.37% Token: |-|
Top 3th token. Logit: 14.01 Prob:  1.16% Token: |).|
Top 4th token. Logit: 13.70 Prob:  0.85% Token: |):|
Top 5th token. Logit: 12.93 Prob:  0.39% Token: |);|
Top 6th token. Logit: 12.77 Prob:  0.34% Token: |)-|
Top 7th token. Logit: 12.58 Prob:  0.28% Token: |,|
Top 8th token. Logit: 12.31 Prob:  0.21% Token: |)/|
Top 9th token. Logit: 12.28 Prob:  0.21% Token: |/|


In [7]:
example_prompt = "Tensor Processing Unit (TPU) Central Processing Unit"
example_answer = " (CPU)"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'T', 'ensor', ' Processing', ' Unit', ' (', 'T', 'PU', ')', ' Central', ' Processing', ' Unit']
Tokenized answer: [' (', 'CPU', ')']


Top 0th token. Logit: 16.59 Prob: 83.20% Token: | (|
Top 1th token. Logit: 12.64 Prob:  1.60% Token: |
|
Top 2th token. Logit: 12.51 Prob:  1.41% Token: |,|
Top 3th token. Logit: 11.96 Prob:  0.82% Token: | is|
Top 4th token. Logit: 11.96 Prob:  0.81% Token: |.|
Top 5th token. Logit: 11.92 Prob:  0.78% Token: | for|
Top 6th token. Logit: 10.89 Prob:  0.28% Token: | -|
Top 7th token. Logit: 10.87 Prob:  0.27% Token: | in|
Top 8th token. Logit: 10.78 Prob:  0.25% Token: |:|
Top 9th token. Logit: 10.76 Prob:  0.25% Token: | or|


Top 0th token. Logit: 18.71 Prob: 48.12% Token: |CP|
Top 1th token. Logit: 18.22 Prob: 29.32% Token: |CPU|
Top 2th token. Logit: 16.48 Prob:  5.17% Token: |C|
Top 3th token. Logit: 16.39 Prob:  4.71% Token: |CM|
Top 4th token. Logit: 15.77 Prob:  2.53% Token: |PC|
Top 5th token. Logit: 15.15 Prob:  1.36% Token: |TC|
Top 6th token. Logit: 14.93 Prob:  1.10% Token: |P|
Top 7th token. Logit: 14.92 Prob:  1.09% Token: |T|
Top 8th token. Logit: 14.87 Prob:  1.03% Token: |CT|
Top 9th token. Logit: 14.00 Prob:  0.43% Token: |CC|


Top 0th token. Logit: 18.93 Prob: 57.86% Token: |)|
Top 1th token. Logit: 17.82 Prob: 19.16% Token: |U|
Top 2th token. Logit: 17.38 Prob: 12.26% Token: | Unit|
Top 3th token. Logit: 16.06 Prob:  3.30% Token: |),|
Top 4th token. Logit: 14.59 Prob:  0.76% Token: | unit|
Top 5th token. Logit: 14.48 Prob:  0.68% Token: |):|
Top 6th token. Logit: 14.29 Prob:  0.56% Token: |Unit|
Top 7th token. Logit: 14.28 Prob:  0.55% Token: |).|
Top 8th token. Logit: 13.99 Prob:  0.41% Token: |)-|
Top 9th token. Logit: 13.52 Prob:  0.26% Token: |-|


In [151]:
prompts = ['Graphics Processing Unit (GPU) Central Processing Unit',
                     'Central Processing Unit (CPU) Graphics Processing Unit',
                     'Marvel Cinematic Universe (MCU) Intensive Care Unit',
                     'Intensive Care Unit (ICU) Marvel Cinematic Universe',
                     'Tensor Processing Unit (TPU) Thermal Design Power',
                     'Thermal Design Power (TDP) Tensor Processing Unit']

answers = [" (CPU)", " (GPU)", " (ICU)", " (MCU)", " (TDP)", " (TPU)"]
# List of the token (ie an integer) corresponding to each answer, in the format (correct_token, incorrect_token)
answer_tokens = []
for i in range(len(prompts)):
    answer_tokens.append(model.to_tokens(answers[i]))
answer_tokens[0] = torch.cat([torch.tensor(50256).reshape(1, 1), answer_tokens[0]], dim=1)
answer_tokens[1] = torch.cat([torch.tensor(50256).reshape(1, 1), answer_tokens[1]], dim=1)
answer_tokens = torch.cat(answer_tokens).to(device)
print(prompts)
print(answers)


['Graphics Processing Unit (GPU) Central Processing Unit', 'Central Processing Unit (CPU) Graphics Processing Unit', 'Marvel Cinematic Universe (MCU) Intensive Care Unit', 'Intensive Care Unit (ICU) Marvel Cinematic Universe', 'Tensor Processing Unit (TPU) Thermal Design Power', 'Thermal Design Power (TDP) Tensor Processing Unit']
[' (CPU)', ' (GPU)', ' (ICU)', ' (MCU)', ' (TDP)', ' (TPU)']


In [90]:
model.to_tokens(answers[3])

tensor([[50256,   357,  9655,    52,     8]])

In [152]:
answer_tokens

tensor([[50256, 50256,   357, 36037,     8],
        [50256, 50256,   357, 33346,     8],
        [50256,   357,  2149,    52,     8],
        [50256,   357,  9655,    52,     8],
        [50256,   357,    51,  6322,     8],
        [50256,   357,    51,  5105,     8]])

In [153]:
for prompt in prompts:
    str_tokens = model.to_str_tokens(prompt)
    print("Prompt length:", len(str_tokens))
    print("Prompt as tokens:", str_tokens)

Prompt length: 10
Prompt as tokens: ['<|endoftext|>', 'Graphics', ' Processing', ' Unit', ' (', 'GPU', ')', ' Central', ' Processing', ' Unit']
Prompt length: 10
Prompt as tokens: ['<|endoftext|>', 'Central', ' Processing', ' Unit', ' (', 'CPU', ')', ' Graphics', ' Processing', ' Unit']
Prompt length: 13
Prompt as tokens: ['<|endoftext|>', 'Marvel', ' Cinem', 'atic', ' Universe', ' (', 'MC', 'U', ')', ' Int', 'ensive', ' Care', ' Unit']
Prompt length: 13
Prompt as tokens: ['<|endoftext|>', 'Int', 'ensive', ' Care', ' Unit', ' (', 'IC', 'U', ')', ' Marvel', ' Cinem', 'atic', ' Universe']
Prompt length: 12
Prompt as tokens: ['<|endoftext|>', 'T', 'ensor', ' Processing', ' Unit', ' (', 'T', 'PU', ')', ' Thermal', ' Design', ' Power']
Prompt length: 13
Prompt as tokens: ['<|endoftext|>', 'Ther', 'mal', ' Design', ' Power', ' (', 'T', 'DP', ')', ' T', 'ensor', ' Processing', ' Unit']


In [154]:
model.tokenizer.padding_side = 'left'
tokens = model.to_tokens(prompts, prepend_bos=True)

# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

In [155]:
# tokens = tokens[0][1:]

In [156]:
original_logits.shape

torch.Size([6, 13, 50257])

In [162]:
def logits_to_ave_logit_diff(logits, answer_tokens, per_prompt=False):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    if per_prompt:
        return answer_logit_diff
    else:
        return answer_logit_diff.mean()


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
    .detach()
    .cpu()
    .round(decimals=3),
)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print(
    "Average logit difference:",
    round(logits_to_ave_logit_diff(original_logits, answer_tokens).item(), 3),
)

Per prompt logit difference: tensor([ 0.0000,  0.0000, -6.3000, -5.8250, -7.0930, -9.6010])
Average logit difference: -4.803


In [163]:
answer_tokens

tensor([[50256, 50256,   357, 36037,     8],
        [50256, 50256,   357, 33346,     8],
        [50256,   357,  2149,    52,     8],
        [50256,   357,  9655,    52,     8],
        [50256,   357,    51,  6322,     8],
        [50256,   357,    51,  5105,     8]])

In [164]:
answer_residual_directions = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([6, 5, 768])
Logit difference directions shape: torch.Size([6, 768])


In [165]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", round(average_logit_diff.item(), 3))
print("Original logit difference:", round(original_average_logit_diff.item(), 3))

Final residual stream shape: torch.Size([6, 13, 768])
Calculated average logit diff: -3.565
Original logit difference: -4.803


In [166]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)

In [167]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

In [168]:
logit_lens_logit_diffs

tensor([-2.2160e-03, -1.5325e-01, -2.1128e-01, -3.2699e-01, -2.8078e-01,
        -4.1179e-01, -3.8640e-01, -4.7184e-01, -5.2639e-01, -5.9478e-01,
        -6.4991e-01, -8.3307e-01, -9.0275e-01, -9.9757e-01, -1.1557e+00,
        -1.3785e+00, -2.2711e+00, -2.5055e+00, -2.7035e+00, -3.1200e+00,
        -3.3113e+00, -3.4565e+00, -3.4582e+00, -3.8678e+00, -3.5653e+00])

In [194]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

In [170]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


In [171]:
def visualize_attention_patterns(
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]],
    local_cache: ActivationCache,
    local_tokens: torch.Tensor,
    title: Optional[str] = "",
    max_width: Optional[int] = 700,
) -> str:
    # If a single head is given, convert to a list
    if isinstance(heads, int):
        heads = [heads]

    # Create the plotting data
    labels: List[str] = []
    patterns: List[Float[torch.Tensor, "dest_pos src_pos"]] = []

    # Assume we have a single batch item
    batch_index = 0

    for head in heads:
        # Set the label
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        labels.append(f"L{layer}H{head_index}")

        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])

    # Convert the tokens to strings (for the axis labels)
    str_tokens = model.to_str_tokens(local_tokens)

    # Combine the patterns into a single tensor
    patterns: Float[torch.Tensor, "head_index dest_pos src_pos"] = torch.stack(
        patterns, dim=0
    )

    # Circuitsvis Plot (note we get the code version so we can concatenate with the title)
    plot = attention_heads(
        attention=patterns, tokens=str_tokens, attention_head_names=labels
    ).show_code()

    # Display the title
    title_html = f"<h2>{title}</h2><br/>"

    # Return the visualisation as raw code
    return f"<div style='max-width: {str(max_width)}px;'>{title_html + plot}</div>"

In [172]:
per_head_logit_diffs.shape

torch.Size([12, 12])

In [173]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[0],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[0],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

In [195]:
top_k = 3

top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

positive_html = visualize_attention_patterns(
    top_positive_logit_attr_heads,
    cache,
    tokens[2],
    f"Top {top_k} Positive Logit Attribution Heads",
)

top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

negative_html = visualize_attention_patterns(
    top_negative_logit_attr_heads,
    cache,
    tokens[2],
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

In [175]:
top_positive_logit_attr_heads

tensor([136,  55, 129])

In [176]:
top_k = 10
top_heads_by_output_patch = torch.topk(
    per_head_logit_diffs.abs().flatten(), k=top_k
).indices
first_mid_layer = 7
first_late_layer = 9
early_heads = top_heads_by_output_patch[
    top_heads_by_output_patch < model.cfg.n_heads * first_mid_layer
]
mid_heads = top_heads_by_output_patch[
    torch.logical_and(
        model.cfg.n_heads * first_mid_layer <= top_heads_by_output_patch,
        top_heads_by_output_patch < model.cfg.n_heads * first_late_layer,
    )
]
late_heads = top_heads_by_output_patch[
    model.cfg.n_heads * first_late_layer <= top_heads_by_output_patch
]

early = visualize_attention_patterns(
    early_heads, cache, tokens[0], title=f"Top Early Heads"
)
mid = visualize_attention_patterns(
    mid_heads, cache, tokens[0], title=f"Top Middle Heads"
)
late = visualize_attention_patterns(
    late_heads, cache, tokens[0], title=f"Top Late Heads"
)

HTML(early + mid + late)

In [177]:
corrupted_prompts = []
for i in range(0, len(prompts), 2):
    corrupted_prompts.append(prompts[i + 1])
    corrupted_prompts.append(prompts[i])
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(
    corrupted_tokens, return_type="logits"
)
corrupted_average_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print("Corrupted Average Logit Diff", round(corrupted_average_logit_diff.item(), 2))
print("Clean Average Logit Diff", round(original_average_logit_diff.item(), 2))

Corrupted Average Logit Diff -4.8
Clean Average Logit Diff -4.8


In [178]:
corrupted_prompts

['Central Processing Unit (CPU) Graphics Processing Unit',
 'Graphics Processing Unit (GPU) Central Processing Unit',
 'Intensive Care Unit (ICU) Marvel Cinematic Universe',
 'Marvel Cinematic Universe (MCU) Intensive Care Unit',
 'Thermal Design Power (TDP) Tensor Processing Unit',
 'Tensor Processing Unit (TPU) Thermal Design Power']

In [179]:
answers

[' (CPU)', ' (GPU)', ' (ICU)', ' (MCU)', ' (TDP)', ' (TPU)']

In [180]:
tokens.shape[1]

13

In [181]:
from tqdm import trange

In [198]:
def patch_residual_component(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    clean_cache,
):
    corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_residual_component


def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_average_logit_diff) / (
        original_average_logit_diff - corrupted_average_logit_diff
    )


patched_residual_stream_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    print(layer)
    for position in trange(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_residual_stream_diff[layer, position] = normalize_patched_logit_diff(
            patched_logit_diff
        )

0


100%|██████████| 13/13 [00:03<00:00,  4.28it/s]


1


100%|██████████| 13/13 [00:01<00:00,  9.55it/s]


2


100%|██████████| 13/13 [00:01<00:00,  9.89it/s]


3


100%|██████████| 13/13 [00:01<00:00,  8.79it/s]


4


100%|██████████| 13/13 [00:02<00:00,  6.03it/s]


5


100%|██████████| 13/13 [00:01<00:00,  8.28it/s]


6


100%|██████████| 13/13 [00:01<00:00,  9.57it/s]


7


100%|██████████| 13/13 [00:02<00:00,  5.64it/s]


8


100%|██████████| 13/13 [00:01<00:00,  9.31it/s]


9


100%|██████████| 13/13 [00:01<00:00,  9.30it/s]


10


100%|██████████| 13/13 [00:01<00:00,  8.97it/s]


11


100%|██████████| 13/13 [00:01<00:00,  9.53it/s]


In [199]:
prompt_position_labels = [
    f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))
]
imshow(
    patched_residual_stream_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Residual Stream",
    labels={"x": "Position", "y": "Layer"},
)

In [202]:
patched_residual_stream_diff

tensor([[ 3.3270e+05,  4.7066e+05, -1.2063e+05, -1.7095e+05,  7.7157e+05,
         -8.9904e+05, -6.9152e+05, -4.7800e+05,  1.5626e+05, -3.9235e+05,
         -3.8463e+05, -1.1805e+06,  1.0797e+06],
        [-8.7870e+03, -1.2326e+06, -1.0061e+04, -3.2515e+05,  7.2304e+05,
         -4.4773e+05, -1.0538e+06, -5.2583e+05,  9.6378e+04, -6.3585e+05,
         -1.9828e+05, -1.2517e+06,  1.0430e+06],
        [-4.5043e+04, -1.5677e+06, -3.7690e+05, -3.3704e+05,  7.8790e+05,
         -2.0993e+05, -6.9971e+05,  6.9005e+04,  6.9241e+04, -4.9787e+05,
          3.4453e+04, -1.2280e+06,  1.1007e+06],
        [-4.7094e+05, -1.8602e+06,  2.8530e+04, -2.1974e+04,  7.8844e+05,
          9.6728e+04,  1.5675e+04,  5.3482e+05, -1.8675e+05, -9.7673e+04,
          5.5739e+05, -8.6741e+05,  1.2332e+06],
        [-4.5400e+05, -2.0782e+06,  1.0065e+05,  1.6831e+05,  6.8823e+05,
          4.2365e+05, -2.3045e+05,  4.7815e+05,  5.1781e+04, -4.8738e+04,
          9.2049e+05, -8.5667e+05,  1.1097e+06],
        [-3.918

In [184]:
patched_attn_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
patched_mlp_diff = torch.zeros(
    model.cfg.n_layers, tokens.shape[1], device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for position in trange(tokens.shape[1]):
        hook_fn = partial(patch_residual_component, pos=position, clean_cache=cache)
        patched_attn_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_attn_logit_diff = logits_to_ave_logit_diff(
            patched_attn_logits, answer_tokens
        )
        patched_mlp_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("mlp_out", layer), hook_fn)],
            return_type="logits",
        )
        patched_mlp_logit_diff = logits_to_ave_logit_diff(
            patched_mlp_logits, answer_tokens
        )

        patched_attn_diff[layer, position] = normalize_patched_logit_diff(
            patched_attn_logit_diff
        )
        patched_mlp_diff[layer, position] = normalize_patched_logit_diff(
            patched_mlp_logit_diff
        )

100%|██████████| 13/13 [00:02<00:00,  4.77it/s]
100%|██████████| 13/13 [00:02<00:00,  5.39it/s]
100%|██████████| 13/13 [00:02<00:00,  5.34it/s]
100%|██████████| 13/13 [00:02<00:00,  5.26it/s]
100%|██████████| 13/13 [00:02<00:00,  5.19it/s]
100%|██████████| 13/13 [00:02<00:00,  5.14it/s]
100%|██████████| 13/13 [00:02<00:00,  5.16it/s]
100%|██████████| 13/13 [00:02<00:00,  5.24it/s]
100%|██████████| 13/13 [00:02<00:00,  5.05it/s]
100%|██████████| 13/13 [00:02<00:00,  4.85it/s]
100%|██████████| 13/13 [00:02<00:00,  4.95it/s]
100%|██████████| 13/13 [00:02<00:00,  5.27it/s]


In [185]:
imshow(
    patched_attn_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched Attention Layer",
    labels={"x": "Position", "y": "Layer"},
)

In [186]:
imshow(
    patched_mlp_diff,
    x=prompt_position_labels,
    title="Logit Difference From Patched MLP Layer",
    labels={"x": "Position", "y": "Layer"},
)

In [187]:
def patch_head_vector(
    corrupted_head_vector: Float[torch.Tensor, "batch pos head_index d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_vector[:, :, head_index, :] = clean_cache[hook.name][
        :, :, head_index, :
    ]
    return corrupted_head_vector


patched_head_z_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in trange(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("z", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_z_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

100%|██████████| 12/12 [00:01<00:00, 10.12it/s]
100%|██████████| 12/12 [00:01<00:00, 10.72it/s]
100%|██████████| 12/12 [00:01<00:00, 10.76it/s]
100%|██████████| 12/12 [00:01<00:00, 10.31it/s]
100%|██████████| 12/12 [00:01<00:00,  8.86it/s]
100%|██████████| 12/12 [00:01<00:00,  8.39it/s]
100%|██████████| 12/12 [00:01<00:00,  9.22it/s]
100%|██████████| 12/12 [00:01<00:00, 10.60it/s]
100%|██████████| 12/12 [00:01<00:00, 10.10it/s]
100%|██████████| 12/12 [00:01<00:00, 10.18it/s]
100%|██████████| 12/12 [00:01<00:00, 10.65it/s]
100%|██████████| 12/12 [00:01<00:00, 10.18it/s]


In [188]:
imshow(
    patched_head_z_diff,
    title="Logit Difference From Patched Head Output",
    labels={"x": "Head", "y": "Layer"},
)

In [189]:
patched_head_v_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_vector, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("v", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_v_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [190]:
imshow(
    patched_head_v_diff,
    title="Logit Difference From Patched Head Value",
    labels={"x": "Head", "y": "Layer"},
)

In [196]:
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_v_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    xaxis="Value Patch",
    yaxis="Output Patch",
    caxis="Layer",
    hover_name=head_labels,
    color=einops.repeat(
        np.arange(model.cfg.n_layers), "layer -> (layer head)", head=model.cfg.n_heads
    ),
    range_x=(-0.5, 0.5),
    range_y=(-0.5, 0.5),
    title="Scatter plot of output patching vs value patching",
)

In [192]:
def patch_head_pattern(
    corrupted_head_pattern: Float[torch.Tensor, "batch head_index query_pos d_head"],
    hook,
    head_index,
    clean_cache,
):
    corrupted_head_pattern[:, head_index, :, :] = clean_cache[hook.name][
        :, head_index, :, :
    ]
    return corrupted_head_pattern


patched_head_attn_diff = torch.zeros(
    model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=torch.float32
)
for layer in range(model.cfg.n_layers):
    for head_index in range(model.cfg.n_heads):
        hook_fn = partial(patch_head_pattern, head_index=head_index, clean_cache=cache)
        patched_logits = model.run_with_hooks(
            corrupted_tokens,
            fwd_hooks=[(utils.get_act_name("attn", layer, "attn"), hook_fn)],
            return_type="logits",
        )
        patched_logit_diff = logits_to_ave_logit_diff(patched_logits, answer_tokens)

        patched_head_attn_diff[layer, head_index] = normalize_patched_logit_diff(
            patched_logit_diff
        )

In [193]:
imshow(
    patched_head_attn_diff,
    title="Logit Difference From Patched Head Pattern",
    labels={"x": "Head", "y": "Layer"},
)
head_labels = [
    f"L{l}H{h}" for l in range(model.cfg.n_layers) for h in range(model.cfg.n_heads)
]
scatter(
    x=utils.to_numpy(patched_head_attn_diff.flatten()),
    y=utils.to_numpy(patched_head_z_diff.flatten()),
    hover_name=head_labels,
    xaxis="Attention Patch",
    yaxis="Output Patch",
    title="Scatter plot of output patching vs attention patching",
)