In [1]:
from copy import deepcopy
from functools import partial

from einops import einsum, rearrange

import torch
import torch.nn as nn

import matplotlib.pyplot as plt

from transformers import AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention

from utils import get_data, compute_logit_diff_acronym

torch.set_grad_enabled(False)

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_patching = 100
n_val = 100
task = "acronyms"

data = get_data(n_patching=n_patching, n_val=n_val, task=task)

model = data["model"]

patching_tokens = data["patching_tokens"] 
patching_answer_tokens = data["patching_answer_tokens"] 
patching_logits = data["patching_logits"] 
patching_cache = data["patching_cache"]

val_tokens = data["val_tokens"] 
val_answer_tokens = data["val_answer_tokens"]
val_logits = data["val_logits"]
val_cache = data["val_cache"]

gt_circuit = data["gt_circuit"]

del model
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()
model.eval()

Loaded pretrained model gpt2-small into HookedTransformer


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [3]:
def cache_mean_attn_layer_activations(model, patching_tokens):
    mean_attn_layer_activations = torch.zeros(model.config.n_layer, patching_tokens.shape[-1], model.config.n_embd)

    def _cache_mean_attn_layer_activations(module, input, output, layer_idx):
        mean_attn_layer_activations[layer_idx] = output[0].mean(0)
        return None

    for layer in range(model.config.n_layer):
        # Gather the activations for that layer
        hook_fn = partial(_cache_mean_attn_layer_activations, layer_idx=layer)
        hook = model.transformer.h[layer].attn.register_forward_hook(hook_fn)
        model(patching_tokens)
        hook.remove()
    return mean_attn_layer_activations


def cache_mean_mlp_activations(model, patching_tokens):
    mean_mlp_activations = torch.zeros(model.config.n_layer, patching_tokens.shape[-1], model.config.n_embd)

    def _cache_mean_mlp_activations(module, input, output, layer_idx):
        mean_mlp_activations[layer_idx] = output.mean(0)
        return None

    for layer in range(model.config.n_layer):
        # Gather the activations for that layer
        hook_fn = partial(_cache_mean_mlp_activations, layer_idx=layer)
        hook = model.transformer.h[layer].mlp.register_forward_hook(hook_fn)
        model(patching_tokens)
        hook.remove()
    return mean_mlp_activations


def cache_mean_head_activations(model, patching_tokens):
    d_head = int(model.config.n_embd / model.config.n_head)


    mean_head_activations = torch.zeros(model.config.n_layer, model.config.n_head, patching_tokens.shape[-1], model.config.n_embd)
    attn_layer_biases = torch.zeros(model.config.n_layer, model.config.n_embd)

    def _cache_mean_head_activations(c_proj, input, output, layer_idx):
        h = input[0]
        batch_size, seq_len, d_model = h.size()
        h = h.view(batch_size, seq_len, model.config.n_head, d_head)
        w = c_proj.weight.view(model.config.n_head, d_head, model.config.n_embd)

        h_proj = einsum(
            h, w,
            "batch_size seq_len n_head d_head, n_head d_head d_model -> batch_size seq_len n_head d_model"
        ).mean(0) # seq_len, n_head, d_model
        h_proj = rearrange(h_proj, "seq_len n_head d_model -> n_head seq_len d_model")
        
        mean_head_activations[layer_idx] = h_proj
        attn_layer_biases[layer_idx] = c_proj.bias

    for layer in range(model.config.n_layer):
        hook_fn = partial(_cache_mean_head_activations, layer_idx=layer)
        hook = model.transformer.h[layer].attn.c_proj.register_forward_hook(hook_fn)
        model(patching_tokens)
        hook.remove()

    return mean_head_activations, attn_layer_biases

In [4]:
mean_head_activations, attn_layer_biases = cache_mean_head_activations(model, patching_tokens)
mean_attn_layer_activations = cache_mean_attn_layer_activations(model, patching_tokens)
mean_mlp_activations = cache_mean_mlp_activations(model, patching_tokens)

# SANITY CHECK: Do we obtain the same results?
torch.allclose(mean_attn_layer_activations, mean_head_activations.sum(1) + attn_layer_biases[:, None, :], atol=1e-4)

True

In [5]:
class BiasLayer(nn.Module):
    """
    This module replaces the GPT2Attention layer and optionally outputs a bias term
    """
    def __init__(self, bias):
        super().__init__()
        self.bias = bias

    def forward(self, hidden_states, **kwargs):
        return (self.bias, None)


class BiasLayerMLP(nn.Module):
    """
    This module replaces the GPT2MLP layer and optionally outputs a bias term
    """
    def __init__(self, bias):
        super().__init__()
        self.bias = bias

    def forward(self, hidden_states, **kwargs):
        return self.bias


class AddLayer(nn.Module):
    """
    This module replaces the GPT2Attention layer and optionally sums a bias term
    """
    def __init__(self, attn: GPT2Attention, bias):
        super().__init__()
        self.bias = bias
        self.attn = attn

    def forward(self, hidden_states, **kwargs):
        output = self.attn(hidden_states)
        return (output[0] + self.bias, None)


class PassthroughLayer(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__()

    def forward(self, hidden_states, **kwargs):
        return (hidden_states, None)

In [6]:
def compute_accuracy(model, val_tokens, val_answer_tokens, task="acronyms"):
    if task == "acronyms":
        return (model(val_tokens)["logits"][:, -1].argmax(-1) == val_answer_tokens[:, -1]).float().mean().item()

def list_to_dict(attn_heads, n_heads=12, n_layers=12):
    """
    Given a list of the attn heads of the circuit, returns
    a dictionary heads_to_prune[layer] = [head, ...] with
    every attention head outside of the circuit.
    """
    heads_to_prune = {}
    for layer in range(n_layers):
        heads_to_prune[layer] = [head for head in range(n_heads)]

    for layer, head in attn_heads:
        heads_to_prune[layer].remove(head)
        
    return heads_to_prune

def get_attn_layers_to_prune(heads_to_prune, n_heads=12):
    """
    If heads_to_prune[layer] contains every head of the attention layer,
    we directly remove the complete layer instead of every separate head.
    """
    attn_layers_to_prune = []
    for layer in heads_to_prune.keys():
        if len(heads_to_prune[layer]) == n_heads:
            attn_layers_to_prune.append(layer)
    for layer in attn_layers_to_prune:
        del heads_to_prune[layer]
    return heads_to_prune, attn_layers_to_prune

In [7]:
embedding_parameters = (50257 * 768) + (1024 * 768)
initial_parameters = model.num_parameters() - embedding_parameters

In [8]:
compute_accuracy(model, val_tokens, val_answer_tokens), initial_parameters

(0.8899999856948853, 85056000)

In [9]:
def prune_model(model, circuit_attn_heads):
    #########################
    #   PRUNE ATTENTION     #
    #########################
    heads_to_prune = list_to_dict(circuit_attn_heads)
    heads_to_prune, attn_layers_to_prune = get_attn_layers_to_prune(heads_to_prune)
    # Replace complete attn layers by just a bias term
    for layer in attn_layers_to_prune:
        model.transformer.h[layer].ln_1 = PassthroughLayer()
        model.transformer.h[layer].attn = BiasLayer(bias=mean_attn_layer_activations[layer].cuda())
    # Prune the individual heads 
    model.transformer._prune_heads(heads_to_prune)
    # Add the bias term of the pruned heads to the respective attention layers
    for layer in heads_to_prune.keys():
        model.transformer.h[layer].attn = AddLayer(
            attn=model.transformer.h[layer].attn,
            bias=mean_head_activations[layer, heads_to_prune[layer]].sum(0).cuda()
            )
    
    #################
    #   PRUNE MLP   #
    #################

    return model

In [10]:
model = prune_model(model, gt_circuit)

In [11]:
circuit_mlps = [0, 1, 8, 9, 10, 11, 12]

mlps_to_prune = [mlp for mlp in range(model.config.n_layer) if mlp not in circuit_mlps]
# Replace MLPs
for layer in mlps_to_prune:
    model.transformer.h[layer].ln_2 = PassthroughLayer()
    model.transformer.h[layer].mlp = BiasLayerMLP(bias=mean_mlp_activations[layer].cuda())

In [12]:
compute_accuracy(model, val_tokens, val_answer_tokens), model.num_parameters() - embedding_parameters

(0.8499999642372131, 29938176)