In [1]:
import random
import itertools
from functools import partial
from typing import Optional, Tuple, Any, Union
import pickle
from tqdm import tqdm

import torch
import torch.nn as nn
from torch import Tensor

from transformers import AutoTokenizer, AutoModelForCausalLM

import plotly.express as px

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def compute_logit_diff(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long)
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

#### Load dataset & model

In [3]:
with open("acronyms_2_common.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in set(f.read().splitlines())]))

# take a subset of the dataset (we do this because VRAM limitations)
n_patching = 100
n_val = 100
# giga-cursed way of sampling from the dataset
patching_prompts, patching_acronyms = prompts[:n_patching], acronyms[:n_patching]
val_prompts, val_acronyms = prompts[n_patching:n_patching+n_val], acronyms[n_patching:n_patching+n_val]

In [16]:
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False)

The QKV vectors are computed for every attn head in the same layer via a 1D convolution.
If input has shape (batch_size, seq_len, d_model), Conv1D(input) has shape (batch_size, seq_len, 3*d_model). Then this is split to obtain the QKV vectors FOR EVERY HEAD (i.e. each Q,K,V has shape (batch_size, seq_len, d_model)). Then, they call `_split_heads` to divide the QKV vectors in terms of different heads -> (batch_size, n_head, seq_len, d_model / n_head). This is just a reshape, so we could modify the correct indices to patch a single head.

In [5]:
patching_tokens = tokenizer(patching_prompts, return_tensors="pt")["input_ids"]
patching_tokens = torch.cat([torch.ones((patching_tokens.shape[0], 1), dtype=torch.long) * tokenizer.bos_token_id, patching_tokens], dim=1)
val_tokens = tokenizer(val_prompts, return_tensors="pt")["input_ids"]
val_tokens = torch.cat([torch.ones((val_tokens.shape[0], 1), dtype=torch.long) * tokenizer.bos_token_id, val_tokens], dim=1)

patching_answer_tokens = tokenizer(patching_acronyms, return_tensors="pt")["input_ids"]
val_answer_tokens = tokenizer(val_acronyms, return_tensors="pt")["input_ids"]

In [18]:
patching_output = model(patching_tokens)
patching_logits = patching_output["logits"]
val_output = model(val_tokens)
val_logits = val_output["logits"]

patching_logit_diff = compute_logit_diff(patching_logits, patching_answer_tokens, average=False)[..., -1].mean()
val_logit_diff = compute_logit_diff(val_logits, val_answer_tokens, average=False)[..., -1].mean()
patching_logit_diff.item(), val_logit_diff.item()

(3.2218189239501953, 3.4839301109313965)

In [7]:
# retrieve the average head activation on the patching dataset
seq_len = patching_tokens.shape[-1]
d_head = int(model.config.n_embd / model.config.n_head)
average_head_acts = torch.zeros((model.config.n_layer, seq_len, model.config.n_head, d_head)) # (n_layers, seq_len, n_head, d_head)

def cache_head_acts(module, input, output, layer: int):
    d_model = int(output.shape[-1] / 3)
    v = output[:, :, 2*d_model:]
    new_shape = v.shape[:-1] + (12, int(d_model / 12)) # (batch_size, seq_len, n_head, d_head)
    v = v.view(new_shape) 
    average_head_acts[layer] = v.mean(0).detach().clone() # (seq_len, n_head, d_head)

hooks = []
for layer in range(model.config.n_layer):
    hook_fn = partial(cache_head_acts, layer=layer)
    hooks.append(model.transformer.h[layer].attn.c_attn.register_forward_hook(hook_fn))
model(patching_tokens)
[hook.remove() for hook in hooks];

In the cell above, we retrieve the V vectors of each head. However, we want to obtain what each head writes into the residual stream. To do so, we have to focus on the projection step. This step receives a vector of shape `(batch_size, seq_len, n_heads*d_heads = d_model)` where the output of each head is arranged sequentially, and then it is projected to the residual space. To obtain what head `head_i` writes into the residual stream, we just have to mask off every other head just before the projection, and then project.

In [19]:
# retrieve the average head activation on the patching dataset
seq_len = patching_tokens.shape[-1]
d_head = int(model.config.n_embd / model.config.n_head)
average_head_acts = torch.zeros((model.config.n_layer, model.config.n_head, seq_len, model.config.n_embd)) # (n_layers, n_head, seq_len, d_model)

d_head = int(model.config.n_embd / model.config.n_head)

def cache_head_acts(module, input, output, layer: int, head: int):
    
    x = input[0].detach().clone() # (batch_size, seq_len, n_heads*d_head = d_model)
    # Set to zero the activations that do not belong to `head`
    x = x.view(x.size()[:-1] + (model.config.n_head, d_head)) # (batch_size, seq_len, n_heads, d_head)
    mask = torch.ones_like(x)
    mask[:, :, head] = 0.
    x[mask.bool()] = 0.
    x = x.view(x.size()[:-2] + (model.config.n_head * d_head,)) # back to the original shape
    # Now project into the residual space
    # This is literally a copy-paste of the forward method of the module.
    # The thing is that if I call it again, it will the hook infinitely
    # This is ultra-cursed but it works for now
    size_out = x.size()[:-1] + (module.nf,)
    x = torch.addmm(module.bias, x.view(-1, x.size(-1)), module.weight)
    x = x.view(size_out)
    average_head_acts[layer, head] = x.mean(0).detach().clone() # (seq_len, n_head, d_model)
hooks = []
for layer, head in itertools.product(range(model.config.n_layer), range(model.config.n_head)):
    hook_fn = partial(cache_head_acts, layer=layer, head=head)
    hooks.append(model.transformer.h[layer].attn.c_proj.register_forward_hook(hook_fn))
model(patching_tokens)
[hook.remove() for hook in hooks];
torch.save(average_head_acts, "average_head_acts.pt")

In [28]:
def mean_ablate_residual(module, input, output):
    output = list(output)
    output[0] = output[0].mean(0)[None, ...]
    return tuple(output)

def mean_ablate_head(module, input, output, layer_idx:int, head_idx: int, cache: Optional[Tensor] = None):
    """
    The output has shape (batch_size, seq_len, 3*d_model) and contains
    the concatenated QKV vectors (for every head), i.e. v = output[:, :, 2*d_model:]

    If `cache=None`, replace with the mean of the current activation. If not, `cache` is 
    Tensor of shape (n_layers, seq_len, n_head, d_head) containing the mean activations
    computed on another dataset.
    """
    d_model = int(output.shape[-1] / 3)
    v = output[:, :, 2*d_model:]
    new_shape = v.shape[:-1] + (12, int(d_model / 12)) # (batch_size, seq_len, n_head, d_head)
    v = v.view(new_shape)
    if cache is None:
        v[:, :, head_idx] = v[:, :, head_idx].mean(0)[None, ...]
    else:
        v[:, :, head_idx] = cache[layer_idx, :, head_idx][None, ...]
    


# for i in range(12):
#     hooks.append(model.transformer.h[i].attn.register_forward_hook(mean_ablate_residual))

# heads_to_patch = [(8, 11), (9, 9), (10, 10)]

# for layer, head in heads_to_patch:
#     hook_fn = partial(mean_ablate_head, head_idx=head)
#     hooks.append(model.transformer.h[layer].attn.c_attn.register_forward_hook(hook_fn))

In [29]:
circuit_heads = [[], [8, 11], [9, 9], [10, 10], [11, 4], [5, 8], [4, 11], [2, 2], [1, 0]]

circuit_heads_i = []
logit_diffs = []
std_logit_diffs = []

for circuit_head in circuit_heads:
    circuit_heads_i.append(circuit_head)
    heads_to_patch = [[a, b] for a, b in itertools.product(range(0, 12), range(12)) if [a, b] not in circuit_heads_i]

    hooks = []
    for layer_i, head_i in heads_to_patch:
        hook_fn = partial(mean_ablate_head, layer_idx=layer_i, head_idx=head_i, cache=average_head_acts)
        hooks.append(model.transformer.h[layer_i].attn.c_attn.register_forward_hook(hook_fn))
    circuit_logits = model(val_tokens)["logits"]
    [hook.remove() for hook in hooks]

    logit_diff = compute_logit_diff(circuit_logits, val_answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs = torch.stack(logit_diffs, dim=0)
std_logit_diffs = torch.stack(std_logit_diffs, dim=0)

In [30]:
labels = ["None"] + [f"{layer}.{head}" for layer, head in circuit_heads[1:]]
fig = px.line(x = labels, y=logit_diffs.cpu().numpy(), error_y=std_logit_diffs.cpu().numpy(), width=800, height=400)
fig.add_hline(y=val_logit_diff.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [13]:
with open("ordered_heads.pkl", "rb") as handle:
    circuit_heads = pickle.load(handle)

In [12]:
heads_to_patch = []
logit_diffs = []
std_logit_diffs = []

for circuit_head in circuit_heads:
    heads_to_patch.append(circuit_head)
    hooks = []
    for layer_i, head_i in heads_to_patch:
        hook_fn = partial(mean_ablate_head, layer_idx=layer_i, head_idx=head_i, cache=average_head_acts)
        hooks.append(model.transformer.h[layer_i].attn.c_attn.register_forward_hook(hook_fn))
    circuit_logits = model(val_tokens)["logits"]
    [hook.remove() for hook in hooks]

    logit_diff = compute_logit_diff(circuit_logits, val_answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs = torch.stack(logit_diffs, dim=0)
std_logit_diffs = torch.stack(std_logit_diffs, dim=0)

In [13]:
labels = [f"{layer}.{head}" for layer, head in circuit_heads]
fig = px.line(x = labels, y=logit_diffs.cpu().numpy(), error_y=std_logit_diffs.cpu().numpy())
fig.add_hline(y=val_logit_diff.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

# Repeat the above experiment, but by using the prune function to iteratively remove attention heads

In [20]:
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False)

initial_parameters = model.num_parameters()

In [21]:
class AttentionIdentity(nn.Module):
    """
    Placeholder for the GPT2Attention layer, literally does nothing
    """
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        return (torch.zeros_like(hidden_states), None)

In [22]:
average_head_acts = torch.load("average_head_acts.pt")

def add_average_act(module, input, output, avg_act):
    return (output[0] + avg_act, None)

In [23]:
heads_to_patch = []
logit_diffs = []
std_logit_diffs = []

hooks = []

pbar = tqdm(circuit_heads)
for layer_i, head_i in pbar:
    # if the layer only has one head, remove it directly
    if model.transformer.h[layer_i].attn.num_heads <= 1:
        #print(f"Removing layer {layer_i}")
        model.transformer.h[layer_i].attn = AttentionIdentity()
    else:
        model.transformer.h[layer_i].attn.prune_heads([head_i])
    # after removing the head, include a hook that adds its mean activation

    hook_fn = partial(add_average_act, avg_act=average_head_acts[layer_i][head_i])
    hooks.append(model.transformer.h[layer_i].attn.register_forward_hook(hook_fn))
     
    circuit_logits = model(val_tokens)["logits"]
        
    logit_diff = compute_logit_diff(circuit_logits, val_answer_tokens, average=False)
    pbar.set_description(f"Avg. Logit Diff. = {logit_diff[..., -1].mean(0).item():.2f}")
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs = torch.stack(logit_diffs, dim=0)
std_logit_diffs = torch.stack(std_logit_diffs, dim=0)

Avg. Logit Diff. = -0.84:   5%|▍         | 7/144 [00:08<02:46,  1.21s/it]


KeyboardInterrupt: 

In [138]:
labels = [f"{layer}.{head}" for layer, head in circuit_heads]
fig = px.line(x = labels, y=logit_diffs.cpu().numpy(), error_y=std_logit_diffs.cpu().numpy())
fig.add_hline(y=val_logit_diff.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.show()

In [60]:
final_parameters = model.num_parameters()
p_reduction = (initial_parameters - final_parameters) / initial_parameters
print(f"{p_reduction:.2f}")

0.23
