In [39]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm

import plotly.express as px

import numpy as np
import torch
from torch import Tensor

import einops

from typing import Literal
from jaxtyping import Float

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils

from ioi_dataset import IOIDataset, format_prompt, make_table


device = "cuda" if torch.cuda.is_available() else "cpu"

The idea is to remove the components (i.e. attention heads and MLPs) that do not contribute to a specific task of study, i.e. they dont belong to the specific underlying circuit. In order to do this, we can **patch** the activations of a component with activations from a corrupted run. If the performance of the specific task (as measured by a defined metric) degrades, this means that this component is relevant to the task. Therefore, we can remove the components that are outside of the circuit and we will maintain the task-specific performance while drastically reducing the number of components! In summary, if $A$ represents the activations of a given component and we use the logit difference to assess the performance on the task (higher is better), we compute the following score:

$$\Delta \text{logit\_diff} = \text{logit\_diff}(x_{clean} | \text{do} (A = a_{corr})) - \text{logit\_diff}(x_{clean})$$

if $\Delta \text{logit\_diff} < 0$, it means that patching that component decreases the performance on the task of study. Hence, we can keep the $k$ components with the lowest score, i.e. the most relevant when it comes to the task, and effectively remove those that not. This is a way to compress a general LLM into a smaller model for a specific task, as well as keeping the mechanistic interpretation of the circuit. In other words, it differs from the other compression models because it preserves the components. 


Now, when it comes to corrupting the activations, we have two methods:
- **Zero ablation:** Replacing the activations with zeros (i.e. $a_{corr} = \mathbf{0}$). This is inspired by neuroscience, but it can send the model to off-distribution, as well as selecting too many components as important. In other words, other techniques such as resampling corruption are useful to determine the exact underlying circuit (e.g. IOI circuit) but do not identify subcircuits or components that are auxiliary to the task. For example, a head that outputs "this sentence might be suited for the IOI task" might be important, but corrupting by resampling will not detect it. Hence, when ablating the rest of the heads, we will not include some important components and I expect this to fail.  
- Replace with the **mean activation**: Another option is to set the activation to the mean (i.e. $a_{corr} = \frac{1}{n} \sum a$). I expect this to be more principled, as the ablation is not too drastic (in addition, if we use the Taylor approximation, I expect this to be benefitial). The only thing that we have to take into account when performing this ablation, is that in the compressing step, we will have to add a sort of "bias" to the residual stream to include the mean contribution of these heads. 

# TO DO
- Perform ACDC at node level with zero ablation, sort the components, then remove components from least to most important and check the drop performance.
    - How many components are we able to remove while keeping a reasonable performance?
- Perform ACDC with mean ablation, compare with the previous experiment.
- Perform the previous two experiments, but with the attribution patching approximation. I expect this to perform worse when zero ablating.

# Model, dataset & metric setup

In [2]:
model = HookedTransformer.from_pretrained(
    'gpt2-small',
    center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    device=device,
)
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
N = 25
clean_dataset = IOIDataset(
    prompt_type='mixed',
    N=N,
    tokenizer=model.tokenizer,
    prepend_bos=False,
    seed=1,
    device=device
)
corr_dataset = clean_dataset.gen_flipped_prompts('ABC->XYZ, BAB->XYZ')

make_table(
  colnames = ["IOI prompt", "IOI subj", "IOI indirect obj", "ABC prompt"],
  cols = [
    map(format_prompt, clean_dataset.sentences),
    model.to_string(clean_dataset.s_tokenIDs).split(),
    model.to_string(clean_dataset.io_tokenIDs).split(),
    map(format_prompt, corr_dataset.sentences),
  ],
  title = "Sentences from IOI vs ABC distribution",
)

In [4]:
def ave_logit_diff(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    ioi_dataset: IOIDataset,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    # Get logits for indirect objects
    io_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.io_tokenIDs]
    s_logits = logits[range(logits.size(0)), ioi_dataset.word_idx['end'], ioi_dataset.s_tokenIDs]
    # Get logits for subject
    logit_diff = io_logits - s_logits
    return logit_diff if per_prompt else logit_diff.mean()

with torch.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset).item()
    corrupt_logit_diff = ave_logit_diff(corrupt_logits, corr_dataset).item()

print(f'Avg. logit diff. on clean dataset: {clean_logit_diff:.4f}, Avg. logit diff. on corrupt dataset: {corrupt_logit_diff:.4f}')

Avg. logit diff. on clean dataset: 2.8052, Avg. logit diff. on corrupt dataset: 1.3848


In [5]:
with torch.no_grad():
    clean_logits, clean_cache = model.run_with_cache(clean_dataset.toks)
    clean_logit_diff = ave_logit_diff(clean_logits, clean_dataset, per_prompt=True)

In [6]:
corrupted_logit_diffs = torch.zeros((model.cfg.n_layers, model.cfg.n_heads, N))
with torch.no_grad():
    for head, layer in tqdm(list(product(range(model.cfg.n_heads), range(model.cfg.n_layers)))):
        def zero_ablate_head(act, hook):
            """
            Zero-ablate the output of a head. This will be executed on 
            attn.hook_result, which has shape [batch seq n_heads d_model]
            """
            act[:, :, head, :] = 0.
            return act
        model.reset_hooks(including_permanent=True)
        model.add_hook(f"blocks.{layer}.attn.hook_result", zero_ablate_head, "fwd")
        corrupted_logits = model(clean_dataset.toks)
        corrupted_logit_diff = ave_logit_diff(corrupted_logits, clean_dataset, per_prompt=True)
        corrupted_logit_diffs[layer, head] = corrupted_logit_diff

attribution_score = (corrupted_logit_diffs - clean_logit_diff.cpu()).mean(-1)

100%|██████████| 144/144 [00:13<00:00, 10.86it/s]


In [7]:
px.imshow(attribution_score, title="Attribution score for attention heads (zero-ablation)", labels={"x": "Head", "y": "Layer"}, width=500, height=500, color_continuous_scale="RdBu")

In [14]:
def topk_of_Nd_tensor(tensor, k):
    '''
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    '''
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

In [16]:
sorted_heads = sort_Nd_tensor(attribution_score, descending=True)

In [25]:
ablated_logit_diffs = torch.zeros(len(sorted_heads))
ablated_logit_diffs_std = torch.zeros(len(sorted_heads))


model.reset_hooks(including_permanent=True)

with torch.no_grad():
    for i, (layer, head) in tqdm(list(enumerate(sorted_heads))):
        def zero_ablate_head(act, hook):
            """
            Zero-ablate the output of a head. This will be executed on 
            attn.hook_result, which has shape [batch seq n_heads d_model]
            """
            act[:, :, head, :] = 0.
            return act
        model.add_hook(f"blocks.{layer}.attn.hook_result", zero_ablate_head, "fwd")
        ablated_logits = model(clean_dataset.toks)
        ablated_logit_diff = ave_logit_diff(ablated_logits, clean_dataset, per_prompt=True)
        ablated_logit_diffs[i] = ablated_logit_diff.mean()
        ablated_logit_diffs_std[i] = ablated_logit_diff.std()        

  0%|          | 0/144 [00:00<?, ?it/s]

100%|██████████| 144/144 [00:13<00:00, 10.84it/s]


In [26]:
px.line(ablated_logit_diffs, error_y=ablated_logit_diffs_std)

We didn't obtain the result that we expected. This can be cause because (i) removing a head changes the impact that the others head have (ii) zero-ablation is too drastic. For now, let's try the following:

1. Compute the attribution score, ablate the one with the highest value (i.e. the less relevant)
2. Re-compute the attribution score with the ablated head.
3. Repeat this 144 times.

In [None]:
def zero_ablate_head(act, hook, head):
                """
                Zero-ablate the output of a head. This will be executed on 
                attn.hook_result, which has shape [batch seq n_heads d_model]
                """
                act[:, :, head, :] = 0.
                return act


head_list = list(product(range(model.cfg.n_layers), range(model.cfg.n_heads)))
ablated_heads = []

with torch.no_grad():
    while head_list:
        # Compute attribution scores
        attr_scores = torch.zeros(len(head_list))
        for i, (layer, head) in tqdm(list(enumerate(head_list))):
            model.reset_hooks(including_permanent=True)

            for ablated_layer, ablated_head in ablated_heads:
                model.add_hook(f"blocks.{layer}.attn.hook_result", partial(zero_ablate_head, head=ablated_head), "fwd")

            model.add_hook(f"blocks.{layer}.attn.hook_result", partial(zero_ablate_head, head=head), "fwd")
            corrupted_logits = model(clean_dataset.toks)
            corrupted_logit_diff = ave_logit_diff(corrupted_logits, clean_dataset, per_prompt=True)
            attr_scores[i] = (corrupted_logit_diff - clean_logit_diff).mean()
        head_to_ablate = head_list[torch.argmax(attr_scores).item()]
        print(head_to_ablate)
        ablated_heads.append(head_to_ablate)
        head_list.remove(head_to_ablate)    

In [44]:
ablated_logit_diffs = torch.zeros(len(sorted_heads))
ablated_logit_diffs_std = torch.zeros(len(sorted_heads))


model.reset_hooks(including_permanent=True)

with torch.no_grad():
    for i, (layer, head) in tqdm(list(enumerate(ablated_heads))):
        def zero_ablate_head(act, hook):
            """
            Zero-ablate the output of a head. This will be executed on 
            attn.hook_result, which has shape [batch seq n_heads d_model]
            """
            act[:, :, head, :] = 0.
            return act
        model.add_hook(f"blocks.{layer}.attn.hook_result", zero_ablate_head, "fwd")
        ablated_logits = model(clean_dataset.toks)
        ablated_logit_diff = ave_logit_diff(ablated_logits, clean_dataset, per_prompt=True)
        ablated_logit_diffs[i] = ablated_logit_diff.mean()
        ablated_logit_diffs_std[i] = ablated_logit_diff.std()   

100%|██████████| 144/144 [00:13<00:00, 10.96it/s]


In [45]:
px.line(ablated_logit_diffs, error_y=ablated_logit_diffs_std)