In [28]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm
import random

import plotly.express as px

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

import einops

from typing import Literal
from jaxtyping import Float

from transformers import AutoTokenizer, AutoModelForCausalLM

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils


torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [29]:
def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def compute_logit_diff(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

# Do we obtain the same activations from huggingface and transformerlens?

The idea of this experiment is to be able to obtain/patch the activations of every attention head. We need this to be able to work directly on a HF model by (i) performing activation patching to check which heads matter (ii) gather these mean activations and (iii) properly prune the huggingface model by completely removing heads/complete attention layers. 

- In TransformerLens, we can obtain the individual contributions of each head, i.e. what they write individually to the residual stream, via the `results` hook. This is computed by taking the activations at the `z` hook (i.e. multiplying value vector with the attention and softmaxing) and using the O matrix to project them into the residual space.
- In HF, the model is defined in `modeling_gpt2.py`. I can see that the `_attn` function is equivalent to what is computed on the `z` hook: the QKV vectors are computed, the QK vectors are used to obtain the attention patterns, which are then used to compute the output as a weighted sum of the V vector. Then, 

**Equivalence**:
- `z: (batch_size, seq_len, n_head, d_head)` $\rightarrow$ `_attn: `

In [30]:
model_tl = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model_tl.set_use_hook_mlp_in(True)
model_tl.set_use_split_qkv_input(True)
model_tl.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [31]:
tokenizer_hf = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model_hf = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).to(device)

In [32]:
with open("acronyms_2_common.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = 50
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [33]:
tokens_tl = model_tl.to_tokens(prompts)
answer_tokens = model_tl.to_tokens(acronyms, prepend_bos=False)

logits_tl, cache_tl = model_tl.run_with_cache(tokens_tl)

In [34]:
tokens_hf = tokenizer_hf(prompts, return_tensors="pt")["input_ids"]
tokens_hf = torch.cat([torch.ones((tokens_hf.shape[0], 1), dtype=torch.long) * tokenizer_hf.bos_token_id, tokens_hf], dim=1).to(device)
logits_hf = model_hf(tokens_hf)["logits"]

In [35]:
logit_diff_tl = compute_logit_diff(logits_tl, answer_tokens, average=False)[..., -1].mean()
logit_diff_hf = compute_logit_diff(logits_hf, answer_tokens, average=False)[..., -1].mean()

print(f"Logit Diff. TL: {logit_diff_tl.item():.4f}, HF: {logit_diff_hf.item():.4f}, Allclose: {torch.allclose(logit_diff_tl, logit_diff_hf)}")

Logit Diff. TL: 3.6052, HF: 3.6052, Allclose: True


**Output of attention head (in head space)**

In [36]:
z_tl = cache_tl[utils.get_act_name("z", 0)]
z_tl.shape

torch.Size([50, 11, 12, 64])

In [37]:
z_hf = torch.zeros_like(z_tl)

def get_z(module: nn.Module, input, output):
    """
    module -> c_proj module inside `GPT2Attention`
    input -> tensor of shape (batch_size, seq_len, n_heads*d_head=d_model)
    """
    x = input[0] # (batch_size, seq_len, n_heads*d_head=d_model)
    z = x.view(x.shape[:2] + (model_hf.config.n_head, -1))  # (batch_size, seq_len, n_heads, d_head)
    z_hf[:] = z.detach().clone()[:]

In [38]:
hook = model_hf.transformer.h[0].attn.c_proj.register_forward_hook(get_z)
model_hf(tokens_hf);
hook.remove()

In [39]:
torch.allclose(z_tl, z_hf)

False

They are not close, probably it has something to do with layer normalization? Let's try performing the same patching experiments and checking if we get the same results. **NOTE:** Because of how GPT2 is implemented on HF, I think that it is better to patch the V vectors instead of the Z vectors. We should obtain the same results.

In [40]:
def mean_ablate_head(activations, hook, head_idx, cache):
    # activation has shape (batch, pos, head, d_head)
    activations[:, :, head_idx] = cache[hook.name][:, :, head_idx].mean(0)[None, ...]
    # activations[:, :, head_idx] = 0.
    return activations

In [41]:
corrupted_logit_diffs = torch.zeros((model_tl.cfg.n_layers, model_tl.cfg.n_heads, n_samples))
with torch.no_grad():
    for layer, head in tqdm(list(product(range(model_tl.cfg.n_layers), range(model_tl.cfg.n_heads)))):
        model_tl.reset_hooks(including_permanent=True)
        hook_fn = partial(mean_ablate_head, head_idx=head, cache=cache_tl)
        model_tl.add_hook(utils.get_act_name("v", layer), hook_fn)
        corrupted_logits = model_tl(tokens_tl)
        corrupted_logit_diff = compute_logit_diff(corrupted_logits, answer_tokens, average=False)
        corrupted_logit_diffs[layer, head] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_tl = (corrupted_logit_diffs - logit_diff_tl.cpu()).mean(-1)

 14%|█▍        | 20/144 [00:00<00:01, 65.25it/s]

100%|██████████| 144/144 [00:02<00:00, 67.16it/s]


In [42]:
px.imshow(attribution_score_tl.detach().numpy(), title="Attribution score for attention heads (TransformerLens)", labels={"x": "Head", "y": "Layer"}, width=500, height=500, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [43]:
cache_v = torch.zeros((model_hf.config.n_layer, model_hf.config.n_head, 11, 64))

In [44]:
def get_cache_v(module: nn.Module, input, output, layer_idx: int, head_idx: int):
    """
    Performs mean ablation on head_idx
    This function is expected as a hook for the c_attn method of GPT2Attention.

    module -> c_proj module inside `GPT2Attention`
    output -> tensor of shape (batch_size, seq_len, 3*d_model)
    """
    d_head = int(model_hf.config.n_embd / model_hf.config.n_head)
    v_avg = output[:, :, 2*model_hf.config.n_embd + head_idx * d_head:2*model_hf.config.n_embd + (head_idx + 1) * d_head].mean(0)
    cache_v[layer_idx, head_idx] = v_avg.detach().clone()

In [45]:
with torch.no_grad():
    for layer, head in tqdm(list(product(range(model_hf.config.n_layer), range(model_hf.config.n_head)))):
        hook = model_hf.transformer.h[layer].attn.c_attn.register_forward_hook(partial(get_cache_v, layer_idx=layer, head_idx=head))
        corrupted_logits = model_hf(tokens_hf)["logits"]
        hook.remove()

100%|██████████| 144/144 [00:00<00:00, 193.74it/s]


In [46]:
def mean_ablate_head_hf(module: nn.Module, input, output, layer_idx: int, head_idx: int, cache_v):
    """
    Performs mean ablation on head_idx
    This function is expected as a hook for the c_attn method of GPT2Attention.

    module -> c_proj module inside `GPT2Attention`
    output -> tensor of shape (batch_size, seq_len, 3*d_model)
    """
    d_head = int(model_hf.config.n_embd / model_hf.config.n_head)
    output[:, :, 2*model_hf.config.n_embd + head_idx * d_head:2*model_hf.config.n_embd + (head_idx + 1) * d_head] = \
        cache_v[layer_idx, head_idx][None, ...]
    return output
    

In [47]:
corrupted_logit_diffs = torch.zeros((model_hf.config.n_layer, model_hf.config.n_head, n_samples))
with torch.no_grad():
    for layer, head in tqdm(list(product(range(model_hf.config.n_layer), range(model_hf.config.n_head)))):
        hook = model_hf.transformer.h[layer].attn.c_attn.register_forward_hook(partial(mean_ablate_head_hf, layer_idx=layer, head_idx=head, cache_v=cache_v))
        corrupted_logits = model_hf(tokens_hf)["logits"]
        hook.remove()
        corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), answer_tokens, average=False)
        corrupted_logit_diffs[layer, head] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_hf = (corrupted_logit_diffs - logit_diff_hf.cpu()).mean(-1)

100%|██████████| 144/144 [00:00<00:00, 185.42it/s]


In [48]:
px.imshow(attribution_score_hf.detach().numpy(), title="Attribution score for attention heads (HuggingFace)", labels={"x": "Head", "y": "Layer"}, width=500, height=500, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [49]:
sorted_heads_tl = sort_Nd_tensor(attribution_score_tl)
sorted_heads_hf = sort_Nd_tensor(attribution_score_hf)

In [50]:
for x, y in zip(sorted_heads_tl, sorted_heads_hf):
    print(x, y)

[8, 11] [8, 11]
[10, 10] [10, 10]
[4, 7] [4, 7]
[11, 6] [11, 6]
[8, 3] [8, 3]
[8, 10] [8, 10]
[11, 11] [11, 11]
[4, 4] [4, 4]
[2, 9] [2, 9]
[1, 5] [1, 5]
[11, 4] [11, 4]
[0, 10] [0, 10]
[3, 10] [3, 10]
[5, 11] [5, 11]
[0, 2] [0, 2]
[0, 0] [0, 0]
[1, 4] [1, 4]
[5, 2] [5, 2]
[6, 6] [6, 6]
[5, 9] [5, 9]
[9, 3] [9, 3]
[6, 7] [6, 7]
[2, 4] [2, 4]
[7, 9] [7, 9]
[6, 4] [6, 4]
[7, 5] [7, 5]
[5, 8] [5, 8]
[1, 3] [1, 3]
[11, 8] [11, 8]
[2, 7] [2, 7]
[9, 8] [9, 8]
[4, 5] [4, 5]
[8, 6] [8, 6]
[9, 2] [9, 2]
[11, 3] [11, 3]
[1, 7] [1, 7]
[8, 9] [8, 9]
[7, 8] [7, 8]
[0, 8] [0, 8]
[10, 3] [10, 3]
[2, 2] [2, 2]
[10, 4] [10, 4]
[1, 2] [1, 2]
[9, 6] [9, 6]
[3, 8] [3, 8]
[7, 3] [7, 3]
[7, 6] [7, 6]
[0, 4] [0, 4]
[11, 9] [11, 9]
[7, 11] [7, 11]
[7, 10] [7, 10]
[1, 8] [1, 8]
[4, 6] [4, 6]
[5, 0] [5, 0]
[1, 9] [1, 9]
[3, 11] [3, 11]
[6, 0] [6, 0]
[8, 2] [8, 2]
[9, 10] [9, 10]
[8, 7] [8, 7]
[7, 1] [7, 1]
[1, 6] [1, 6]
[8, 8] [8, 8]
[2, 11] [2, 11]
[9, 7] [9, 7]
[6, 5] [6, 5]
[10, 5] [10, 5]
[6, 11] [6, 11]
[4

### Replicate the pruning process

Now, we're going to progressively remove attention heads and measure the performance. If we get similar results with both models, repeat the experiment by removing the weights.

In [51]:
heads_to_patch = []
logit_diffs = []
std_logit_diffs = []

for circuit_head in reversed(sorted_heads_tl):
    heads_to_patch.append(circuit_head)
    model_tl.reset_hooks(including_permanent=True)
    for layer_i, head_i in heads_to_patch:
        hook_fn = partial(mean_ablate_head, head_idx=head_i, cache=cache_tl)
        model_tl.add_hook(utils.get_act_name("v", layer_i), hook_fn)
    circuit_logits = model_tl(tokens_tl)
    model_tl.reset_hooks(including_permanent=True)

    logit_diff = compute_logit_diff(circuit_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_tl = torch.stack(logit_diffs, dim=0)
std_logit_diffs_tl = torch.stack(std_logit_diffs, dim=0)

In [52]:
labels = [f"{layer}.{head}" for layer, head in reversed(sorted_heads_tl)]
fig = px.line(x = labels, y=logit_diffs_tl.cpu().numpy(), error_y=std_logit_diffs_tl.cpu().numpy())
fig.add_hline(y=logit_diff_tl.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [53]:
heads_to_patch = []
logit_diffs = []
std_logit_diffs = []

hooks = []
    
for circuit_head in reversed(sorted_heads_hf):
    heads_to_patch.append(circuit_head)
    for layer_i, head_i in heads_to_patch:
        hooks.append(model_hf.transformer.h[layer_i].attn.c_attn.register_forward_hook(partial(mean_ablate_head_hf, layer_idx=layer_i, head_idx=head_i, cache_v=cache_v)))
    circuit_logits = model_hf(tokens_hf)["logits"]
    [hook.remove() for hook in hooks]

    logit_diff = compute_logit_diff(circuit_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_hf = torch.stack(logit_diffs, dim=0)
std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)

In [54]:
labels = [f"{layer}.{head}" for layer, head in reversed(sorted_heads_hf)]
fig = px.line(x = labels, y=logit_diffs_hf.cpu().numpy(), error_y=std_logit_diffs_hf.cpu().numpy())
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()