In [64]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm
import random

import plotly.express as px

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

import einops

from typing import Literal, Optional, Tuple, Union
from jaxtyping import Float

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens import utils


torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

In [65]:
def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()

def compute_logit_diff(logits, answer_tokens, average=True):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    capital_letters_tokens = torch.tensor([32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,
         50, 51, 52, 53, 54, 55, 56, 57], dtype=torch.long, device=device)
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)

In [66]:
tokenizer_hf = AutoTokenizer.from_pretrained("openai-community/gpt2", add_bos_token=True)
model_hf = AutoModelForCausalLM.from_pretrained("openai-community/gpt2", output_hidden_states=False, use_cache=False).to(device)
initial_parameters = model_hf.num_parameters()

In [67]:
with open("acronyms_2_common.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = 250
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))

In [68]:
tokens_hf = tokenizer_hf(prompts, return_tensors="pt")["input_ids"]
answer_tokens = tokenizer_hf(acronyms, return_tensors="pt")["input_ids"].cuda()
tokens_hf = torch.cat([torch.ones((tokens_hf.shape[0], 1), dtype=torch.long) * tokenizer_hf.bos_token_id, tokens_hf], dim=1).to(device)
logits_hf = model_hf(tokens_hf)["logits"]

In [69]:
logit_diff_hf = compute_logit_diff(logits_hf, answer_tokens, average=False)[..., -1].mean()

print(f"Logit Diff. HF: {logit_diff_hf.item():.4f}")

Logit Diff. HF: 3.3557


In [70]:
def get_cache_attn(module: GPT2Attention, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Caches the mean activation of the attention layer. `cache` has shape (n_layer, d_model)
    output -> tensor of shape (batch_size, seq_len, d_model)
    """
    cache[layer_idx] = output[0].mean(0).detach().clone()

cache_attn = torch.zeros((model_hf.config.n_layer, 11, model_hf.config.n_embd)).cuda()

for layer in range(model_hf.config.n_layer):  
    hook_fn = partial(get_cache_attn, layer_idx=layer, cache=cache_attn)
    hook = model_hf.transformer.h[layer].attn.register_forward_hook(hook_fn)
    corrupted_logits = model_hf(tokens_hf)["logits"]
    hook.remove()

In [71]:
def mean_ablate_attn_hf(module: GPT2Attention, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Performs mean ablation on an attention layer
    This function 

    output -> tuple containing (output[0], None)
    output[0] -> tensor of shape (batch_size, seq_len, d_model)
    cache -> tensor of shape (n_layers, seq_len, d_model)
    """
    # output = cache[layer_idx][None, ...]
    return (cache[layer_idx][None, ...], None)

In [72]:
corrupted_logit_diffs = torch.zeros((model_hf.config.n_layer, n_samples))
with torch.no_grad():
    for layer in tqdm(range(model_hf.config.n_layer)):
        hook_fn = partial(mean_ablate_attn_hf, layer_idx=layer, cache=cache_attn)
        hook = model_hf.transformer.h[layer].attn.register_forward_hook(hook_fn)
        corrupted_logits = model_hf(tokens_hf)["logits"]
        hook.remove()
        corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), answer_tokens, average=False)
        corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_attn_hf = (corrupted_logit_diffs - logit_diff_hf.cpu()).mean(-1)

100%|██████████| 12/12 [00:00<00:00, 37.17it/s]


In [73]:
px.imshow(attribution_score_attn_hf.unsqueeze(0).detach().numpy(), title="Attribution score for attention heads (HuggingFace)", 
    labels={"x": "Layer", "y": ""}, width=800, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [74]:
def get_cache_mlp(module: GPT2MLP, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Caches the mean activation of the MLP. `cache` has shape (n_layer, d_model)
    output -> tensor of shape (batch_size, seq_len, d_model)
    """
    cache[layer_idx] = output.mean(0).detach().clone()

cache_mlp = torch.zeros((model_hf.config.n_layer, 11, model_hf.config.n_embd)).cuda()

for layer in range(model_hf.config.n_layer):  
    hook_fn = partial(get_cache_mlp, layer_idx=layer, cache=cache_mlp)
    hook = model_hf.transformer.h[layer].mlp.register_forward_hook(hook_fn)
    corrupted_logits = model_hf(tokens_hf)["logits"]
    hook.remove()

In [75]:
def mean_ablate_mlp_hf(module: GPT2MLP, input, output, layer_idx: int, cache: torch.Tensor):
    """
    Performs mean ablation on an MLP layer
    This function 

    output -> tensor of shape (batch_size, seq_len, d_model)
    cache -> tensor of shape (n_layers, seq_len, d_model)
    """
    output = cache[layer_idx][None, ...]
    return output

In [76]:
corrupted_logit_diffs = torch.zeros((model_hf.config.n_layer, n_samples))
with torch.no_grad():
    for layer in tqdm(range(model_hf.config.n_layer)):
        hook_fn = partial(mean_ablate_mlp_hf, layer_idx=layer, cache=cache_mlp)
        hook = model_hf.transformer.h[layer].mlp.register_forward_hook(hook_fn)
        corrupted_logits = model_hf(tokens_hf)["logits"]
        hook.remove()
        corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), answer_tokens, average=False)
        corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

attribution_score_mlp_hf = (corrupted_logit_diffs - logit_diff_hf.cpu()).mean(-1)

100%|██████████| 12/12 [00:00<00:00, 37.00it/s]


In [77]:
px.imshow(attribution_score_mlp_hf.unsqueeze(0).detach().numpy(), title="Attribution score for attention heads (HuggingFace)", 
    labels={"x": "Layer", "y": ""}, width=800, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [78]:
# Row 0 contains scores for attn, row 1 for MLPs
attribution_scores = torch.stack([attribution_score_attn_hf, attribution_score_mlp_hf], dim=0)
sorted_components = sort_Nd_tensor(attribution_scores)

In [79]:
for comp, layer in sorted_components:
    print(f"{'MLP' if comp else 'Attn.'} {layer}: {attribution_scores[comp, layer]:.2f}")

MLP 0: -2.11
Attn. 8: -1.70
Attn. 10: -0.74
MLP 10: -0.39
MLP 8: -0.36
MLP 11: -0.35
MLP 9: -0.22
MLP 6: -0.11
MLP 7: -0.08
Attn. 1: -0.00
Attn. 7: 0.07
Attn. 6: 0.08
Attn. 4: 0.11
Attn. 11: 0.12
Attn. 2: 0.12
Attn. 3: 0.14
MLP 3: 0.15
MLP 5: 0.16
MLP 1: 0.20
MLP 4: 0.20
MLP 2: 0.22
Attn. 9: 0.28
Attn. 0: 0.31
Attn. 5: 0.42


In [80]:
class PrunedMLP(nn.Module):
    """
    Layer that simply returns the vector `mean_activation`, which is expected
    to have shape (seq_len, d_model) and represents the mean activation of the MLP
    that is going to be replaced.
    """
    def __init__(self, mean_activation: torch.Tensor):
        super().__init__()
        self.mean_activation = mean_activation

    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
        return self.mean_activation[None, ...]
        

class PrunedAttention(nn.Module):
    """
    Layer that simply returns the vector `mean_activation`, which is expected
    to have shape (seq_len, d_model) and represents the mean activation of the MLP
    that is going to be replaced.
    """
    def __init__(self, mean_activation: torch.Tensor):
        super().__init__()
        self.mean_activation = mean_activation

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
        return (self.mean_activation[None, ...], None)

In [81]:
logit_diffs = []
std_logit_diffs = []

num_parameters = []

for comp, layer in reversed(sorted_components):
    if comp == 0:
        model_hf.transformer.h[layer].attn = PrunedAttention(mean_activation=cache_attn[layer])
    else:
        model_hf.transformer.h[layer].mlp = PrunedMLP(mean_activation=cache_mlp[layer])
    num_parameters.append(model_hf.num_parameters())
    circuit_logits = model_hf(tokens_hf)["logits"]
    logit_diff = compute_logit_diff(circuit_logits, answer_tokens, average=False)
    av_logit_diff = logit_diff[..., -1].mean(0)
    std_logit_diff = logit_diff[..., -1].std(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
logit_diffs_hf = torch.stack(logit_diffs, dim=0)
std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)

In [82]:
cutoff = (logit_diffs_hf >= logit_diff_hf).nonzero()[-1].item() # last point where the performance is equal or higher than baseline

In [83]:
labels = [f"{'MLP' if comp else 'Attn.'} {layer}" for comp, layer in reversed(sorted_components)]
fig = px.line(x = labels, y=logit_diffs_hf.cpu().numpy(), error_y=std_logit_diffs_hf.cpu().numpy())
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.add_vline(x=labels[cutoff], line_width=1., line_dash="dash", line_color="red")
fig.show()

In [84]:
p_parameters = [x/initial_parameters for x in num_parameters]
fig = px.line(x = labels, y=p_parameters)
fig.add_vline(x=labels[cutoff], line_width=1., line_dash="dash", line_color="red")
fig.show()

We achieve a ~33% reduction in parameters while preserving the same baseline performance, as well as keeping the general structure of the model (i.e. we don't mask specific weights).

In [85]:
fig = px.line(x = p_parameters, y=logit_diffs_hf.cpu().numpy(), error_y=std_logit_diffs_hf.cpu().numpy())
fig.add_vline(x=p_parameters[cutoff], line_width=1., line_dash="dash", line_color="red")
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.update_xaxes(autorange="reversed")
fig.show()

However, I can keep pruning the model while sacrificing a little bit of performance, and achieve up to 50% reduction while keeping as reasonable performance!