In [1]:
import transformer_lens
import dotenv
from os import environ
import torch
import plotly.io as pio
import transformer_lens.utils as utils
import circuitsvis as cv
import tqdm.auto as tqdm
from functools import partial
import plotly.express as px
import einops

dotenv.load_dotenv('.env', override=True)
print(f'{environ.get("PYTORCH_ENABLE_MPS_FALLBACK")=}')

torch.set_grad_enabled(False)

DEVICE = utils.get_device()
print(f'{DEVICE=}')

# Load a model (eg GPT-2 Small)
model = transformer_lens.HookedTransformer.from_pretrained(
    "gpt2-medium", device=DEVICE)

print("test model generate:", model.generate("The Space Needle is in the city of"))

print("print model structure", model)

  from .autonotebook import tqdm as notebook_tqdm


environ.get("PYTORCH_ENABLE_MPS_FALLBACK")='1'
DEVICE=device(type='mps')




Loaded pretrained model gpt2-medium into HookedTransformer


  0%|          | 0/10 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  torch.isin(
100%|██████████| 10/10 [00:01<00:00,  5.52it/s]

test model generate: The Space Needle is in the city of Seattle and is located at 1 M Street SE between
print model structure HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-23): 24 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_inpu




In [2]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [3]:
text = "The Space Needle is in the city of"
tokens = model.to_tokens(text)
logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
attention_pattern = activations["pattern", 0, "attn"]
str_tokens = model.to_str_tokens(text)
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

In [4]:
loss = model(tokens, return_type='loss')
print(f'{loss=}')

loss=tensor(3.1159, device='mps:0')


In [5]:
layer_to_ablate = 0
head_index_to_ablate = 1


def head_ablation_hook(value, hook):
    print(f"Shape of the value tensor: {value.shape} {hook}")
    value[:, :, 1, :] = 0.
    value[:, :, 2, :] = 0.
    return value


activations = None
with model.hooks(fwd_hooks=[(utils.get_act_name("attn", 0, "pattern"), head_ablation_hook)]):
    print(model)
    logits, activations = model.run_with_cache(tokens, remove_batch_dim=True)
    loss = model(tokens, return_type='loss')
    print(f'{loss=}')

attention_pattern = activations["pattern", 0, "attn"]
cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

# loss = model.run_with_hooks(
#     tokens,
#     return_type='loss',
#     fwd_hooks=[(
#         utils.get_act_name("v", layer_to_ablate),
#         head_ablation_hook
#     )]
# )
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-23): 24 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [6]:
# clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
# corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
clean_prompt = "The Space Needle is in the city of"
corrupted_prompt = "Eiffel Tower is located in the city of"

clean_tokens = model.to_tokens(clean_prompt)
corrupted_tokens = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer, incorrect_answer):
    # model.to_single_token maps a string value of a single token to the token index for that token
    # If the string is not a single token, it raises an error.
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    # print(f'{correct_index=}, {incorrect_index=}')
    # print(f'{logits.shape=}')
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# We run on the clean prompt with the cache so we store activations to patch in later.
clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logits_to_logit_diff(clean_logits, correct_answer=" Seattle", incorrect_answer=" Paris")
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits, correct_answer=" Paris", incorrect_answer=" Seattle")
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

# str_tokens = model.to_str_tokens(clean_prompt)
# print(clean_cache)
# attention_pattern = clean_cache["pattern", 0, "attn"]
# cv.attention.attention_patterns(tokens=str_tokens, attention=attention_pattern)

Clean logit difference: 3.711
Corrupted logit difference: 11.450


In [7]:
# We define a residual stream patching hook
# We choose to act on the residual stream at the start of the layer, so we call it resid_pre
# The type annotations are a guide to the reader and are not necessary
def residual_stream_patching_hook(resid_pre, hook, position):
    # Each HookPoint has a name attribute giving the name of the hook.
    clean_resid_pre = clean_cache[hook.name]
    resid_pre[:, position, :] = clean_resid_pre[:, position, :]
    return resid_pre


# We make a tensor to store the results for each patching run. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
print(f'{clean_tokens.shape=}')
num_positions = len(clean_tokens[0])
print(f'{num_positions=}')
ioi_patching_result = torch.zeros(
    (model.cfg.n_layers, num_positions), device=model.cfg.device)

for layer in tqdm.tqdm(range(model.cfg.n_layers)):
    for position in range(num_positions):
        # Use functools.partial to create a temporary hook function with the position fixed
        temp_hook_fn = partial(
            residual_stream_patching_hook, position=position)
        # Run the model with the patching hook
        patched_logits = model.run_with_hooks(corrupted_tokens, fwd_hooks=[
            (utils.get_act_name("resid_pre", layer), temp_hook_fn)
        ])
        # Calculate the logit difference
        patched_logit_diff = logits_to_logit_diff(patched_logits, correct_answer=" Seattle", incorrect_answer=" Paris").detach()
        # Store the result, normalizing by the clean and corrupted logit difference so it's between 0 and 1 (ish)
        ioi_patching_result[layer, position] = (
            patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

clean_tokens.shape=torch.Size([1, 10])
num_positions=10


100%|██████████| 24/24 [00:14<00:00,  1.61it/s]


In [8]:
# Add the index to the end of the label, because plotly doesn't like duplicate labels
token_labels = [f"{token}_{index}" for index, token in enumerate(model.to_str_tokens(clean_tokens))]
imshow(ioi_patching_result, x=token_labels, xaxis="Position", yaxis="Layer", title="Normalized Logit Difference After Patching Residual Stream on the IOI Task")

In [9]:
batch_size = 10
seq_len = 50
size = (batch_size, seq_len)
input_tensor = torch.randint(0, 10000, size)

random_tokens = input_tensor.to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
repeated_logits = model(repeated_tokens)
correct_log_probs = model.loss_fn(repeated_logits, repeated_tokens, per_token=True)
loss_by_position = einops.reduce(correct_log_probs, "batch position -> position", "mean")
line(loss_by_position, xaxis="Position", yaxis="Loss", title="Loss by position on random repeated tokens")

In [11]:
# We make a tensor to store the induction score for each head. We put it on the model's device to avoid needing to move things between the GPU and CPU, which can be slow.
induction_score_store = torch.zeros(
    (model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)


def induction_score_hook(pattern, hook):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(
        induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score


# We make a boolean filter on activation names, that's true only on attention pattern names.
def pattern_hook_names_filter(name): return name.endswith("pattern")


model.run_with_hooks(
    repeated_tokens,
    return_type=None,  # For efficiency, we don't need to calculate the logits
    fwd_hooks=[(
        pattern_hook_names_filter,
        induction_score_hook
    )]
)

imshow(induction_score_store, xaxis="Head",
       yaxis="Layer", title="Induction Score by Head")

In [13]:

induction_head_layer = 18
induction_head_index = 5
size = (1, 20)
input_tensor = torch.randint(1000, 10000, size)

single_random_sequence = input_tensor.to(model.cfg.device)
repeated_random_sequence = einops.repeat(single_random_sequence, "batch seq_len -> batch (2 seq_len)")
def visualize_pattern_hook(pattern, hook):
    display(
        cv.attention.attention_patterns(
            tokens=model.to_str_tokens(repeated_random_sequence), 
            attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
        )
    )

model.run_with_hooks(
    repeated_random_sequence, 
    return_type=None, 
    fwd_hooks=[(
        utils.get_act_name("pattern", induction_head_layer), 
        visualize_pattern_hook
    )]
)

In [14]:
test_prompt = "The quick brown fox jumped over the lazy dog"
print("Num tokens:", len(model.to_tokens(test_prompt)[0]))

def print_name_shape_hook_function(activation, hook):
    print(hook.name, activation.shape)

not_in_late_block_filter = lambda name: name.startswith("blocks.0.") or not name.startswith("blocks")

model.run_with_hooks(
    test_prompt,
    return_type=None,
    fwd_hooks=[(not_in_late_block_filter, print_name_shape_hook_function)],
)

Num tokens: 10
hook_embed torch.Size([1, 10, 1024])
hook_pos_embed torch.Size([1, 10, 1024])
blocks.0.hook_resid_pre torch.Size([1, 10, 1024])
blocks.0.ln1.hook_scale torch.Size([1, 10, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 10, 1024])
blocks.0.ln1.hook_scale torch.Size([1, 10, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 10, 1024])
blocks.0.ln1.hook_scale torch.Size([1, 10, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 10, 1024])
blocks.0.attn.hook_q torch.Size([1, 10, 16, 64])
blocks.0.attn.hook_k torch.Size([1, 10, 16, 64])
blocks.0.attn.hook_v torch.Size([1, 10, 16, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 16, 10, 10])
blocks.0.attn.hook_pattern torch.Size([1, 16, 10, 10])
blocks.0.attn.hook_z torch.Size([1, 10, 16, 64])
blocks.0.hook_attn_out torch.Size([1, 10, 1024])
blocks.0.hook_resid_mid torch.Size([1, 10, 1024])
blocks.0.ln2.hook_scale torch.Size([1, 10, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 10, 1024])
blocks.0.mlp.hook_pre torch.Size([1, 10,