In [68]:
from circuitsvis.attention import attention_heads
from fancy_einsum import einsum
from functools import partial
from IPython.display import HTML, IFrame
from jaxtyping import Float
from os import environ
from transformer_lens import ActivationCache, HookedTransformer
from typing import List, Optional, Union
import circuitsvis as cv
import dotenv
import einops
import numpy as np
import plotly.express as px
import plotly.io as pio
import torch
import tqdm.auto as tqdm
import transformer_lens
import transformer_lens.utils as utils
from utils import imshow, line, scatter

dotenv.load_dotenv('.env', override=True)
print(f'{environ.get("PYTORCH_ENABLE_MPS_FALLBACK")=}')

torch.set_grad_enabled(False)

DEVICE = utils.get_device()
print(f'{DEVICE=}')

PRETRAINED_MODEL = 'gpt2-small'  # gpt2-small or gpt2-medium
print(f'{PRETRAINED_MODEL=}')

model = transformer_lens.HookedTransformer.from_pretrained(
    PRETRAINED_MODEL, center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True, device=DEVICE)

print("test model generate:", model.generate(
    "The Space Needle is in the city of"))

print("print model structure", model)

environ.get("PYTORCH_ENABLE_MPS_FALLBACK")='1'
DEVICE=device(type='mps')
PRETRAINED_MODEL='gpt2-small'




Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 10/10 [00:00<00:00, 22.69it/s]

test model generate: The Space Needle is in the city of Sedro Renica. It holds what must have
print model structure HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPo




In [2]:
prompts = [
    "123+456=",
    " 123+456=",
    " 123 +456=",
    " 123 + 456=",
    " 123 + 456 =",
    " 123 + 456 = ",
    "123 +456=",
    "123 + 456=",
    "123 + 456 =",
    "123+ 456=",
    "123+ 456 =",
    "123+ 456 = ",
    "123+456 =",
    "123+456 = ",
    "123+456= ",
]
answers = [  # each answer contains 2 tokens
    " 579",
]


# utils.test_prompt(prompt='123+123=', answer='579', model=model,
#                   prepend_bos=True, prepend_space_to_answer=False, print_details=False)

for p in prompts:
    for a in answers:
        print(f"prompt=|{p}|, answer=|{a}|")
        utils.test_prompt(prompt=p, answer=a, model=model, top_k=1)

prompt=|123+456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', '456', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 10.69 Prob:  3.82% Token: |123|


Top 0th token. Logit: 11.48 Prob:  2.30% Token: |
|


prompt=| 123+456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', ' 123', '+', '456', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 11.32 Prob:  7.90% Token: |123|


Top 0th token. Logit: 11.56 Prob:  3.41% Token: |
|


prompt=| 123 +456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', ' 123', ' +', '456', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 10.62 Prob:  4.43% Token: |123|


Top 0th token. Logit: 12.00 Prob:  4.49% Token: | +|


prompt=| 123 + 456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', ' 123', ' +', ' 4', '56', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit:  9.53 Prob:  1.87% Token: |
|


Top 0th token. Logit: 11.43 Prob:  6.52% Token: |.|


prompt=| 123 + 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', ' 123', ' +', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 11.84 Prob:  3.70% Token: | 1|


Top 0th token. Logit: 12.16 Prob:  9.45% Token: |.|


prompt=| 123 + 456 = |, answer=| 579|
Tokenized prompt: ['<|endoftext|>', ' 123', ' +', ' 4', '56', ' =', ' ']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 13.40 Prob: 11.12% Token: | |


Top 0th token. Logit: 11.47 Prob:  7.16% Token: | +|


prompt=|123 +456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', ' +', '456', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 10.12 Prob:  3.07% Token: |123|


Top 0th token. Logit: 11.82 Prob:  3.29% Token: | +|


prompt=|123 + 456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', ' +', ' 4', '56', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit:  9.51 Prob:  2.05% Token: |
|


Top 0th token. Logit: 11.67 Prob:  7.86% Token: |.|


prompt=|123 + 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', ' +', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 11.95 Prob:  3.67% Token: | 1|


Top 0th token. Logit: 12.48 Prob: 12.00% Token: |.|


prompt=|123+ 456=|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', ' 4', '56', '=']
Tokenized answer: [' 5', '79']


Top 0th token. Logit:  9.26 Prob:  1.95% Token: |
|


Top 0th token. Logit: 11.07 Prob:  5.20% Token: |.|


prompt=|123+ 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 10.61 Prob:  3.29% Token: | 1|


Top 0th token. Logit: 11.77 Prob:  8.72% Token: |.|


prompt=|123+ 456 = |, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', ' 4', '56', ' =', ' ']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 13.57 Prob: 11.60% Token: |????|


Top 0th token. Logit: 10.90 Prob:  5.46% Token: |.|


prompt=|123+456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', '456', ' =']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 12.06 Prob: 12.69% Token: | 123|


Top 0th token. Logit: 11.51 Prob:  2.49% Token: |
|


prompt=|123+456 = |, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', '456', ' =', ' ']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 13.69 Prob:  8.38% Token: |【|


Top 0th token. Logit: 10.64 Prob:  1.83% Token: |
|


prompt=|123+456= |, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', '456', '=', ' ']
Tokenized answer: [' 5', '79']


Top 0th token. Logit: 14.04 Prob: 12.58% Token: |irc|


Top 0th token. Logit: 10.36 Prob:  1.80% Token: |.|


```
prompt=|123 + 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', ' +', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']
Performance on answer token:
Rank: 6        Logit: 11.39 Prob:  2.09% Token: | 5|
Top 0th token. Logit: 11.95 Prob:  3.67% Token: | 1|
Performance on answer token:
Rank: 27       Logit:  9.31 Prob:  0.51% Token: |79|
Top 0th token. Logit: 12.48 Prob: 12.00% Token: |.|
Ranks of the answer tokens: [(' 5', 6), ('79', 27)]


prompt=|123+ 456 =|, answer=| 579|
Tokenized prompt: ['<|endoftext|>', '123', '+', ' 4', '56', ' =']
Tokenized answer: [' 5', '79']
Performance on answer token:
Rank: 5        Logit:  9.91 Prob:  1.63% Token: | 5|
Top 0th token. Logit: 10.61 Prob:  3.29% Token: | 1|
Performance on answer token:
Rank: 80       Logit:  8.25 Prob:  0.26% Token: |79|
Top 0th token. Logit: 11.77 Prob:  8.72% Token: |.|
Ranks of the answer tokens: [(' 5', 5), ('79', 80)]
```

* first output token logit diff from answer token logit = `|11.39 - 11.95| = -0.56`
* second output token logit diff from answer token logit = `|9.31 - 12.48| = -3.17`



In [81]:
correct_wrong_answer_tokens = model.to_tokens(
    " 5 1", prepend_bos=False).to(DEVICE)
print(f'{correct_wrong_answer_tokens=}')
answer_residual_directions = model.tokens_to_residual_directions(
    correct_wrong_answer_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
# logit_diff_directions = (
#     # token | 5| - token | 1|
#     answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
# )
print(f'{logit_diff_directions.shape=}')
print(f'{answer_residual_directions[0,0,0]=}')
# logit_diff_directions = (
#     answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
# )
# print("Logit difference directions shape:", logit_diff_directions.shape)

# 768 refers to the dimensionality of the embedding and hidden states within the model

# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer.
# The general syntax is [activation_name, layer_index, sub_layer_type].
input_tokens0 = model.to_tokens("123 + 456 =", prepend_bos=True)
output_logit0, cache0 = model.run_with_cache(input_tokens0)
print(f'logit for token | 5|: {output_logit0[0, -1, 642]=}')
print(f'logit for token | 1|: {output_logit0[0, -1, 352]=}')
assert (torch.eq(output_logit0[0, -1, 642].round(decimals=4), 11.3875).item())

input_tokens1 = model.to_tokens(
    "123 + 456 = 5", prepend_bos=True)  # advance 1 more token
output_logit1, cache1 = model.run_with_cache(input_tokens1)
print(f'logit for token |79|: {output_logit1[0, -1, 3720]=}')
assert (torch.eq(output_logit1[0, -1, 3720].round(decimals=4), 9.3148).item())

final_residual_stream0 = cache0["resid_post", -1]  # shape [1, 6, 768]
final_token_residual_stream = final_residual_stream0[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache0.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)
print(f'{scaled_final_token_residual_stream.shape=}')
print(f'{scaled_final_token_residual_stream[0,0]=}')
print(f'{cache0["ln_final.hook_normalized"].shape=}')
print(f'{cache0["ln_final.hook_normalized"][0,-1,0]=}')

calculated_logit_diff = cache0["ln_final.hook_normalized"][:, -1, :]
calculated_logit_diff = calculated_logit_diff @ model.unembed.W_U + model.unembed.b_U
print(f'{calculated_logit_diff.shape=}')
print(f'{calculated_logit_diff[0,642]=}')


def logits_to_ave_logit_diff(logits, answer_tokens):
    # Only the final logits are relevant for the answer
    final_logits = logits[:, -1, :]
    answer_logits = final_logits.gather(dim=-1, index=answer_tokens)
    print(f'{answer_logits=}')
    return answer_logits[:, 0] - answer_logits[:, 1]


print(
    "Per prompt logit difference:",
    logits_to_ave_logit_diff(output_logit0, correct_wrong_answer_tokens)
    .detach()
    .cpu()
    .round(decimals=3),
)

correct_wrong_answer_tokens=tensor([[642, 352]], device='mps:0')
Answer residual directions shape: torch.Size([1, 2, 768])
logit_diff_directions.shape=torch.Size([1, 768])
answer_residual_directions[0,0,0]=tensor(-0.1862, device='mps:0')
logit for token | 5|: output_logit0[0, -1, 642]=tensor(11.3875, device='mps:0')
logit for token | 1|: output_logit0[0, -1, 352]=tensor(11.9513, device='mps:0')
logit for token |79|: output_logit1[0, -1, 3720]=tensor(9.3148, device='mps:0')
scaled_final_token_residual_stream.shape=torch.Size([1, 768])
scaled_final_token_residual_stream[0,0]=tensor(0.1353, device='mps:0')
cache0["ln_final.hook_normalized"].shape=torch.Size([1, 6, 768])
cache0["ln_final.hook_normalized"][0,-1,0]=tensor(0.1353, device='mps:0')
calculated_logit_diff.shape=torch.Size([1, 50257])
calculated_logit_diff[0,642]=tensor(11.3875, device='mps:0')
answer_logits=tensor([[11.3875, 11.9513]], device='mps:0')
Per prompt logit difference: tensor([-0.5640])


---

In [None]:
text = " 123456754234 + 84729123475"
tokens = model.to_tokens(text)
logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
attention_pattern = activations["pattern", 0, "attn"]
str_tokens = model.to_str_tokens(text)
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [None]:
loss = model(tokens, return_type='loss')
print(f'{loss=}')

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 1


def head_ablation_hook(value, hook):
    print(f"Shape of the value tensor: {value.shape} {hook}")
    value[:, :, 1, :] = 0.
    value[:, :, 2, :] = 0.
    return value


activations = None
with model.hooks(fwd_hooks=[(utils.get_act_name("attn", 0, "pattern"), head_ablation_hook)]):
    print(model)
    logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
    loss = model(tokens, return_type='loss')
    print(f'{loss=}')

attention_pattern = activations["pattern", 0, "attn"]
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

# loss = model.run_with_hooks(
#     tokens,
#     return_type='loss',
#     fwd_hooks=[(
#         utils.get_act_name("v", layer_to_ablate),
#         head_ablation_hook
#     )]
# )
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [None]:
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
# clean_prompt = "The Space Needle is in the city of"
# corrupted_prompt = "Eiffel Tower is located in the city of"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer, incorrect_answer):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    # print(f'{correct_index=}, {incorrect_index=}')
    # print(f'{logits.shape=}')
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits, correct_answer=" Seattle", incorrect_answer=" Paris")
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits, correct_answer=" Paris", incorrect_answer=" Seattle")
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

# str_tokens = model.to_str_tokens(clean_prompt)
# print(clean_cache)
# attention_pattern = clean_cache["pattern", 0, "attn"]
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [None]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(resid_pre, hook, position):
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre


# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
print(f'{clean_tokens.shape=}')
num_positions = len(clean_tokens[0])
print(f'{num_positions=}')
ioi_patching_result = torch.zeros(
    (model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(
            residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits, correct_answer=" Seattle", incorrect_answer=" Paris").detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (
            patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

In [None]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")

In [None]:
batch_size = 10
seq_len = 50
size = (batch_size, seq_len)
input_tensor = torch.randint(0, 10000, size)

random_tokens = input_tensor.to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
repeated_logits = model(repeated_tokens)
correct_log_probs = model.loss_fn(repeated_logits, repeated_tokens, per_token=True)
loss_by_position = einops.reduce(correct_log_probs, "batch position -> position", "mean")
line(loss_by_position, xaxis="Position", yaxis="Loss", title="Loss by position on random repeated tokens")

In [None]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros(
    (model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)


def induction_score_hook(pattern, hook):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(
        induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score


# We make a boolean filter on activation names, that's true only on attention pattern names.
def pattern_hook_names_filter(name): return name.endswith("pattern")


model.run_with_hooks(
    repeated_tokens,
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head",
       yaxis="Layer", title="Induction Score by Head")

In [None]:

induction_head_layer = 18
induction_head_index = 5
size = (1, 20)
input_tensor = torch.randint(1000, 10000, size)

single_random_sequence = input_tensor.to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")
def visualize_pattern_hook(pattern, hook):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)

In [None]:
test_prompt = "The quick brown fox jumped over the lazy dog"
print("Num tokens:", len(model.to_tokens(test_prompt)[0]))

def print_name_shape_hook_function(activation, hook):
    print(hook.name, activation.shape)

not_in_late_block_filter = lambda name: name.startswith("blocks.0.") or not name.startswith("blocks")

model.run_with_hooks(
    test_prompt,
    return_type=None,
    fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],
)

In [None]:
example_prompt = "1111 + 2222 ="
example_answer = " 3333"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)