In [21]:
import pickle
import random
import itertools
from itertools import product
from functools import partial
from tqdm import tqdm

import plotly.express as px
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import einops

from typing import Literal
from jaxtyping import Float

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils

from easy_transformer.ioi_dataset import IOIDataset
from easy_transformer.ioi_utils import logit_diff as ioi_logit_diff

from utils import get_data, compute_logit_diff_acronym

torch.set_grad_enabled(False)

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
n_patching = 100
n_val = 100
task = "greater-than"

data = get_data(n_patching=n_patching, n_val=n_val, task=task)

model = data["model"]

patching_tokens = data["patching_tokens"] 
patching_answer_tokens = data["patching_answer_tokens"] 
patching_logits = data["patching_logits"] 
patching_cache = data["patching_cache"]

val_tokens = data["val_tokens"] 
val_answer_tokens = data["val_answer_tokens"]
val_logits = data["val_logits"]
val_cache = data["val_cache"]

gt_circuit = data["gt_circuit"]

Loaded pretrained model gpt2-small into HookedTransformer


In [24]:
def ablate_head(activations, hook, head_idx, scheme, new_cache=None):
    """
        Ablates a head from the layer specified by hook and the 
        index specified by `head_idx`.

        Parameters
        ----------
        - `activations`: output of the hook. 
        Usually will have shape `(batch, pos, head, d_head)`
        - `hook`: This specifies where the hook will be located. Usually will be
        of the shape 'blocks.0.attn.hook_result'
        - `head_idx`: The index of the head that we want to ablate
        - `scheme`: Either "zero" or "mean" for zero or mean ablation.
        - `new_cache`: Cache that will be used when performing mean ablation
    """
    assert scheme in ["mean", "zero"], "the ablation scheme should be either 'mean' or 'zero'"

    if scheme == "mean":
        assert new_cache is not None, "`new_cache` is required when mean ablating"
        activations[:, :, head_idx] = new_cache[hook.name][:, :, head_idx].mean(0)[None, ...]
    elif scheme == "zero":
        activations[:, :, head_idx] = 0.
    return activations

In [25]:
def activation_patching(score_func, baseline_scores):
    model.reset_hooks(including_permanent=True)
    corrupted_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))
    with torch.no_grad():
        for layer, head in tqdm(list(product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):
            model.reset_hooks(including_permanent=True)
            hook_fn = partial(ablate_head, head_idx=head, scheme="mean", new_cache=patching_cache)
            model.add_hook(utils.get_act_name("result", layer), hook_fn)
            corrupted_logits = model(val_tokens)
            corrupted_score = score_func(corrupted_logits)
            corrupted_scores[layer, head] = corrupted_score

    attribution_score = (corrupted_scores - baseline_scores.cpu())
    return attribution_score

# KL-divergence algorithm

Sort the nodes in reverse topological. Note that we are working with nodes instead of edges. In this case, nodes represent activations (or components) and edges represent computations. As GPUs are very good at performing lots of parallel computations, we will not gain that much speed and/or space. As our objective is to perform a pruning method that leaves us with more interpretable models, we therefore focus on removing nodes. This is also good because there are way less nodes that edges, so that the algorithm will generally be faster.

Also, we used the logit difference in the previous experiment, but in ACDC they recommend using the KL divergence as it performs better in most cases. There are cases where a more specific metric is required. However, let's try to use the KL divergence: if we recover the acronym circuit, then keep going.

One more thing is that we also have to choose the ablation scheme: zero-ablating might be more aggressive, hence more nodes will be present in the resulting circuit, whereas mean-ablating might be more precise but we require to add a bias term to compensate for removing each component. **This comes with a great setback: the bias vector depends on the length of the sequence, so we either stick to sentences with the same template or think about some other method.**

So, a todo list:

- ~~Implement the KL-based pruning algorithm (mean and zero).~~
- Add MLPs to the algorithm
- ~~Check if we obtain results similar to the acronym paper.~~
- ~~Experiment with zero/mean ablation~~
- Actually obtain a pruned model

Cool visualizations
- Pareto frontier


In [26]:
def kl_div(logits, baseline_logprobs, pos=-1):
    # logits and baseline_logprobs have shape [batch_size, seq_len, d_vocab]
    logits = logits[:, pos]
    baseline_logprobs = baseline_logprobs[:, pos]
    return F.kl_div(F.log_softmax(logits, dim=-1), baseline_logprobs, log_target=True, reduction="none").sum(dim=-1).mean()

In [27]:
# SANITY CHECK: Perform the same experiment as above, but swapping the logit difference by the KL divergence
# check that we obtain sensible results
score_kl_div = partial(kl_div, baseline_logprobs=F.log_softmax(val_logits, dim=-1), pos=-1)

attribution_scores_kl = activation_patching(score_kl_div, torch.tensor(0.))
px.imshow(attribution_scores_kl, 
            title="Attribution score for attention heads (mean-ablation)", 
            labels={"x": "Head", "y": "Layer"}, width=500, height=500, 
            color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

  0%|          | 0/144 [00:00<?, ?it/s]

100%|██████████| 144/144 [00:03<00:00, 44.31it/s]


In [36]:
def auto_circuit(threshold, score_func, curr_logits, ablation_scheme="mean"):
    model.reset_hooks(including_permanent=True)
    circuit = []

    # Traverse the nodes from downstream to upstream
    for layer, head in tqdm(
         product(reversed(range(model.cfg.n_layers)), range(model.cfg.n_heads)),
         total=model.cfg.n_layers*model.cfg.n_heads):
            # temporarily remove node
            hook_fn = partial(
                ablate_head, head_idx=head, scheme=ablation_scheme, new_cache=patching_cache
                )
            model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=False)
            temp_logits = model(val_tokens).clone()
            
            if (score_func(temp_logits).mean(0) - score_func(curr_logits).mean(0)) < threshold:
                # if the KL divergence does not increase over a threshold, the node
                # is not important, so remove permanently
                model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=True)
                curr_logits = temp_logits.clone()
            else:
                # include node in the circuit
                circuit.append([layer, head])
            model.reset_hooks(including_permanent=False)
            
    return circuit

In [34]:
# test: pareto frontier
# try different thresholds, check the number of true positives and false positives
baseline_logprobs = F.log_softmax(val_logits.clone(), dim=-1).clone()
score = partial(kl_div, baseline_logprobs=baseline_logprobs, pos=-1)
ablation_scheme = "mean"

thresholds = 10**np.linspace(-1, -4, 10)
#thresholds = 10**np.linspace(-1, -6, 20)

tprs = []
fprs = []

for threshold in thresholds:
    circuit = auto_circuit(threshold, score, val_logits.clone(), ablation_scheme=ablation_scheme)
    tpr = len([head for head in circuit if head in gt_circuit]) / len(gt_circuit)
    fpr = len([head for head in circuit if head not in gt_circuit]) / (144 - len(gt_circuit))
    tprs.append(tpr); fprs.append(fpr)

  0%|          | 0/144 [00:00<?, ?it/s]

100%|██████████| 144/144 [00:03<00:00, 41.53it/s]
100%|██████████| 144/144 [00:03<00:00, 41.69it/s]
100%|██████████| 144/144 [00:03<00:00, 41.68it/s]
100%|██████████| 144/144 [00:03<00:00, 41.65it/s]
100%|██████████| 144/144 [00:03<00:00, 41.70it/s]
100%|██████████| 144/144 [00:03<00:00, 41.74it/s]
100%|██████████| 144/144 [00:03<00:00, 41.78it/s]
100%|██████████| 144/144 [00:03<00:00, 41.96it/s]
100%|██████████| 144/144 [00:03<00:00, 42.10it/s]
100%|██████████| 144/144 [00:03<00:00, 42.29it/s]
100%|██████████| 144/144 [00:03<00:00, 42.58it/s]
100%|██████████| 144/144 [00:03<00:00, 42.80it/s]
100%|██████████| 144/144 [00:03<00:00, 42.93it/s]
100%|██████████| 144/144 [00:03<00:00, 43.01it/s]
100%|██████████| 144/144 [00:03<00:00, 43.10it/s]
100%|██████████| 144/144 [00:03<00:00, 43.04it/s]
100%|██████████| 144/144 [00:03<00:00, 43.15it/s]
100%|██████████| 144/144 [00:03<00:00, 43.19it/s]
100%|██████████| 144/144 [00:03<00:00, 43.17it/s]
100%|██████████| 144/144 [00:03<00:00, 43.19it/s]


In [35]:
fig = px.line(x=fprs, y=tprs, text=[f"{threshold:.0e}" for threshold in thresholds],
              labels={"x": "False Positive Rate", "y": "True Positive Rate", "text": "threshold"})
fig.update_traces(textposition="bottom right")
fig.show()