In [106]:
import pickle
from typing import Optional, Tuple, Any, Union

import pandas as pd

import plotly.express as px

import torch
import torch.nn as nn

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Model

from utils import get_data

torch.set_grad_enabled(False)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Now that we have obtained the circuits, it is time to truly obtain a pruned model. We will do it directly on the HuggingFace implementation, as we don't need TransformerLens now and we actually care about performance and not interpretability right now. We will start by pruning zero ablated models, as we don't have to take care of adding any bias.

In [14]:
TASK = "acronyms"
ABLATION = "zero"
INCLUDE_MLPS = True

In [45]:
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()

In [25]:
n_patching = 100
n_val = 100
task = "acronyms"

data = get_data(n_patching=n_patching, n_val=n_val, task=task)

patching_tokens = data["patching_tokens"] 
patching_answer_tokens = data["patching_answer_tokens"] 
patching_logits = data["patching_logits"] 
patching_cache = data["patching_cache"]

val_tokens = data["val_tokens"] 
val_answer_tokens = data["val_answer_tokens"]
val_logits = data["val_logits"]
val_cache = data["val_cache"]

gt_circuit = data["gt_circuit"]

Loaded pretrained model gpt2-small into HookedTransformer


In [86]:
def compute_accuracy(model, val_tokens, val_answer_tokens, task="acronyms"):
    if task == "acronyms":
        return (model(val_tokens)["logits"][:, -1].argmax(-1) == val_answer_tokens[:, -1]).float().mean().item()

In [52]:
class AttentionIdentity(nn.Module):
    """
    Placeholder for the GPT2Attention layer, literally does nothing
    """
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        return (torch.zeros_like(hidden_states), None)

In [109]:
def list_to_dict(attn_heads, n_heads=12, n_layers=12):
    """
    Given a list of the attn heads of the circuit, returns
    a dictionary heads_to_prune[layer] = [head, ...] with
    every attention head outside of the circuit.
    """
    heads_to_prune = {}
    for layer in range(n_layers):
        heads_to_prune[layer] = [head for head in range(n_heads)]

    for layer, head in attn_heads:
        heads_to_prune[layer].remove(head)
        
    return heads_to_prune

def get_attn_layers_to_prune(heads_to_prune, n_heads=12):
    """
    If heads_to_prune[layer] contains every head of the attention layer,
    we directly remove the complete layer instead of every separate head.
    """
    attn_layers_to_prune = []
    for layer in heads_to_prune.keys():
        if len(heads_to_prune[layer]) == n_heads:
            attn_layers_to_prune.append(layer)
    for layer in attn_layers_to_prune:
        del heads_to_prune[layer]
    return heads_to_prune, attn_layers_to_prune

def prune_attn_layers(self: GPT2Model, layers):
    for layer in layers:
        self.h[layer].attn = AttentionIdentity()

GPT2Model.prune_attn_layers = prune_attn_layers


attn_heads = gt_circuit
heads_to_prune = list_to_dict(attn_heads)
heads_to_prune, attn_layers_to_prune = get_attn_layers_to_prune(heads_to_prune)

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()
model.transformer._prune_heads(heads_to_prune)
model.transformer.prune_attn_layers(attn_layers_to_prune)

In [111]:
patching_cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.hook_q_input', 'blocks.0.hook_k_input', 'blocks.0.hook_v_input', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.hook_q_input', 'blocks.1.hook_k_input', 'blocks.1.hook_v_input', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.attn.hook_result', '

In [112]:
def prune_gpt2_hf_model(model, attn_heads, ablation_scheme="mean", patching_cache=None):
    """
    Truly remove the attention heads and MLPs on a GPT-2 model. Note
    that we can either remove attention heads, MLPs, or both.

    Parameters
    ----------
    - `model`: 
    - `attn_heads`: List [[layer, head]] containing the attention heads of the circuit
    Returns
    -------
    Pruned model
    """
    if ablation_scheme == "mean":
        assert patching_cache is not None, "You need to provide the cached activations when mean patching"

    heads_to_prune = list_to_dict(attn_heads)
    heads_to_prune, attn_layers_to_prune = get_attn_layers_to_prune(heads_to_prune)
    
    model.transformer._prune_heads(heads_to_prune)
    model.transformer.prune_attn_layers(attn_layers_to_prune)
    return model

model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()
acc_before = compute_accuracy(model, val_tokens, val_answer_tokens)
model = prune_gpt2_hf_model(model, gt_circuit, ablation_scheme="mean", patching_cache=patching_cache)
acc_after = compute_accuracy(model, val_tokens, val_answer_tokens)
print(acc_before, "->", acc_after)

0.9399999976158142 -> 0.0


In [96]:
with open(f"logs/{TASK}_{ABLATION}{'_mlp' if INCLUDE_MLPS else ''}.pkl", "rb") as handle:
    log = pickle.load(handle)

In [105]:
df = []

for threshold in sorted(log.keys()):
    attn_heads = log[threshold][0]
    model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()
    model = prune_gpt2_hf_model(model, attn_heads)
    acc = compute_accuracy(model, val_tokens, val_answer_tokens)
    size = model.num_parameters()
    df.append([threshold, acc, size])

df = pd.DataFrame(df, columns=["threshold", "acc", "size"])
df

Unnamed: 0,threshold,acc,size
0,1e-06,0.94,124439808
1,2e-06,0.94,124439808
2,5e-06,0.94,124439808
3,1.1e-05,0.94,124439808
4,2.3e-05,0.94,124439808
5,5.1e-05,0.94,124439808
6,0.000113,0.94,124439808
7,0.000248,0.94,123849408
8,0.000546,0.89,122275008
9,0.001199,0.92,118732608
