# Hook Meanings

https://www.lesswrong.com/posts/qxvihKpFMuc4tvuf4/recall-and-regurgitation-in-gpt2

In [1]:
import torch, transformer_lens, itertools, torchvision
from measureLM import visualizing, decoding, patching, scoring
from functools import partial
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  

In [2]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-medium").to("cpu")
model.cfg.spacing = "Ġ"
model.tokenizer.pad_token = model.tokenizer.eos_token

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-medium into HookedTransformer
Moving model to device:  cpu


In [3]:
## encoding
token_candidates = ["Paris", "France", "Poland", "Warsaw"]
prompts = [("Q: What is the capital of France? A: Paris Q: What is the capital of Poland? A:",
            'Q: What is the capital of Poland? Options: "A" Berlin, "B" Warsaw, "C" Paris A:'),
           ("Q: What is the relationship between Joe Biden and Donald Trump? A:",
            'Joe Biden and Donald Trump are enemies. Q: What is the relationship between Joe Biden and Donald Trump? A:'),
           ("Q: When was Barrack Obama born? A:",
            'Trump was born in 1946. Barrack Obama was born in 1961. Q: When was Barrack Obama born? A:')]

example = 0
i = 0
logits, activs, tokens = decoding.encode(prompts[example][i], model)
pred = model.tokenizer.convert_ids_to_tokens(torch.topk(logits[:,-1,:], k=10).indices.tolist()[0])
print(pred, (len(tokens[0])))

['ĠWarsaw', 'ĠPoland', 'ĠW', 'ĠParis', 'ĠK', 'ĠPrague', 'ĠPo', 'ĠP', 'ĠB', 'ĠL'] 23


In [66]:
def token_select(layer_token_vec, tok_type="last"):
    if tok_type == "last":
        layer_token_vec = layer_token_vec[...,-1].squeeze().detach()
    if tok_type == "mean":
        layer_token_vec = layer_token_vec.mean(-1).squeeze().detach()
    return layer_token_vec


def layer_norm(layer_token_vec):

    mean = torch.mean(layer_token_vec, dim=-1).unsqueeze(-1).repeat(1,1,layer_token_vec.shape[-1])
    std = torch.std(layer_token_vec, dim=-1).unsqueeze(-1).repeat(1,1,layer_token_vec.shape[-1])
    
    layer_token_vec_normed = (layer_token_vec - mean) / std
    return layer_token_vec_normed, mean, std

In [22]:
list(activs)[:19]

['hook_embed',
 'hook_pos_embed',
 'blocks.0.hook_resid_pre',
 'blocks.0.ln1.hook_scale',
 'blocks.0.ln1.hook_normalized',
 'blocks.0.attn.hook_q',
 'blocks.0.attn.hook_k',
 'blocks.0.attn.hook_v',
 'blocks.0.attn.hook_attn_scores',
 'blocks.0.attn.hook_pattern',
 'blocks.0.attn.hook_z',
 'blocks.0.hook_attn_out',
 'blocks.0.hook_resid_mid',
 'blocks.0.ln2.hook_scale',
 'blocks.0.ln2.hook_normalized',
 'blocks.0.mlp.hook_pre',
 'blocks.0.mlp.hook_post',
 'blocks.0.hook_mlp_out',
 'blocks.0.hook_resid_post']

## LayerNorm

In [29]:

#activs["blocks.0.ln2.hook_scale"]
activs["blocks.0.ln1.hook_normalized"]

tensor([[[-0.5929, -0.5020, -0.4043,  ...,  0.4311, -1.1121, -0.0822],
         [ 0.4339, -0.3292, -0.2374,  ...,  0.2788, -0.4852,  0.0150],
         [-0.9933, -0.0774,  0.2924,  ..., -0.3123,  0.2428, -0.1188],
         ...,
         [-0.0603, -0.6091, -0.0616,  ..., -0.3783, -0.1042,  0.0017],
         [-0.9920, -0.7807, -0.9680,  ...,  0.6503,  0.2500, -0.3483],
         [ 0.1450, -0.2279, -0.3149,  ...,  0.1350, -0.6963,  0.0270]]])

In [30]:
norm, mean, std = layer_norm(activs["blocks.0.hook_resid_pre"]) #std[0,:,0]
norm

tensor([[[-0.5927, -0.5018, -0.4041,  ...,  0.4309, -1.1117, -0.0822],
         [ 0.4339, -0.3292, -0.2374,  ...,  0.2787, -0.4852,  0.0150],
         [-0.9932, -0.0774,  0.2924,  ..., -0.3123,  0.2428, -0.1188],
         ...,
         [-0.0603, -0.6092, -0.0617,  ..., -0.3784, -0.1042,  0.0017],
         [-0.9922, -0.7809, -0.9682,  ...,  0.6505,  0.2501, -0.3484],
         [ 0.1451, -0.2280, -0.3150,  ...,  0.1350, -0.6965,  0.0270]]])

## MLP Out

In [51]:
activs["blocks.10.hook_resid_mid"] + activs["blocks.10.hook_mlp_out"]

tensor([[[ 5.9129,  3.8089,  5.7358,  ...,  4.1576,  4.2645,  3.3083],
         [ 5.1118, -1.3723,  3.2000,  ...,  2.2238,  1.6303,  0.7182],
         [ 2.0944, -0.0644,  0.6919,  ..., -2.8870,  1.4337,  5.2877],
         ...,
         [-2.4473, -4.1174, -0.6573,  ..., -0.9985,  1.6291, -1.4805],
         [-0.7714,  1.9903,  0.0090,  ..., -3.2068, -1.8586,  1.2165],
         [-1.8430,  0.8577, -2.4709,  ...,  0.3093, -1.3889,  0.9787]]])

In [52]:
activs["blocks.10.hook_resid_post"]

tensor([[[ 5.9129,  3.8089,  5.7358,  ...,  4.1576,  4.2645,  3.3083],
         [ 5.1118, -1.3723,  3.2000,  ...,  2.2238,  1.6303,  0.7182],
         [ 2.0944, -0.0644,  0.6919,  ..., -2.8870,  1.4337,  5.2877],
         ...,
         [-2.4473, -4.1174, -0.6573,  ..., -0.9985,  1.6291, -1.4805],
         [-0.7714,  1.9903,  0.0090,  ..., -3.2068, -1.8586,  1.2165],
         [-1.8430,  0.8577, -2.4709,  ...,  0.3093, -1.3889,  0.9787]]])

## Attention Out

In [55]:
activs["blocks.0.hook_resid_pre"] + activs["blocks.0.hook_attn_out"]

tensor([[[-0.0665, -0.1217, -0.0888,  ...,  0.1090, -0.2056, -0.0248],
         [ 0.1138, -0.0925,  0.0099,  ...,  0.0375, -0.0260, -0.0638],
         [-0.1199, -0.0107,  0.0446,  ..., -0.0299,  0.0280, -0.0639],
         ...,
         [-0.0193, -0.0516,  0.0047,  ..., -0.0421,  0.0252, -0.0206],
         [-0.0927, -0.0927, -0.0923,  ...,  0.0375,  0.0523, -0.0443],
         [ 0.0240, -0.0615,  0.0131,  ..., -0.0141, -0.0076, -0.0423]]])

In [56]:
activs["blocks.0.hook_resid_mid"]

tensor([[[-0.0665, -0.1217, -0.0888,  ...,  0.1090, -0.2056, -0.0248],
         [ 0.1138, -0.0925,  0.0099,  ...,  0.0375, -0.0260, -0.0638],
         [-0.1199, -0.0107,  0.0446,  ..., -0.0299,  0.0280, -0.0639],
         ...,
         [-0.0193, -0.0516,  0.0047,  ..., -0.0421,  0.0252, -0.0206],
         [-0.0927, -0.0927, -0.0923,  ...,  0.0375,  0.0523, -0.0443],
         [ 0.0240, -0.0615,  0.0131,  ..., -0.0141, -0.0076, -0.0423]]])

## Norms

In [65]:
torch.linalg.norm(activs["blocks.20.hook_resid_pre"],dim=-1).sum()

tensor(12182.9014)

In [64]:
torch.linalg.norm(activs["blocks.20.hook_attn_out"],dim=-1).sum()

tensor(599.4905)

In [45]:
torch.linalg.norm(activs["blocks.2.ln1.hook_normalized"],dim=-1).sum()

tensor(735.9998)