Similar to what I've done in `mlps.ipynb`, I'm going to ablate complete attention layers, both in TL and HF (checking that I obtain the same/similar results). As I'm struggling when ablating individual heads, first I'm going to experiment with complete attention layers.

In [52]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm
import random

import plotly.express as px

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

import einops

from typing import Literal, Optional, Tuple, Union
from jaxtyping import Float

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils


torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [25]:
def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def compute_logit_diff(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

In [26]:
model_tl = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model_tl.set_use_hook_mlp_in(True)
model_tl.set_use_split_qkv_input(True)
model_tl.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [27]:
tokenizer_hf = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model_hf = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).to(device)
initial_parameters = model_hf.num_parameters()

In [28]:
with open("acronyms_2_common.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = 250
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [29]:
tokens_tl = model_tl.to_tokens(prompts)
answer_tokens = model_tl.to_tokens(acronyms, prepend_bos=False)

logits_tl, cache_tl = model_tl.run_with_cache(tokens_tl)

In [30]:
tokens_hf = tokenizer_hf(prompts, return_tensors="pt")["input_ids"]
tokens_hf = torch.cat([torch.ones((tokens_hf.shape[0], 1), dtype=torch.long) * tokenizer_hf.bos_token_id, tokens_hf], dim=1).to(device)
logits_hf = model_hf(tokens_hf)["logits"]

In [31]:
logit_diff_tl = compute_logit_diff(logits_tl, answer_tokens, average=False)[..., -1].mean()
logit_diff_hf = compute_logit_diff(logits_hf, answer_tokens, average=False)[..., -1].mean()

print(f"Logit Diff. TL: {logit_diff_tl.item():.4f}, HF: {logit_diff_hf.item():.4f}, Allclose: {torch.allclose(logit_diff_tl, logit_diff_hf)}")

Logit Diff. TL: 3.5070, HF: 3.5070, Allclose: True


In [32]:
cache_tl[utils.get_act_name("attn_out", 0)].shape

torch.Size([250, 11, 768])

In [33]:
def mean_ablate_attn(activations, hook, cache):
    # activation has shape (batch, pos, d_model)
    activations[:, :, :] = cache[hook.name][:, :, :].mean(0)[None, ...]
    return activations

In [34]:
corrupted_logit_diffs = torch.zeros((model_tl.cfg.n_layers, n_samples))
with torch.no_grad():
    for layer in tqdm(range(model_tl.cfg.n_layers)):
        model_tl.reset_hooks(including_permanent=True)
        hook_fn = partial(mean_ablate_attn, cache=cache_tl)
        model_tl.add_hook(utils.get_act_name("attn_out", layer), hook_fn)
        corrupted_logits = model_tl(tokens_tl)
        corrupted_logit_diff = compute_logit_diff(corrupted_logits, answer_tokens, average=False)
        corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_tl = (corrupted_logit_diffs - logit_diff_tl.cpu()).mean(-1)

  0%|          | 0/12 [00:00<?, ?it/s]

100%|██████████| 12/12 [00:01<00:00, 11.13it/s]


In [35]:
px.imshow(attribution_score_tl.unsqueeze(0).detach().numpy(), title="Attribution score for attention heads (TransformerLens)", 
    labels={"x": "Layer", "y": ""}, width=800, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

The results make sense: attn layers 8 and 10 contain the most important heads, so mean ablating them breaks performance. Layer 1 decreases performance but is negligible, whereas the rest of the layers even increase the performance. Let's replicate the examples on HF

In [36]:
model_hf.transformer.h[0].attn

GPT2Attention(
  (c_attn): Conv1D()
  (c_proj): Conv1D()
  (attn_dropout): Dropout(p=0.1, inplace=False)
  (resid_dropout): Dropout(p=0.1, inplace=False)
)

In [37]:
def get_cache_attn(module: GPT2Attention, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Caches the mean activation of the attention layer. `cache` has shape (n_layer, d_model)
    output -> tensor of shape (batch_size, seq_len, d_model)
    """
    cache[layer_idx] = output[0].mean(0).detach().clone()

cache_attn = torch.zeros((model_hf.config.n_layer, 11, model_hf.config.n_embd)).cuda()

for layer in range(model_hf.config.n_layer):  
    hook_fn = partial(get_cache_attn, layer_idx=layer, cache=cache_attn)
    hook = model_hf.transformer.h[layer].attn.register_forward_hook(hook_fn)
    corrupted_logits = model_hf(tokens_hf)["logits"]
    hook.remove()

In [39]:
def mean_ablate_attn_hf(module: GPT2Attention, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Performs mean ablation on an attention layer
    This function 

    output -> tuple containing (output[0], None)
    output[0] -> tensor of shape (batch_size, seq_len, d_model)
    cache -> tensor of shape (n_layers, seq_len, d_model)
    """
    # output = cache[layer_idx][None, ...]
    return (cache[layer_idx][None, ...], None)

In [40]:
corrupted_logit_diffs = torch.zeros((model_hf.config.n_layer, n_samples))
with torch.no_grad():
    for layer in tqdm(range(model_hf.config.n_layer)):
        hook_fn = partial(mean_ablate_attn_hf, layer_idx=layer, cache=cache_attn)
        hook = model_hf.transformer.h[layer].attn.register_forward_hook(hook_fn)
        corrupted_logits = model_hf(tokens_hf)["logits"]
        hook.remove()
        corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), answer_tokens, average=False)
        corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_hf = (corrupted_logit_diffs - logit_diff_hf.cpu()).mean(-1)

100%|██████████| 12/12 [00:00<00:00, 43.01it/s]


In [41]:
px.imshow(attribution_score_hf.unsqueeze(0).detach().numpy(), title="Attribution score for attention heads (HuggingFace)", 
    labels={"x": "Layer", "y": ""}, width=800, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [42]:
sorted_attns_tl = attribution_score_tl.sort().indices
sorted_attns_hf = attribution_score_hf.sort().indices

In [43]:
for x, y in zip(sorted_attns_tl, sorted_attns_hf):
    print(x.item(), y.item())

8 8
10 10
1 1
7 7
4 4
2 2
6 6
3 3
11 11
0 0
9 9
5 5


# Replicate the pruning process

In [44]:
attns_to_patch = []
logit_diffs = []
std_logit_diffs = []

for layer in reversed(sorted_attns_tl):
    attns_to_patch.append(layer)
    model_tl.reset_hooks(including_permanent=True)
    for layer_i in attns_to_patch:
        hook_fn = partial(mean_ablate_attn, cache=cache_tl)
        model_tl.add_hook(utils.get_act_name("attn_out", layer_i), hook_fn)
    ablated_logits = model_tl(tokens_tl)
    model_tl.reset_hooks(including_permanent=True)

    logit_diff = compute_logit_diff(ablated_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_tl = torch.stack(logit_diffs, dim=0)
std_logit_diffs_tl = torch.stack(std_logit_diffs, dim=0)

In [45]:
labels = [f"Attn. {layer}" for layer in reversed(sorted_attns_tl)]
fig = px.line(x = labels, y=logit_diffs_tl.cpu().numpy(), error_y=std_logit_diffs_tl.cpu().numpy())
fig.add_hline(y=logit_diff_tl.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [48]:
attns_to_patch = []
logit_diffs = []
std_logit_diffs = []

hooks = []
    
for layer in reversed(sorted_attns_hf):
    attns_to_patch.append(layer)
    for layer_i in attns_to_patch:
        hook_fn = partial(mean_ablate_attn_hf, layer_idx=layer_i, cache=cache_attn)
        hooks.append(model_hf.transformer.h[layer_i].attn.register_forward_hook(hook_fn))
    ablated_logits = model_hf(tokens_hf)["logits"]
    [hook.remove() for hook in hooks]

    logit_diff = compute_logit_diff(ablated_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_hf = torch.stack(logit_diffs, dim=0)
std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)

In [50]:
labels = [f"Attn. {layer}" for layer in reversed(sorted_attns_hf)]
fig = px.line(x = labels, y=logit_diffs_hf.cpu().numpy(), error_y=std_logit_diffs_hf.cpu().numpy())
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [53]:
class PrunedAttention(nn.Module):
    """
    Layer that simply returns the vector `mean_activation`, which is expected
    to have shape (seq_len, d_model) and represents the mean activation of the MLP
    that is going to be replaced.
    """
    def __init__(self, mean_activation: torch.Tensor):
        super().__init__()
        self.mean_activation = mean_activation

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        return (self.mean_activation[None, ...], None)

In [54]:
logit_diffs = []
std_logit_diffs = []

num_parameters = []

for layer in reversed(sorted_attns_hf):
    model_hf.transformer.h[layer].attn = PrunedAttention(mean_activation=cache_attn[layer])
    num_parameters.append(model_hf.num_parameters())
    circuit_logits = model_hf(tokens_hf)["logits"]
    logit_diff = compute_logit_diff(circuit_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_hf = torch.stack(logit_diffs, dim=0)
std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)

In [56]:
labels = [f"Attn. {layer}" for layer in reversed(sorted_attns_hf)]
fig = px.line(x = labels, y=logit_diffs_hf.cpu().numpy(), error_y=std_logit_diffs_hf.cpu().numpy())
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [57]:
p_parameters = [x/initial_parameters for x in num_parameters]
fig = px.line(x = labels, y=p_parameters)
fig.show()