In [1]:
import pickle
import random
import itertools
from itertools import product
from functools import partial
from tqdm import tqdm
import os

import plotly.express as px
import matplotlib.pyplot as plt

from sklearn.metrics import auc

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

import einops

from typing import Literal
from jaxtyping import Float

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils

from easy_transformer.ioi_dataset import IOIDataset
from easy_transformer.ioi_utils import logit_diff as ioi_logit_diff

from utils import get_data, compute_logit_diff_acronym

torch.set_grad_enabled(False)

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_patching = 100
n_val = 100
task = "acronyms"

data = get_data(n_patching=n_patching, n_val=n_val, task=task)

model = data["model"]

patching_tokens = data["patching_tokens"] 
patching_answer_tokens = data["patching_answer_tokens"] 
patching_logits = data["patching_logits"] 
patching_cache = data["patching_cache"]

val_tokens = data["val_tokens"] 
val_answer_tokens = data["val_answer_tokens"]
val_logits = data["val_logits"]
val_cache = data["val_cache"]

gt_circuit = data["gt_circuit"]

Loaded pretrained model gpt2-small into HookedTransformer


In [3]:
def compute_accuracy(val_logits, val_answer_tokens):
    return (val_logits[:, -1].argmax(-1) == val_answer_tokens[:, -1]).float().mean()

In [3]:
def ablate_head(activations, hook, head_idx, scheme, new_cache=None):
    """
        Ablates a head from the layer specified by hook and the 
        index specified by `head_idx`.

        Parameters
        ----------
        - `activations`: output of the hook. 
        Usually will have shape `(batch, pos, head, d_head)`
        - `hook`: This specifies where the hook will be located. Usually will be
        of the shape 'blocks.0.attn.hook_result'
        - `head_idx`: The index of the head that we want to ablate
        - `scheme`: Either "zero" or "mean" for zero or mean ablation.
        - `new_cache`: Cache that will be used when performing mean ablation
    """
    assert scheme in ["mean", "zero"], "the ablation scheme should be either 'mean' or 'zero'"

    if scheme == "mean":
        assert new_cache is not None, "`new_cache` is required when mean ablating"
        activations[:, :, head_idx] = new_cache[hook.name][:, :, head_idx].mean(0)[None, ...]
    elif scheme == "zero":
        activations[:, :, head_idx] = 0.
    return activations

def ablate_mlp(activations, hook, scheme, new_cache=None):
    assert scheme in ["mean", "zero"], "the ablation scheme should be either 'mean' or 'zero'"

    if scheme == "mean":
        assert new_cache is not None, "`new_cache` is required when mean ablating"
        activations[:, :, :] = new_cache[hook.name][:, :, :].mean(0)[None, ...]
    elif scheme == "zero":
        activations[:, :, :] = 0.
    return activations

In [5]:
def activation_patching(val_tokens, score_func, baseline_scores, model, patching_cache):
    model.reset_hooks(including_permanent=True)
    corrupted_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))
    with torch.no_grad():
        for layer, head in tqdm(list(product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):
            model.reset_hooks(including_permanent=True)
            hook_fn = partial(ablate_head, head_idx=head, scheme="mean", new_cache=patching_cache)
            model.add_hook(utils.get_act_name("result", layer), hook_fn)
            corrupted_logits = model(val_tokens)
            corrupted_score = score_func(corrupted_logits)
            corrupted_scores[layer, head] = corrupted_score

    attribution_score = (corrupted_scores - baseline_scores.cpu())
    return attribution_score

# KL-divergence algorithm

Sort the nodes in reverse topological. Note that we are working with nodes instead of edges. In this case, nodes represent activations (or components) and edges represent computations. As GPUs are very good at performing lots of parallel computations, we will not gain that much speed and/or space. As our objective is to perform a pruning method that leaves us with more interpretable models, we therefore focus on removing nodes. This is also good because there are way less nodes that edges, so that the algorithm will generally be faster.

Also, we used the logit difference in the previous experiment, but in ACDC they recommend using the KL divergence as it performs better in most cases. There are cases where a more specific metric is required. However, let's try to use the KL divergence: if we recover the acronym circuit, then keep going.

One more thing is that we also have to choose the ablation scheme: zero-ablating might be more aggressive, hence more nodes will be present in the resulting circuit, whereas mean-ablating might be more precise but we require to add a bias term to compensate for removing each component. **This comes with a great setback: the bias vector depends on the length of the sequence, so we either stick to sentences with the same template or think about some other method.**

So, a todo list:

- ~~Implement the KL-based pruning algorithm (mean and zero).~~
- Add MLPs to the algorithm
- ~~Check if we obtain results similar to the acronym paper.~~
- ~~Experiment with zero/mean ablation~~
- Actually obtain a pruned model

Cool visualizations
- ~~Pareto frontier~~


In [8]:
def auto_circuit(threshold, score_func, curr_logits, val_tokens, patching_cache, model, ablation_scheme="mean", include_mlps=False):
    model.reset_hooks(including_permanent=True)
    circuit = []
    circuit_mlps = []

    # Traverse the nodes from downstream to upstream
    for layer in tqdm(reversed(range(model.cfg.n_layers)), total=model.cfg.n_layers):
        for head in range(model.cfg.n_heads):
            # temporarily remove node
            hook_fn = partial(
                ablate_head, head_idx=head, scheme=ablation_scheme, new_cache=patching_cache
                )
            model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=False)
            temp_logits = model(val_tokens).clone()
            
            if (score_func(temp_logits).mean(0) - score_func(curr_logits).mean(0)) < threshold:
                # if the KL divergence does not increase over a threshold, the node
                # is not important, so remove permanently
                model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=True)
                curr_logits = temp_logits.clone()
            else:
                # include node in the circuit
                circuit.append([layer, head])
            model.reset_hooks(including_permanent=False)
        # repeat for MLP
        # temporarily remove node
        hook_fn = partial(
            ablate_mlp, scheme=ablation_scheme, new_cache=patching_cache
            )
        model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=False)
        temp_logits = model(val_tokens).clone()

        if (score_func(temp_logits).mean(0) - score_func(curr_logits).mean(0)) < threshold:
            # if the KL divergence does not increase over a threshold, the node
            # is not important, so remove permanently
            model.add_hook(utils.get_act_name("result", layer), hook_fn, is_permanent=True)
            curr_logits = temp_logits.clone()
        else:
            # include node in the circuit
            circuit_mlps.append([layer])
        model.reset_hooks(including_permanent=False)
            
    return circuit, circuit_mlps

In [9]:
def sweep_autocircuit(task="acronyms", ablation_scheme="mean", include_mlps=False):

    try:
        with open(f"logs/{task}_{ablation_scheme}{'_mlp' if include_mlps else ''}.pkl", "rb") as handle:
            log = pickle.load(handle)
    except FileNotFoundError:
        log = {}
    
    data = get_data(n_patching=n_patching, n_val=n_val, task=task)

    model = data["model"]
    patching_cache = data["patching_cache"]

    val_tokens = data["val_tokens"] 
    val_logits = data["val_logits"]

    baseline_logprobs = F.log_softmax(val_logits.clone(), dim=-1).clone()
    score = partial(kl_div, baseline_logprobs=baseline_logprobs, pos=-1)

    thresholds = 10**np.linspace(0, -6, 20)

    for threshold in thresholds:
        circuit, circuit_mlps = auto_circuit(threshold, score, val_logits.clone(), val_tokens, patching_cache, model, ablation_scheme=ablation_scheme, include_mlps=include_mlps)
        log[threshold] = [circuit, circuit_mlps]
    
    with open(f"logs/{task}_{ablation_scheme}{'_mlp' if include_mlps else ''}.pkl", "wb") as handle:
        pickle.dump(log, handle, protocol=pickle.HIGHEST_PROTOCOL)
    return log

In [None]:
for task, scheme, include_mlps in product(["acronyms", "ioi", "greater-than"], ["mean", "zero"], [False, True]):
    sweep_autocircuit(task=task, ablation_scheme=scheme, include_mlps=include_mlps)

In [12]:
with open(f"logs/acronyms_mean.pkl", "rb") as handle:
    log = pickle.load(handle)

In [1]:
def pareto_frontier(task="acronyms", ablation_scheme="mean"):
    """
    Applies the discovery algorithm with different threshold and gathers the TPRs and FPRs,
    comparing the obtained circuit with the ground truth.
    """

    data = get_data(n_patching=n_patching, n_val=n_val, task=task)

    model = data["model"]

    patching_tokens = data["patching_tokens"] 
    patching_answer_tokens = data["patching_answer_tokens"] 
    patching_logits = data["patching_logits"] 
    patching_cache = data["patching_cache"]

    val_tokens = data["val_tokens"] 
    val_answer_tokens = data["val_answer_tokens"]
    val_logits = data["val_logits"]
    val_cache = data["val_cache"]

    gt_circuit = data["gt_circuit"]

    baseline_logprobs = F.log_softmax(val_logits.clone(), dim=-1).clone()
    score = partial(kl_div, baseline_logprobs=baseline_logprobs, pos=-1)

    thresholds = 10**np.linspace(0, -6, 40)

    tprs = []
    fprs = []

    for threshold in thresholds:
        circuit = auto_circuit(threshold, score, val_logits.clone(), val_tokens, patching_cache, model, ablation_scheme=ablation_scheme)
        tpr = len([head for head in circuit if head in gt_circuit]) / len(gt_circuit)
        fpr = len([head for head in circuit if head not in gt_circuit]) / (144 - len(gt_circuit))
        tprs.append(tpr); fprs.append(fpr)
    return tprs, fprs

In [2]:
def plot_pareto_experiment(ablation_scheme):
    fig, ax = plt.subplots(1)
    markers = ["o", "^", "x"]
    for i, task in enumerate(["acronyms", "greater-than", "ioi"]):
        tprs, fprs = pareto_frontier(task=task, ablation_scheme=ablation_scheme)
        # add the ends for prettier plots
        fprs = [0] + fprs + [1]
        tprs = [0] + tprs + [1]
        try:
            auc_score = auc(fprs, tprs)
        except:
            auc_score = 0.
        ax.plot(fprs, tprs, f'-', color=f"C{i}", drawstyle='steps-post', label=f"{task} (AUC: {auc_score:.2f})")
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.legend(loc="lower right")
    fig.tight_layout()
    return fig

In [3]:
fig = plot_pareto_experiment(ablation_scheme="mean")

NameError: name 'plt' is not defined

In [None]:
fig.savefig("pareto-zero.pdf")