In [1]:
from functools import partial

from einops import einsum, rearrange

import torch
import torch.nn as nn

from transformers import AutoModelForCausalLM

from utils import get_data
from pruning_utils import cache_mean_attn_layer_activations, cache_mean_head_activations, cache_mean_mlp_activations
from pruning_utils import PassthroughLayer, AddLayer, BiasLayer, BiasLayerMLP
from pruning_utils import prune_model

torch.set_grad_enabled(False)

device = "cuda" if torch.cuda.is_available() else "cpu"

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_patching = 100
n_val = 100
task = "acronyms"

data = get_data(n_patching=n_patching, n_val=n_val, task=task)

model = data["model"]

patching_tokens = data["patching_tokens"] 
patching_answer_tokens = data["patching_answer_tokens"] 
patching_logits = data["patching_logits"] 
patching_cache = data["patching_cache"]

val_tokens = data["val_tokens"] 
val_answer_tokens = data["val_answer_tokens"]
val_logits = data["val_logits"]
val_cache = data["val_cache"]

gt_circuit = data["gt_circuit"]

del model
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).cuda()
model.eval()

Loaded pretrained model gpt2-small into HookedTransformer


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [3]:
def compute_accuracy(model, val_tokens, val_answer_tokens, task="acronyms"):
    if task == "acronyms":
        return (model(val_tokens)["logits"][:, -1].argmax(-1) == val_answer_tokens[:, -1]).float().mean().item()

In [4]:
embedding_parameters = (50257 * 768) + (1024 * 768)
initial_parameters = model.num_parameters() - embedding_parameters

In [5]:
compute_accuracy(model, val_tokens, val_answer_tokens), initial_parameters

(0.9399999976158142, 85056000)

In [7]:
circuit_mlps = [0, 1, 8, 9, 10, 11, 12]

model = prune_model(model, gt_circuit, circuit_mlps, patching_tokens, ablation_scheme="mean")

In [8]:
compute_accuracy(model, val_tokens, val_answer_tokens), model.num_parameters() - embedding_parameters

(0.8499999642372131, 29938176)