## Model & Task Setup

Import

In [105]:
import os; os.environ["ACCELERATE_DISABLE_RICH"] = "1"
import sys
from pathlib import Path
import torch as t
from torch import Tensor
import numpy as np
import einops
from tqdm.notebook import tqdm
import plotly.express as px
import webbrowser
import re
import itertools
from jaxtyping import Float, Int, Bool
from typing import List, Optional, Callable, Tuple, Dict, Literal, Set
from functools import partial
from IPython.display import display, HTML
from rich.table import Table, Column
from rich import print as rprint
import circuitsvis as cv
from pathlib import Path
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, ActivationCache
from transformer_lens.components import Embed, Unembed, LayerNorm, MLP
import functools

t.set_grad_enabled(False)

# Make sure exercises are in the path
chapter = r"chapter1_transformers"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = (exercises_dir / "part3_indirect_object_identification").resolve()
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, line, scatter, bar
import part3_indirect_object_identification.tests as tests

device = t.device("cuda") if t.cuda.is_available() else t.device("cpu")

MAIN = __name__ == "__main__"

Load model

In [106]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


In [107]:
# Here is where we test on a single prompt
# Result: 70% probability on Mary, as we expect

example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


In [108]:
prompt_format = [
    "When John and Mary went to the shops,{} gave the bag to",
    "When Tom and James went to the park,{} gave the ball to",
    "When Dan and Sid went to the shops,{} gave an apple to",
    "After Martin and Amy went to the park,{} gave a drink to",
]
name_pairs = [
    (" Mary", " John"),
    (" Tom", " James"),
    (" Dan", " Sid"),
    (" Martin", " Amy"),
]

# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)
prompts = [
    prompt.format(name) # knows that it's {}, native Pytohon string function
    for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1] 
]
# Define the answers for each prompt, in the form (correct, incorrect)
answers = [names[::i] for names in name_pairs for i in (1, -1)]
# Define the answer tokens (same shape as the answers)
answer_tokens = t.concat([
    model.to_tokens(names, prepend_bos=False).T for names in answers
])

rprint(prompts)
rprint(answers)
rprint(answer_tokens)

table = Table("Prompt", "Correct", "Incorrect", title="Prompts & Answers:")

for prompt, answer in zip(prompts, answers):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]))

rprint(table)

In [109]:
tokens = model.to_tokens(prompts, prepend_bos=True)
# Move the tokens to the GPU
tokens = tokens.to(device)
# Run the model and cache all activations
original_logits, cache = model.run_with_cache(tokens)

Key variables at this point: prompts (type LIST), tokens (tokenized prompts; type TENSOR), answers (type LIST), answer_tokens (type TENSOR)

Model variables (output of running on tokens): original_logits, cache

GPT2-small transformer variable: model

## Logit Difference (Performance Evaluation Function)

In [110]:
def logits_to_ave_logit_diff(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    per_prompt: bool = False
):
    '''
    Returns logit difference between the correct and incorrect answer. The logit difference is defined between the indirect object's name and the subject's name (e.g. logit(Mary) - logit(John))

    If per_prompt=True, return the array of differences rather than the average.
    '''
    batch, seq_len, d_vocab = logits.shape

    # Take answer token, find whether it's correct or incorrect
    correct = answer_tokens[:,0] # shape (batch) representing token for correct answer (Mary)
    incorrect = answer_tokens[:,1] # shape (batch) representing token for incorrect answer (John)

    last_logits = einops.rearrange(logits, "b s d -> s b d")[-1] # b, d
    
    # Take the logit associated with the correct and the incorrect
    # Could do some sort of t.index_select(last_logits, 2, ) but want to learn how to splice instead.

    correct_logits = last_logits[t.arange(batch), correct] # the reason why this isn't (batch, batch) is because you're chosing one sequence position for each element in your batch. It correlates with (b, d)

    # correct_logits = last_logits[t.arange(batch)][correct] # NOT THE SAME THING. Because you're indexing from the 0th dimension, and then indexing correct from 0th dimension again
    # correct_logits = last_logits[t.arange(batch)][:, correct] # NOT THE SAME THING. Take everything by dimension 0, and then index by correct dimension. returns something 2d, see your drawing notes.

    incorrect_logits = last_logits[t.arange(batch), incorrect] # indexed by i = t.arange(batch), j = correct (where j has shape batch, i has shape batch)

    # no summary
    # assumption that the thing will create a list of (i_1,j_1) coordinates from list i and j when you slice askldjfaksdfj[i,j] = askldjfaksdfj[(i,j)]

    logit_diff = correct_logits-incorrect_logits
    return logit_diff if per_prompt else logit_diff.mean() # Take a mean over batch dimension if per_prompt is False

tests.test_logits_to_ave_logit_diff(logits_to_ave_logit_diff)

original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

cols = [
    "Prompt", 
    Column("Correct", style="rgb(0,200,0) bold"), 
    Column("Incorrect", style="rgb(255,0,0) bold"), 
    Column("Logit Difference", style="bold")
]
table = Table(*cols, title="Logit differences")

for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):
    table.add_row(prompt, repr(answer[0]), repr(answer[1]), f"{logit_diff.item():.3f}")

rprint(table)

All tests in `test_logits_to_ave_logit_diff` passed!
Per prompt logit difference: tensor([3.3368, 3.2017, 2.7096, 3.7974, 1.7204, 5.2813, 2.6007, 5.7674],
       device='cuda:0')
Average logit difference: tensor(3.5519, device='cuda:0')


## Logit Attribution

Logit difference is the same as log probability difference. It also helps is understand which component we care about -- *which* name is the Indirect Object.

\begin{aligned}
&p_i=\operatorname{softmax}(\overrightarrow{\mathbf{x}})_i=\frac{e^{x_i}}{\sum_{i=1}^n e^{x_i}}\\
&L_i=\log p_i\\
&L_i=\log \frac{e^{x_i}}{\sum_{j=1}^n e^{x_k}}=x_i-\log \sum_{k=1}^n e^{x_k}\\
&L_i-L_j=x_i-x_j
\end{aligned}

In [111]:
answer_residual_directions: Float[Tensor, "batch 2 d_model"] = model.tokens_to_residual_directions(answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape) # same as previous shape, but now has directions embedded

correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)
logit_diff_directions: Float[Tensor, "batch d_model"] = correct_residual_directions - incorrect_residual_directions
print(f"Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


Get final residual stream values from cache object.

Apply layernorm scaling to these values.

Project them along the unembedding dimensions we care about (logit_diff_dimensions)

In [112]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type]. 

final_residual_stream: Float[Tensor, "batch seq d_model"] = cache["resid_post", -1]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)

average_logit_diff = einops.einsum(
    scaled_final_token_residual_stream, logit_diff_directions,
    "batch d_model, batch d_model ->"
) / len(prompts)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff:.10f}")

t.testing.assert_close(average_logit_diff, original_average_logit_diff)

Final residual stream shape: torch.Size([8, 15, 768])
Calculated average logit diff: 3.5519316196
Original logit difference:     3.5519280434


Key variables:

answer_residual_directions = *model.tokens_to_residual_directions*'d vector where we map answer tokens to that direction

incorrect_residual_directions = *answer_residual_directions.unbind(dim=1)*

average_logit_diff is calculated via final_residual_stream --> final_token_residual_stream --> scaled_final_token_residual_stream

**Logit Lens**

Looks at residual stream after each layer, calculates logit difference from that (assuming that is the logit layer). Acts it like the final layer. Simulate what happens if we delete all subsequence layers. A probe, basically.

In [113]:
def residual_stack_to_logit_diff(
    residual_stack: Float[Tensor, "... batch d_model"], 
    cache: ActivationCache,
    logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
) -> Float[Tensor, "..."]:
    '''
    Gets the avg logit difference between the correct and incorrect answer for a given 
    stack of components in the residual stream.
    '''
    # Apply layernorm
    batch_size = residual_stack.size(-2) # retrieves batch size of model
    print(batch_size)
    scaled_final_token_residual_stream = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)

    # Project them in the logit difference directions
    average_logit_diff = einops.einsum(
        scaled_final_token_residual_stream, logit_diff_directions,
        "... batch d_model, batch d_model -> ..."
    ) / batch_size

    return average_logit_diff


t.testing.assert_close(
    residual_stack_to_logit_diff(final_token_residual_stream, cache),
    original_average_logit_diff
)

8


* potentially do a tuned lens extension

**In residual stream**: n_pre means the residual stream at the start of layer n, n_mid means the residual stream after the attention part of layer n (n_post is the same as n+1_pre so is not included)

In [114]:
accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual, cache)

line(
    logit_lens_logit_diffs, 
    hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)

8


Thomas's viewpoint is that other things written to last token residual stream wouldn't affect this direction unembedding and would show up pre-6 (or maybe 7 and 8)

**Per layer** aka differences between adjacent residual stream. Here, the kth transformer block consists of an attention layer (move info around) and an MLP layer (process info).

In [115]:
per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
print(cache.decompose_resid)
per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)

line(
    per_layer_logit_diffs, 
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800
)

<bound method ActivationCache.decompose_resid of ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_ou

**Per head** aka sum of the outputs of each attention head.

**Key functions** *I love how everything standard is just a cache function away lol*
- cache.accumulated_resid
- cache.decompose_resid
- cache.stack_head_results
- cache.apply_ln_to_stack --> is this really "inherent" to class or does it act more like a static function where it doesn't actually take any of the stuff in cache --> more of a static method that doesn't leverage information from the cache

In [116]:
per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
print(per_head_residual.shape)
print(f"{len(labels)}, {labels}")
# shape (head*layer, batch_size, embedding_dimsnion)
per_head_residual = einops.rearrange(
    per_head_residual, 
    "(layer head) ... -> layer head ...", 
    layer=model.cfg.n_layers
)
per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)
print(per_head_logit_diffs.shape)

imshow(
    per_head_logit_diffs, 
    labels={"x":"Head", "y":"Layer"}, 
    title="Logit Difference From Each Head",
    width=600
)

Tried to stack head results when they weren't cached. Computing head results now
torch.Size([144, 8, 768])
144, ['L0H0', 'L0H1', 'L0H2', 'L0H3', 'L0H4', 'L0H5', 'L0H6', 'L0H7', 'L0H8', 'L0H9', 'L0H10', 'L0H11', 'L1H0', 'L1H1', 'L1H2', 'L1H3', 'L1H4', 'L1H5', 'L1H6', 'L1H7', 'L1H8', 'L1H9', 'L1H10', 'L1H11', 'L2H0', 'L2H1', 'L2H2', 'L2H3', 'L2H4', 'L2H5', 'L2H6', 'L2H7', 'L2H8', 'L2H9', 'L2H10', 'L2H11', 'L3H0', 'L3H1', 'L3H2', 'L3H3', 'L3H4', 'L3H5', 'L3H6', 'L3H7', 'L3H8', 'L3H9', 'L3H10', 'L3H11', 'L4H0', 'L4H1', 'L4H2', 'L4H3', 'L4H4', 'L4H5', 'L4H6', 'L4H7', 'L4H8', 'L4H9', 'L4H10', 'L4H11', 'L5H0', 'L5H1', 'L5H2', 'L5H3', 'L5H4', 'L5H5', 'L5H6', 'L5H7', 'L5H8', 'L5H9', 'L5H10', 'L5H11', 'L6H0', 'L6H1', 'L6H2', 'L6H3', 'L6H4', 'L6H5', 'L6H6', 'L6H7', 'L6H8', 'L6H9', 'L6H10', 'L6H11', 'L7H0', 'L7H1', 'L7H2', 'L7H3', 'L7H4', 'L7H5', 'L7H6', 'L7H7', 'L7H8', 'L7H9', 'L7H10', 'L7H11', 'L8H0', 'L8H1', 'L8H2', 'L8H3', 'L8H4', 'L8H5', 'L8H6', 'L8H7', 'L8H8', 'L8H9', 'L8H10', 'L8H11', 'L9H0

**Attention patterns**

note: A common mistake to make when looking at attention patterns is thinking that they must convey information about the token looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the residual stream position corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in ".", "!" or "?"

Can use either attention_patterns or attention_heads

In [151]:
def topk_of_Nd_tensor(tensor: Float[Tensor, "rows cols"], k: int):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = t.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()



k = 3

for head_type in ["Positive", "Negative"]:

    # Get the heads with largest (or smallest) contribution to the logit difference
    top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type=="Positive" else -1), k)

    # Get all their attention patterns
    attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
        cache["pattern", layer][:, head][0]
        for layer, head in top_heads
    ])

    # Display results
    display(HTML(f"<h2>Top {k} {head_type} Logit Attribution Heads</h2>"))
    display(cv.attention.attention_heads(
        attention = attn_patterns_for_important_heads,
        tokens = model.to_str_tokens(tokens[0]),
        attention_head_names = [f"{layer}.{head}" for layer, head in top_heads],
    ))

Duplicate name heads do not show up because they're indirectly responsible and not directly responsible

In [118]:
layer = 1
head = 2

print(cache["pattern", layer].shape) # 8 sentences, 12 heads, 15 source & 15 destination tokens (attention pattern)
cache["pattern", layer][:, head].shape  # taking every sentence, only head 2

torch.Size([8, 12, 15, 15])


torch.Size([8, 15, 15])

## Activation Patching

In [119]:
from transformer_lens import patching

Creating a metric

In [120]:
clean_tokens = tokens
# Swap each adjacent pair to get corrupted tokens
indices = [i+1 if i % 2 == 0 else i-1 for i in range(len(tokens))]
corrupted_tokens = clean_tokens[indices]

print(
    "Clean string 0:    ", model.to_string(clean_tokens[0]), "\n"
    "Corrupted string 0:", model.to_string(corrupted_tokens[0])
)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

# Why are these exact opposites? --> The tokens are exactly opposites
# clean_logit_diff = "Good performance" and corrupted_logit_diff = "bad performance"

Clean string 0:     <|endoftext|>When John and Mary went to the shops, John gave the bag to 
Corrupted string 0: <|endoftext|>When John and Mary went to the shops, Mary gave the bag to
Clean logit diff: 3.5519
Corrupted logit diff: -3.5519


In [121]:
def ioi_metric(
    logits: Float[Tensor, "batch seq d_vocab"], 
    answer_tokens: Float[Tensor, "batch 2"] = answer_tokens,
    corrupted_logit_diff: float = corrupted_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is 
    same as on corrupted input, and 1 when performance is same as on clean input.
    '''
    diff = logits_to_ave_logit_diff(logits, answer_tokens)
    return (diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)


t.testing.assert_close(ioi_metric(clean_logits).item(), 1.0)
t.testing.assert_close(ioi_metric(corrupted_logits).item(), 0.0)
t.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2).item(), 0.5)

Residual stream patching

In [122]:
act_patch_resid_pre = patching.get_act_patch_resid_pre(
    model = model,
    corrupted_tokens = corrupted_tokens,
    clean_cache = clean_cache,
    patching_metric = ioi_metric
)

labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]

imshow(
    act_patch_resid_pre, 
    labels={"x": "Position", "y": "Layer"},
    x=labels,
    title="resid_pre Activation Patching",
    width=600
)

  0%|          | 0/180 [00:00<?, ?it/s]

In [123]:
clean_cache['resid_pre', 0].shape

torch.Size([8, 15, 768])

Implement head-to-residual patching --> patches place from residual stream

In [124]:
def patch_residual_component(
    corrupted_residual_component: Float[Tensor, "batch pos d_model"],
    hook: HookPoint, 
    pos: int, 
    clean_cache: ActivationCache
) -> Float[Tensor, "batch pos d_model"]:
    '''
    Patches a given sequence position in the residual stream, using the value
    from the clean cache.
    '''
    # print(hook.name)
    ret = corrupted_residual_component.clone()
    ret[:, pos] = clean_cache[hook.name][ :, pos]
    return ret

def get_act_patch_resid_pre(
    model: HookedTransformer, 
    corrupted_tokens: Float[Tensor, "batch pos"], 
    clean_cache: ActivationCache, 
    patching_metric: Callable[[Float[Tensor, "batch pos d_vocab"]], float]
) -> Float[Tensor, "layer pos"]:
    '''
    Returns an array of results of patching each position at each layer in the residual
    stream, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    n_layers = model.cfg.n_layers
    n_pos = corrupted_tokens.shape[1]
    results = t.zeros((n_layers, n_pos)).to(device)
    for layer in range(n_layers):
        for pos in range(n_pos):
            model.reset_hooks()
            hook = functools.partial(patch_residual_component, pos=pos, clean_cache=clean_cache)
            logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(
                utils.get_act_name("resid_pre", layer), hook)]
            )
            results[layer, pos] = ioi_metric(logits)
    
    return results


act_patch_resid_pre_own = get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)

t.testing.assert_close(act_patch_resid_pre, act_patch_resid_pre_own)

In [125]:
imshow(
    act_patch_resid_pre_own, 
    x=labels, 
    title="Logit Difference From Patched Residual Stream", 
    labels={"x":"Sequence Position", "y":"Layer"},
    width=600 # If you remove this argument, the plot will usually fill the available space
)

Patch in residual stream by block -- just after the attention layer or just after the MLP

In [126]:
act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)

imshow(
    act_patch_block_every,
    x=labels, 
    facet_col=0, # This argument tells plotly which dimension to split into separate plots
    facet_labels=["Residual Stream", "Attn Output", "MLP Output"], # Subtitles of separate plots
    title="Logit Difference From Patched Attn Head Output", 
    labels={"x": "Sequence Position", "y": "Layer"},
    width=1000,
)

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

  0%|          | 0/180 [00:00<?, ?it/s]

In [127]:
# def patch_block_component(
    
# )

# def get_act_patch_block_every(
#     model: HookedTransformer, 
#     corrupted_tokens: Float[Tensor, "batch pos"], 
#     clean_cache: ActivationCache, 
#     patching_metric: Callable[[Float[Tensor, "batch pos d_vocab"]], float]
# ) -> Float[Tensor, "layer pos"]:
#     '''
#     Returns an array of results of patching each position at each layer in the residual
#     stream, using the value from the clean cache.

#     The results are calculated using the patching_metric function, which should be
#     called on the model's logit output.
#     '''
#     layer = model.cfg.n_layers
#     batch, pos = corrupted_tokens.shape

#     results = t.zeros((layer, pos))

#     for l in range(layer):
#         for p in range(pos):
#             model.reset_hooks()
#             logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
#                 utils.get_act_name("layer")
#             ])

# act_patch_block_every_own = get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)

# t.testing.assert_close(act_patch_block_every, act_patch_block_every_own)

# imshow(
#     act_patch_block_every_own,
#     x=labels, 
#     facet_col=0,
#     facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
#     title="Logit Difference From Patched Attn Head Output", 
#     labels={"x": "Sequence Position", "y": "Layer"},
#     width=1000
# )

Head patching

In [128]:
act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric
)

imshow(
    act_patch_attn_head_out_all_pos, 
    labels={"y": "Layer", "x": "Head"}, 
    title="attn_head_out Activation Patching (All Pos)",
    width=600
)

  0%|          | 0/144 [00:00<?, ?it/s]

Head patching

In [129]:
def patch_head_vector(
    corrupted_head_vector: Float[Tensor, "batch pos head_index d_head"],
    hook: HookPoint, 
    head_index: int, 
    clean_cache: ActivationCache
) -> Float[Tensor, "batch pos head_index d_head"]:
    '''
    Patches the output of a given head (before it's added to the residual stream) at
    every sequence position, using the value from the clean cache.
    '''
    # print(clean_cache[hook.name].shape) # batch pos head_index d_head
    corrupted_head_vector[:, :, head_index] = clean_cache[hook.name][:, :, head_index]
    return corrupted_head_vector

def get_act_patch_attn_head_out_all_pos(
    model: HookedTransformer, 
    corrupted_tokens: Float[Tensor, "batch pos"], 
    clean_cache: ActivationCache, 
    patching_metric: Callable
) -> Float[Tensor, "layer head"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer, using the value from the clean cache.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    n_layers = model.cfg.n_layers
    n_heads = model.cfg.n_heads
    results = t.zeros((n_layers, n_heads)).to(device)
    for layer in range(n_layers):
        for head in range(n_heads):
            model.reset_hooks()
            hook = functools.partial(patch_head_vector, head_index=head, clean_cache=clean_cache)
            logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[(
                utils.get_act_name("z", layer), hook)]
            )
            results[layer, head] = ioi_metric(logits)
    
    return results    


act_patch_attn_head_out_all_pos_own = get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)

t.testing.assert_close(act_patch_attn_head_out_all_pos, act_patch_attn_head_out_all_pos_own)

imshow(
    act_patch_attn_head_out_all_pos_own,
    title="Logit Difference From Patched Attn Head Output", 
    labels={"x":"Head", "y":"Layer"},
    width=600
)

Implement head-to-head-input patching

In [130]:
act_patch_attn_head_all_pos_every = patching.get_act_patch_attn_head_all_pos_every(
    model, 
    corrupted_tokens, 
    clean_cache, 
    ioi_metric
)

imshow(
    act_patch_attn_head_all_pos_every, 
    facet_col=0, 
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Activation Patching Per Head (All Pos)", 
    labels={"x": "Head", "y": "Layer"},
)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

In [143]:
# Get the heads with largest value patching
# (we know from plot above that these are the 4 heads in layers 7 & 8)
k = 4
top_heads = topk_of_Nd_tensor(act_patch_attn_head_all_pos_every[3], k=k)

# Get all their attention patterns
attn_patterns_for_important_heads: Float[Tensor, "head q k"] = t.stack([
    cache["pattern", layer][:, head].mean(0)
        for layer, head in top_heads
])

# Display results
display(HTML(f"<h2>Top {k} Logit Attribution Heads (from value-patching)</h2>"))
display(cv.attention.attention_patterns(
    attention = attn_patterns_for_important_heads,
    tokens = model.to_str_tokens(tokens[0]),
    attention_head_names = [f"{layer}.{head}" for layer, head in top_heads],
))

In [131]:
def patch_attn_patterns(
    corrupted_head_vector: Float[Tensor, "batch head_index pos_q pos_k"],
    hook: HookPoint, 
    head_index: int, 
    clean_cache: ActivationCache
) -> Float[Tensor, "batch pos head_index d_head"]:
    '''
    Patches the attn patterns of a given head at every sequence position, using 
    the value from the clean cache.
    '''
    corrupted_head_vector[:, head_index] = clean_cache[hook.name][:,head_index]

def get_act_patch_attn_head_all_pos_every(
    model: HookedTransformer,
    corrupted_tokens: Float[Tensor, "batch pos"],
    clean_cache: ActivationCache,
    patching_metric: Callable
) -> Float[Tensor, "layer head"]:
    '''
    Returns an array of results of patching at all positions for each head in each
    layer (using the value from the clean cache) for output, queries, keys, values
    and attn pattern in turn.

    The results are calculated using the patching_metric function, which should be
    called on the model's logit output.
    '''
    results = t.zeros(5, model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)
    # Loop over each component in turn
    for component_idx, component in enumerate(["z", "q", "k", "v", "pattern"]):
        for layer in tqdm(range(model.cfg.n_layers)):
            for head in range(model.cfg.n_heads):
                # Get different hook function if we're doing attention probs
                hook_fn_general = patch_attn_patterns if component == "pattern" else patch_head_vector
                hook_fn = partial(hook_fn_general, head_index=head, clean_cache=clean_cache)
                # Get patched logits
                patched_logits = model.run_with_hooks(
                    corrupted_tokens,
                    fwd_hooks = [(utils.get_act_name(component, layer), hook_fn)],
                    return_type="logits"
                )
                results[component_idx, layer, head] = patching_metric(patched_logits)

    return results

act_patch_attn_head_all_pos_every_own = get_act_patch_attn_head_all_pos_every(
    model,
    corrupted_tokens,
    clean_cache,
    ioi_metric
)

t.testing.assert_close(act_patch_attn_head_all_pos_every, act_patch_attn_head_all_pos_every_own)

imshow(
    act_patch_attn_head_all_pos_every_own,
    facet_col=0,
    facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
    title="Activation Patching Per Head (All Pos)",
    labels={"x": "Head", "y": "Layer"},
    width=1200
)

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

**Key variables:**
- clean_tokens, corrupted_tokens
- clean_logits, clean_cache when you run model with clean_tokens
- corrupted_logits, corrupted_cache when you run model with corrupted_tokens
- clean_logit_diff and corrupted_logit_diff for performance to eventually determine *ioi_metric*

**Key functions:**

TransformerLens’ patching modules: takes in input model, corrupted_tokens, clean_cache, patching_metric
- patching.get_act_patch_resid_pre() --> gets activation patching for 
- patching.get_act_patch_block_every()
- patching.get_act_patch_attn_head_all_pos_every

## Path Patching

In [132]:
from part3_indirect_object_identification.ioi_dataset import NAMES, IOIDataset

In [133]:
N = 25
ioi_dataset = IOIDataset(
    prompt_type="mixed",
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=str(device)
)
abc_dataset = ioi_dataset.gen_flipped_prompts("ABB->XYZ, BAB->XYZ")

Looking at the dataset

In [134]:
def format_prompt(sentence: str) -> str:
    '''Format a prompt by underlining names (for rich print)'''
    return re.sub("(" + "|".join(NAMES) + ")", lambda x: f"[u bold dark_orange]{x.group(0)}[/]", sentence) + "\n"


def make_table(cols, colnames, title="", n_rows=5, decimals=4):
    '''Makes and displays a table, from cols rather than rows (using rich print)'''
    table = Table(*colnames, title=title)
    rows = list(zip(*cols))
    f = lambda x: x if isinstance(x, str) else f"{x:.{decimals}f}"
    for row in rows[:n_rows]:
        table.add_row(*list(map(f, row)))
    rprint(table)

make_table(
    colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
    cols = [
        map(format_prompt, ioi_dataset.sentences), 
        model.to_string(ioi_dataset.s_tokenIDs).split(), 
        model.to_string(ioi_dataset.io_tokenIDs).split(), 
        map(format_prompt, abc_dataset.sentences), 
    ],
    title = "Sentences from IOI vs ABC distribution",
)

In [135]:
def logits_to_ave_logit_diff_2(logits: Float[Tensor, "batch seq d_vocab"], ioi_dataset: IOIDataset = ioi_dataset, per_prompt=False):
    '''
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    '''

    # Only the final logits are relevant for the answer
    # Get the logits corresponding to the indirect object / subject tokens respectively
    io_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.io_tokenIDs]
    s_logits: Float[Tensor, "batch"] = logits[range(logits.size(0)), ioi_dataset.word_idx["end"], ioi_dataset.s_tokenIDs]
    # Find logit difference
    answer_logit_diff = io_logits - s_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()



model.reset_hooks(including_permanent=True)

ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)
abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)

ioi_per_prompt_diff = logits_to_ave_logit_diff_2(ioi_logits_original, per_prompt=True)
abc_per_prompt_diff = logits_to_ave_logit_diff_2(abc_logits_original, per_prompt=True)

ioi_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits_original).item()
abc_average_logit_diff = logits_to_ave_logit_diff_2(abc_logits_original).item()

print(f"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}")
print(f"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}")

make_table(
    colnames = ["IOI prompt", "IOI logit diff", "ABC prompt", "ABC logit diff"],
    cols = [
        map(format_prompt, ioi_dataset.sentences), 
        ioi_per_prompt_diff,
        map(format_prompt, abc_dataset.sentences), 
        abc_per_prompt_diff,
    ],
    title = "Sentences from IOI vs ABC distribution",
)

Average logit diff (IOI dataset): 2.8053
Average logit diff (ABC dataset): -0.1693


In [136]:
def ioi_metric_2(
    logits: Float[Tensor, "batch seq d_vocab"],
    clean_logit_diff: float = ioi_average_logit_diff,
    corrupted_logit_diff: float = abc_average_logit_diff,
    ioi_dataset: IOIDataset = ioi_dataset,
) -> float:
    '''
    We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset), 
    and -1 when performance has been destroyed (i.e. is same as ABC dataset).
    '''
    patched_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)
    return (patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)


print(f"IOI metric (IOI dataset): {ioi_metric_2(ioi_logits_original):.4f}")
print(f"IOI metric (ABC dataset): {ioi_metric_2(abc_logits_original):.4f}")

IOI metric (IOI dataset): 0.0000
IOI metric (ABC dataset): -1.0000


Exercise: Path patching for name mover heads

In [137]:
abc_cache['blocks.0.attn.hook_z'].shape # batch, seq_ln, num_heads, head_dim

torch.Size([25, 21, 12, 64])

In [138]:
def hook_and_patch_head_output(
    activation: Float[Tensor, "batch seq n_head embed"], # can also pass cache in, so you don't have to do functools.partial each time --> cache would be same for all of these
    hook: HookPoint,
    new_activation: Float[Tensor, "batch seq embed"],
    head: int
) -> Float[Tensor, "batch seq n_head embed"]:
    activation[:,:,head,:] = new_activation
    return activation

def get_path_patch_head_to_final_resid_post(
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = abc_dataset,
    orig_dataset: IOIDataset = ioi_dataset,
    new_cache: Optional[ActivationCache] = abc_cache,
    orig_cache: Optional[ActivationCache] = ioi_cache,
) -> Float[Tensor, "layer head"]:

    layers = model.cfg.n_layers
    heads = model.cfg.n_heads

    ret = t.zeros(layers, heads)

    for corrupted_layer in tqdm(range(layers)):
        for corrupted_head in range(heads):

            model.reset_hooks()
            list_of_hooks = []
            
            # Find the corruption
            corrupt_hook = functools.partial(
                hook_and_patch_head_output, 
                new_activation = new_cache["z", corrupted_layer][:,:,corrupted_head], 
                head = corrupted_head
            )
            corrupt_input_name = utils.get_act_name("z", corrupted_layer)
            list_of_hooks.append((corrupt_input_name, corrupt_hook))

            # Make sure all the other heads be patched in cursed O(n^2) time
            for non_corrupted_layer in range(layers):
                for non_corrupted_head in range(heads):
                    # if (not layer == corrupted_layer) and (not head == corrupted_head):
                    if not (non_corrupted_layer, non_corrupted_head) == (corrupted_layer, corrupted_head):
                        clean_hook = functools.partial(
                            hook_and_patch_head_output, 
                            new_activation = orig_cache["z", non_corrupted_layer][:, :, non_corrupted_head], 
                            head = non_corrupted_head
                        )
                        clean_input_name = utils.get_act_name("z", non_corrupted_layer)
                        list_of_hooks.append((clean_input_name, clean_hook))
            
            # Return the result
            logits = model.run_with_hooks(orig_dataset.toks, fwd_hooks = list_of_hooks, return_type = "logits")
            ret[corrupted_layer, corrupted_head] = patching_metric(logits)

    return ret

path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(model, ioi_metric_2)

imshow(
    100 * path_patch_head_to_final_resid_post,
    title="Direct effect on logit difference",
    labels={"x":"Head", "y":"Layer", "color": "Logit diff. variation"},
    coloraxis=dict(colorbar_ticksuffix = "%"),
    width=600,
)

  0%|          | 0/12 [00:00<?, ?it/s]

In [147]:
def hook_and_patch_zqkv(
    activation: Float[Tensor, "batch seq n_head embed"], # can also pass cache in, so you don't have to do functools.partial each time --> cache would be same for all of these
    hook: HookPoint,
    new_activation: Float[Tensor, "batch seq embed"],
    head: int
) -> Float[Tensor, "batch seq n_head embed"]:
    activation[:,:,head,:] = new_activation
    return activation

def get_path_patch_head_to_heads(
    receiver_heads: List[Tuple[int, int]],
    receiver_input: str,
    model: HookedTransformer,
    patching_metric: Callable,
    new_dataset: IOIDataset = abc_dataset,
    orig_dataset: IOIDataset = ioi_dataset,
    new_cache: Optional[ActivationCache] = abc_cache,
    orig_cache: Optional[ActivationCache] = ioi_cache,
) -> Float[Tensor, "layer head"]:
    '''
    Performs path patching (see algorithm in appendix B of IOI paper), with:

        sender head = (each head, looped through, one at a time)
        receiver node = input to a later head (or set of heads)

    The receiver node is specified by receiver_heads and receiver_input.
    Example (for S-inhibition path patching the queries):
        receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
        receiver_input = "v"

    Returns:
        tensor of metric values for every possible sender head
    '''
    layers = model.cfg.n_layers
    heads = model.cfg.n_heads

    ret = t.zeros(layers, heads)

    for corrupted_layer in tqdm(range(layers)):
        for corrupted_head in range(heads):

            # ----- RUN ONE: Patch sender, freeze others, cache reciever

            model.reset_hooks()
            
            # Find the corruption
            corrupt_hook = functools.partial(
                hook_and_patch_zqkv, 
                new_activation = new_cache["z", corrupted_layer][:,:,corrupted_head], 
                head = corrupted_head
            )
            corrupt_input_name = utils.get_act_name("z", corrupted_layer)
            model.add_hook(corrupt_input_name, corrupt_hook)

            # Make sure all the other heads be patched in cursed O(n^2) time
            for non_corrupted_layer in range(layers):
                for non_corrupted_head in range(heads):
                    # if (not layer == corrupted_layer) and (not head == corrupted_head):
                    if not (non_corrupted_layer, non_corrupted_head) == (corrupted_layer, corrupted_head):
                        clean_hook = functools.partial(
                            hook_and_patch_zqkv, 
                            new_activation = orig_cache["z", non_corrupted_layer][:, :, non_corrupted_head], 
                            head = non_corrupted_head
                        )
                        clean_input_name = utils.get_act_name("z", non_corrupted_layer)
                        model.add_hook(clean_input_name, clean_hook)

            run1_logits, run1_cache = model.run_with_cache(orig_dataset.toks)

            # ----- RUN TWO
        
            model.reset_hooks()

            # Patch in the layer and head from the corrupted run
            for layer, head in receiver_heads:
                receiver_zqkv = run1_cache[receiver_input, layer][:,:, head] 
                patch_hook = functools.partial(
                    hook_and_patch_zqkv, 
                    new_activation = receiver_zqkv, 
                    head = head
                )
                model.add_hook(utils.get_act_name(receiver_input, layer), patch_hook)

            run2_logits, run2_cache = model.run_with_cache(orig_dataset.toks)

            logits, cache = model.run_with_cache(orig_dataset.toks)
            ret[corrupted_layer, corrupted_head] = patching_metric(logits)

    return ret

model.reset_hooks()

s_inhibition_value_path_patching_results = get_path_patch_head_to_heads(
    receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)],
    receiver_input = "v",
    model = model,
    patching_metric = ioi_metric_2
)

imshow(
    100 * s_inhibition_value_path_patching_results,
    title="Direct effect on S-Inhibition Heads' values", 
    labels={"x": "Head", "y": "Layer", "color": "Logit diff.<br>variation"},
    width=600,
    coloraxis=dict(colorbar_ticksuffix = "%"),
)

  0%|          | 0/12 [00:00<?, ?it/s]

Key variables:
- ioi_dataset. Has a couple of good methods/attributions: ``toks`` is a tensor containing token IDs (batch_size, max_seq_len); ``s_tokenIDs`` and ``io_tokenIDs`` are a list containing token IDs for the subjects and objects; ``sentences`` is a list containing the sentences as strings; ``word_idx`` is a dictionary mapping word types to tensors containing positions of those words for each sequence
- abc_dataset

Data cleaning / visualization: make_table prints tables after being fed columns, format_prompt (??)

New function for logit difference: def logits_to_ave_logit_diff_2

New function for IOI metric: ioi_metric_2. -1 means corrupted (ABC dataset), while 0 means non-corrupted (IOI dataset performance).

## Paper Replication

In [None]:
def scatter_embedding_vs_attn(
    attn_from_end_to_io: Float[Tensor, "batch"],
    attn_from_end_to_s: Float[Tensor, "batch"],
    projection_in_io_dir: Float[Tensor, "batch"],
    projection_in_s_dir: Float[Tensor, "batch"],
    layer: int,
    head: int
):
    scatter(
        x=t.concat([attn_from_end_to_io, attn_from_end_to_s], dim=0),
        y=t.concat([projection_in_io_dir, projection_in_s_dir], dim=0),
        color=["IO"] * N + ["S"] * N,
        title=f"Projection of the output of {layer}.{head} along the name<br>embedding vs attention probability on name",
        title_x=0.5,
        labels={"x": "Attn prob on name", "y": "Dot w Name Embed", "color": "Name type"},
        color_discrete_sequence=["#72FF64", "#C9A5F7"],
        width=650
    )

def calculate_and_show_scatter_embedding_vs_attn(
    layer: int,
    head: int,
    cache: ActivationCache = ioi_cache,
    dataset: IOIDataset = ioi_dataset,
) -> None:
    '''
    Creates and plots a figure equivalent to 3(c) in the paper.

    This should involve computing the four 1D tensors:
        attn_from_end_to_io
        attn_from_end_to_s
        projection_in_io_dir
        projection_in_s_dir
    and then calling the scatter_embedding_vs_attn function.
    '''

    


nmh = (9, 9)
calculate_and_show_scatter_embedding_vs_attn(*nmh)

nnmh = (11, 10)
calculate_and_show_scatter_embedding_vs_attn(*nnmh)