In [1]:
import pickle
from itertools import product
from functools import partial
from tqdm import tqdm
import random
from string import ascii_uppercase

import plotly.express as px

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

import einops

from typing import Literal, Optional, Tuple, Union
from jaxtyping import Float

from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaSdpaAttention
from transformers.cache_utils import Cache

BATCH_SIZE = 100
SIZE = 500

torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer_hf = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model_hf = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, use_cache=False).cuda()
initial_parameters = model_hf.num_parameters()

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.07it/s]


In [3]:
def n_parameters(layer: nn.Module):
    return sum(p.numel() for p in layer.parameters())


def print_parameters_llama(model_hf):
    print(f"Initial # of parameters: {initial_parameters:.2E}")
    print(f"Final # of parameters: {model_hf.num_parameters():.2E}")
    print(f"Reduction of {100 - model_hf.num_parameters()/initial_parameters * 100:.2f}%")
    print("-"*40)
    print(f"embed_tokens: {sum(p.numel() for p in model_hf.model.embed_tokens.parameters()):.2E}")
    print(f"Decoder: {32*n_parameters(model_hf.model.layers[0]):.2E}")
    print(f"lm_head: {sum(p.numel() for p in model_hf.lm_head.parameters()):.2E}")
    print("-"*40)
    print(f"Decoder Layer: {n_parameters(model_hf.model.layers[0]):.2E}")
    print(f"\tAttention: {n_parameters(model_hf.model.layers[0].self_attn):.2E}")
    print(f"\tMLP: {n_parameters(model_hf.model.layers[0].mlp):.2E}")
    print(f"\tinput_layernorm: {n_parameters(model_hf.model.layers[0].input_layernorm):.2E}")
    print(f"\tpost_attention_layernorm: {n_parameters(model_hf.model.layers[0].post_attention_layernorm):.2E}")
    print("-"*40)


def print_predictions(prompts, probs, str_toks, val_acronyms):
    for prompt, prob, str_tok, acronym in zip(prompts, probs, str_toks, val_acronyms):
        print(f"{prompt}")
        for p, s in zip(prob, str_tok):
            if s == acronym[-1]:
                print(f"\t\033[92m{s}: {p:.2f}\033[0m")
            else:
                print(f"\t{s}: {p:.2f}")

    
def sort_Nd_tensor(tensor, descending=False):
    i = torch.sort(tensor.flatten(), descending=descending).indices
    return np.array(np.unravel_index(i.numpy(), tensor.shape)).T.tolist()

capital_letters_tokens = torch.tensor(tokenizer_hf(list(ascii_uppercase), add_special_tokens=False)["input_ids"], dtype=torch.long, device=device).squeeze()
    
def compute_logit_diff(logits, answer_tokens, 
                       average=True, capital_letters_tokens=capital_letters_tokens):
    """
    Compute the logit difference between the correct answer and the largest logit
    of all the possible incorrect capital letters. This is done for every iteration
    (i.e. each of the three letters of the acronym) and then averaged if desired.
    If `average=False`, then a `Tensor[batch_size, 3]` is returned, containing the
    logit difference at every iteration for every prompt in the batch

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Logits of the correct answers (batch_size, 3)
    correct_logits = logits[:, -3:].gather(-1, answer_tokens[..., None]).squeeze()
    # Retrieve the maximum logit of the possible incorrect answers
    batch_size = logits.shape[0]
    capital_letters_tokens_expanded = capital_letters_tokens.expand(batch_size, 3, -1)
    incorrect_capital_letters = capital_letters_tokens_expanded[capital_letters_tokens_expanded != answer_tokens[..., None]].reshape(batch_size, 3, -1)
    incorrect_logits, _ = logits[:, -3:].gather(-1, incorrect_capital_letters).max(-1)
    # Return the mean
    return (correct_logits - incorrect_logits).mean() if average else (correct_logits - incorrect_logits)


def compute_accuracy(logits, answer_tokens, 
                     capital_letters_tokens=capital_letters_tokens):
    """
    Computes the accuracy among the possible outputs (specified by `capital_letters_tokens`).
    Specifically, it retrieves the predicted token among the vocabulary and checks if it is
    the same as `answer_tokens[-1]`.

    Parameters:
    -----------
    - `logits`: `Tensor[batch_size, seq_len, d_vocab]`
    - `answer_tokens`: Tensor[batch_size, 3]
    """
    # Retrieve logits of the letters of the vocabulary
    logits_vocab = logits[:, -1, capital_letters_tokens]
    # Now the indices [0, 1, 2, 3,...] represent [capital_letter_tokens[0], ...]
    # Transform the answer tokens to [0, 1, 2, ...]
    answer_token_t = torch.tensor([torch.where(capital_letters_tokens == element)[0][0] for element in answer_tokens[:, -1]]).to(device)
    # Get predictions (max logit)
    preds = logits_vocab.argmax(-1)
    acc = (preds == answer_token_t).float().mean()
    std = (preds == answer_token_t).float().std()
    return acc, std

In [4]:
print_parameters_llama(model_hf)

Initial # of parameters: 6.74E+09
Final # of parameters: 6.74E+09
Reduction of 0.00%
----------------------------------------
embed_tokens: 1.31E+08
Decoder: 6.48E+09
lm_head: 1.31E+08
----------------------------------------
Decoder Layer: 2.02E+08
	Attention: 6.71E+07
	MLP: 1.35E+08
	input_layernorm: 4.10E+03
	post_attention_layernorm: 4.10E+03
----------------------------------------


In [5]:
with open("acronyms.txt", "r") as f:
   prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))

# take a subset of the dataset (we do this because VRAM limitations)
n_samples = 100
val_split = 0.2
n_val = int(n_samples*val_split)
# giga-cursed way of sampling from the dataset
prompts, acronyms = list(map(list, zip(*random.choices(list(zip(prompts, acronyms)), k=n_samples))))
val_prompts, val_acronyms = prompts[-n_val:], acronyms[-n_val:]
prompts, acronyms = prompts[:-n_val], acronyms[:-n_val]

In [6]:
class AcronymDataset(Dataset):
    def __init__(self, path: str, tokenizer, size=None):
        with open(path, "r") as f:
            prompts, acronyms = list(zip(*[line.split(", ") for line in f.read().splitlines()]))
        self.prompts = prompts
        self.acronyms = acronyms
        self.cap_to_id = {k: v[0] for k, v in zip(ascii_uppercase, tokenizer(list(ascii_uppercase), add_special_tokens=False)["input_ids"])}

        self.tokens = tokenizer(prompts, return_tensors="pt")["input_ids"]
        self.answer_tokens = torch.tensor([[self.cap_to_id[c] for c in acronym] for acronym in acronyms])
    
        if size:
            self.prompts = self.prompts[:size]
            self.acronyms = self.acronyms[:size]
            self.tokens = self.tokens[:size]
            self.answer_tokens = self.answer_tokens[:size]

    def __len__(self):
        return self.tokens.shape[0]
    
    def __getitem__(self, idx):
        return self.tokens[idx], self.answer_tokens[idx]
        
dataset = AcronymDataset(path="cache_acronyms.txt", tokenizer=tokenizer_hf, size=SIZE)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

val_dataset = AcronymDataset(path="val_acronyms.txt", tokenizer=tokenizer_hf, size=SIZE)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [7]:
cap_to_id = {k: v[0] for k, v in zip(ascii_uppercase, tokenizer_hf(list(ascii_uppercase), add_special_tokens=False)["input_ids"])}

# tokens_hf = tokenizer_hf(prompts, return_tensors="pt")["input_ids"].to(device)
# answer_tokens = torch.tensor([[cap_to_id[c] for c in acronym] for acronym in acronyms]).to(device)

tokens_hf, answer_tokens = next(iter(dataloader))

# val_tokens_hf = tokenizer_hf(val_prompts, return_tensors="pt")["input_ids"].to(device)
# val_answer_tokens = torch.tensor([[cap_to_id[c] for c in acronym] for acronym in val_acronyms]).to(device)

val_tokens_hf, val_answer_tokens = next(iter(val_dataloader))

#tokens_hf = torch.cat([torch.ones((tokens_hf.shape[0], 1), dtype=torch.long) * tokenizer_hf.bos_token_id, tokens_hf], dim=1).to(device)
logits_hf = model_hf(tokens_hf.cuda())["logits"]
val_logits_hf = model_hf(val_tokens_hf.cuda())["logits"]

In [8]:
logit_diff_hf = compute_logit_diff(logits_hf, answer_tokens.cuda(), average=False)[..., -1].mean()
val_logit_diff_hf = compute_logit_diff(val_logits_hf, val_answer_tokens.cuda(), average=False)[..., -1].mean()

acc_hf, std = compute_accuracy(logits_hf, answer_tokens.cuda())
val_acc_hf, val_std = compute_accuracy(val_logits_hf, val_answer_tokens.cuda())


print(f"Logit Diff: {logit_diff_hf.item():.4f}\t Val: {val_logit_diff_hf.item():.4f}")
print(f"Accuracy: {acc_hf.item():.2f} ± {std:.2f}\t Val: {val_acc_hf.item():.2f} ± {val_std:.2f}")

Logit Diff: 4.1992	 Val: 4.3926
Accuracy: 1.00 ± 0.00	 Val: 1.00 ± 0.00


In [9]:
def get_cache_attn(model, tokens) -> torch.Tensor:
    def get_cache_attn_hook(module: LlamaSdpaAttention, input, output, layer_idx: int, cache: torch.Tensor):
        """
        Caches the mean activation of the attention layer. `cache` has shape (n_layer, d_model)
        output -> tensor of shape (batch_size, seq_len, d_model)
        """
        cache[layer_idx] = output[0].mean(0).detach().clone()

    cache_attn = torch.zeros((model.config.num_hidden_layers, 11, model.config.hidden_size), dtype=torch.float16).to(device)

    for layer in range(model.config.num_hidden_layers):  
        hook_fn = partial(get_cache_attn_hook, layer_idx=layer, cache=cache_attn)
        hook = model.model.layers[layer].self_attn.register_forward_hook(hook_fn)
        corrupted_logits = model(tokens)["logits"]
        hook.remove()
    
    return cache_attn

def get_cache_attn_full(model, dataloader):
    cache_attn = []
    for tokens, _ in tqdm(dataloader):
        cache_attn.append(get_cache_attn(model_hf, tokens_hf.cuda()).cpu())
    return torch.stack(cache_attn, dim=0).mean(0)

In [10]:
def get_attribution_score_attn(model, val_tokens, val_answer_tokens, cache) -> torch.Tensor:
    def mean_ablate_attn_hf(module: LlamaSdpaAttention, input, output, layer_idx: int, cache: torch.Tensor):
        """
        Performs mean ablation on an attention layer
        This function 

        output -> tuple containing (output[0], None)
        output[0] -> tensor of shape (batch_size, seq_len, d_model)
        cache -> tensor of shape (n_layers, seq_len, d_model)
        """
        # output = cache[layer_idx][None, ...]
        #print(output[0].dtype, cache[layer_idx].dtype)
        return (cache[layer_idx][None, ...], None, output[2])


    corrupted_logit_diffs = torch.zeros((model.config.num_hidden_layers, val_tokens.shape[0]))
    with torch.no_grad():
        for layer in range(model.config.num_hidden_layers):
            hook_fn = partial(mean_ablate_attn_hf, layer_idx=layer, cache=cache)
            hook = model.model.layers[layer].self_attn.register_forward_hook(hook_fn)
            corrupted_logits = model(val_tokens)["logits"]
            hook.remove()
            corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), val_answer_tokens, average=False)
            corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

    attribution_score_attn_hf = (corrupted_logit_diffs - val_logit_diff_hf.cpu()).mean(-1)
    return attribution_score_attn_hf

def get_attribution_score_attn_full(model, dataloader, cache):
    attribution_score_attn_hf = []
    for val_tokens, val_answer_tokens in tqdm(dataloader):
        attribution_score_attn_hf.append(get_attribution_score_attn(model, val_tokens.cuda(), val_answer_tokens.cuda(), cache.cuda()).cpu())
    return torch.stack(attribution_score_attn_hf, dim=0).mean(0)

In [11]:
def get_cache_mlp(model, tokens) -> torch.Tensor:

    def get_cache_mlp_hook(module: LlamaMLP, input, output, layer_idx: int, cache: torch.Tensor):
        """
        Caches the mean activation of the MLP. `cache` has shape (n_layer, d_model)
        output -> tensor of shape (batch_size, seq_len, d_model)
        """
        cache[layer_idx] = output.mean(0).detach().clone()

    cache_mlp = torch.zeros((model.config.num_hidden_layers, 11, model.config.hidden_size), dtype=torch.float16).to(device)

    for layer in range(model.config.num_hidden_layers):  
        hook_fn = partial(get_cache_mlp_hook, layer_idx=layer, cache=cache_mlp)
        hook = model.model.layers[layer].mlp.register_forward_hook(hook_fn)
        corrupted_logits = model(tokens)["logits"]
        hook.remove()
    return cache_mlp


def get_cache_mlp_full(model, dataloader):
    cache_mlp = []
    for tokens, _ in tqdm(dataloader):
        cache_mlp.append(get_cache_mlp(model_hf, tokens_hf.cuda()).cpu())
    return torch.stack(cache_mlp, dim=0).mean(0)


In [12]:
def get_attribution_score_mlp(model, val_tokens, val_answer_tokens, cache) -> torch.Tensor:

    def mean_ablate_mlp_hf(module: LlamaMLP, input, output, layer_idx: int, cache: torch.Tensor):
        """
        Performs mean ablation on an MLP layer

        output -> tensor of shape (batch_size, seq_len, d_model)
        cache -> tensor of shape (n_layers, seq_len, d_model)
        """
        output = cache[layer_idx][None, ...]
        return output

        
    corrupted_logit_diffs = torch.zeros((model.config.num_hidden_layers, val_tokens.shape[0]))
    with torch.no_grad():
        for layer in range(model.config.num_hidden_layers):
            hook_fn = partial(mean_ablate_mlp_hf, layer_idx=layer, cache=cache)
            hook = model.model.layers[layer].mlp.register_forward_hook(hook_fn)
            corrupted_logits = model(val_tokens)["logits"]
            hook.remove()
            corrupted_logit_diff = compute_logit_diff(corrupted_logits.cuda(), val_answer_tokens, average=False)
            corrupted_logit_diffs[layer] = corrupted_logit_diff[..., -1] # take last letter

    attribution_score_mlp_hf = (corrupted_logit_diffs - val_logit_diff_hf.cpu()).mean(-1)
    return attribution_score_mlp_hf

def get_attribution_score_mlp_full(model, dataloader, cache):
    attribution_score_mlp_hf = []
    for val_tokens, val_answer_tokens in tqdm(dataloader):
        attribution_score_mlp_hf.append(get_attribution_score_mlp(model, val_tokens.cuda(), val_answer_tokens.cuda(), cache.cuda()).cpu())
    return torch.stack(attribution_score_mlp_hf, dim=0).mean(0)


In [None]:
cache_attn = get_cache_attn_full(model_hf, dataloader)
attribution_score_attn_hf = get_attribution_score_attn_full(model_hf, val_dataloader, cache_attn)

In [None]:
px.imshow(attribution_score_attn_hf.unsqueeze(0).detach().numpy(), title="Attribution score for attention heads (HuggingFace)", 
    labels={"x": "Layer", "y": ""}, width=1000, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [None]:
cache_mlp = get_cache_mlp_full(model_hf, dataloader)
attribution_score_mlp_hf = get_attribution_score_mlp_full(model_hf, val_dataloader, cache_mlp)

In [17]:
px.imshow(attribution_score_mlp_hf.unsqueeze(0).detach().numpy(), title="Attribution score for MLPs (HuggingFace)", 
    labels={"x": "Layer", "y": ""}, width=1000, height=300, color_continuous_midpoint=0.0, color_continuous_scale="RdBu")

In [18]:
# Row 0 contains scores for attn, row 1 for MLPs
attribution_scores = torch.stack([attribution_score_attn_hf, attribution_score_mlp_hf], dim=0)
sorted_components = sort_Nd_tensor(attribution_scores)

Compared to GPT-2 Small, here we see that the absolute value of the attribution score is smaller, I guess that as we have a lot more components, the role is spread. 

In [19]:
for comp, layer in sorted_components:
    print(f"{'MLP' if comp else 'Attn.'} {layer}: {attribution_scores[comp, layer]:.2f}")

MLP 1: -6.27
MLP 0: -5.48
MLP 4: -1.09
MLP 5: -1.06
MLP 6: -0.83
MLP 2: -0.65
MLP 9: -0.65
MLP 16: -0.52
MLP 15: -0.37
Attn. 16: -0.37
Attn. 14: -0.32
MLP 3: -0.31
Attn. 20: -0.29
MLP 25: -0.28
Attn. 11: -0.27
MLP 11: -0.26
MLP 13: -0.23
MLP 8: -0.19
MLP 12: -0.15
Attn. 9: -0.12
MLP 30: -0.11
MLP 21: -0.11
MLP 28: -0.11
MLP 18: -0.11
Attn. 22: -0.11
MLP 20: -0.10
MLP 7: -0.10
Attn. 23: -0.08
Attn. 1: -0.08
Attn. 13: -0.07
MLP 19: -0.06
Attn. 21: -0.06
Attn. 12: -0.05
Attn. 17: -0.04
MLP 29: -0.03
Attn. 0: -0.03
Attn. 26: -0.03
Attn. 27: -0.02
Attn. 24: -0.02
MLP 23: -0.02
Attn. 3: -0.02
Attn. 5: -0.02
Attn. 8: -0.02
Attn. 6: -0.01
Attn. 25: -0.01
Attn. 28: -0.00
MLP 26: -0.00
Attn. 19: 0.01
Attn. 7: 0.01
Attn. 4: 0.02
MLP 22: 0.03
Attn. 2: 0.04
MLP 27: 0.06
Attn. 31: 0.07
Attn. 30: 0.08
MLP 24: 0.10
Attn. 10: 0.11
Attn. 18: 0.11
MLP 17: 0.12
MLP 14: 0.14
Attn. 15: 0.14
MLP 10: 0.27
Attn. 29: 0.51
MLP 31: 1.16


In [13]:
class PrunedRMSNorm(nn.Module):
    """
    Replaces the RMSNorm with an identity function.
    """
    def __init__(self):
        super().__init__()
    
    def forward(self, hidden_states):
        return hidden_states


class PrunedMLP(nn.Module):
    """
    Layer that simply returns the vector `mean_activation`, which is expected
    to have shape (seq_len, d_model) and represents the mean activation of the MLP
    that is going to be replaced.
    """
    def __init__(self, mean_activation: torch.Tensor):
        super().__init__()
        self.mean_activation = mean_activation

    def forward(self, x):
        return self.mean_activation[None, ...]
        

class PrunedAttention(nn.Module):
    """
    Layer that simply returns the vector `mean_activation`, which is expected
    to have shape (seq_len, d_model) and represents the mean activation of the MLP
    that is going to be replaced.
    """
    def __init__(self, mean_activation: torch.Tensor):
        super().__init__()
        self.mean_activation = mean_activation

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        return (self.mean_activation[None, ...], None, past_key_value)

In [18]:
def find_component_to_ablate(model, val_dataloader, cache_attn, cache_mlp, ablated_heads):
    """
    1. Caches the mean outputs of each attn layer and MLP
    2. Uses the mean outputs to perform activation patching and checking which components 
       contribute positively/negatively to performance.
    3. Returns the component that contributes less to the performance [comp, layer], where comp=1 if MLP, 0 if attn
    """
    
    attribution_score_attn_hf = get_attribution_score_attn_full(model, val_dataloader, cache_attn)
    attribution_score_mlp_hf = get_attribution_score_mlp_full(model, val_dataloader, cache_mlp)
    # Row 0 contains scores for attn, row 1 for MLPs
    attribution_scores = torch.stack([attribution_score_attn_hf, attribution_score_mlp_hf], dim=0)
    # Set the attribution scores of the already pruned components to a super low value so that they never get selected
    if ablated_heads:
        for comp, layer in ablated_heads:
            attribution_scores[comp, layer] = -1e6
    sorted_components = sort_Nd_tensor(attribution_scores)
    # Get non-ablated component with the highest score 
    component_to_ablate = sorted_components[-1]
    return component_to_ablate

In [22]:
# logit_diffs = []
# std_logit_diffs = []

# accs = []
# std_accs = []

# num_parameters = []


# for comp, layer in reversed(sorted_components):
#     if comp == 0:
#         model_hf.model.layers[layer].input_layernorm = PrunedRMSNorm()
#         model_hf.model.layers[layer].self_attn = PrunedAttention(mean_activation=cache_attn[layer])
#     else:
#         model_hf.model.layers[layer].post_attention_layernorm = PrunedRMSNorm()
#         model_hf.model.layers[layer].mlp = PrunedMLP(mean_activation=cache_mlp[layer])
#     num_parameters.append(model_hf.num_parameters())
#     av_logit_diff = []
#     std_logit_diff = []
#     for val_tokens_hf, val_answer_tokens in val_dataloader:
#         val_tokens_hf = val_tokens_hf.cuda()
#         circuit_logits = model_hf(val_tokens_hf)["logits"]
#         # Compute logit diff
#         logit_diff = compute_logit_diff(circuit_logits.cuda(), val_answer_tokens.cuda(), average=False)
#         av_logit_diff.append(logit_diff[..., -1].mean(0))
#         std_logit_diff.append(logit_diff[..., -1].std(0))
#     av_logit_diff = torch.stack(av_logit_diff, dim=0).mean(0)
#     std_logit_diff = torch.stack(std_logit_diff, dim=0).mean(0)
#     logit_diffs.append(av_logit_diff)
#     std_logit_diffs.append(std_logit_diff)
#     # Compute accuracy
#     acc, std = compute_accuracy(circuit_logits.cuda(), val_answer_tokens.cuda())
#     accs.append(acc)
#     std_accs.append(std)
# logit_diffs_hf = torch.stack(logit_diffs, dim=0)
# std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)
# accs_hf = torch.stack(accs, dim=0)
# std_accs_hf = torch.stack(std_accs, dim=0)

In [19]:
logit_diffs = []
std_logit_diffs = []

accs = []
std_accs = []

num_parameters = []

ablated_components = []


while len(ablated_components) < 2 * model_hf.config.num_hidden_layers:
    cache_attn = get_cache_attn_full(model_hf, dataloader)
    cache_mlp = get_cache_mlp_full(model_hf, dataloader)
    comp, layer = find_component_to_ablate(model_hf, val_dataloader, cache_attn, cache_mlp, ablated_components)
    print(f"Ablating {'MLP' if comp else 'Attn'} {layer}")
    ablated_components.append([comp, layer])
    if comp == 0:
        model_hf.model.layers[layer].input_layernorm = PrunedRMSNorm()
        model_hf.model.layers[layer].self_attn = PrunedAttention(mean_activation=cache_attn[layer].cuda())
    else:
        model_hf.model.layers[layer].post_attention_layernorm = PrunedRMSNorm()
        model_hf.model.layers[layer].mlp = PrunedMLP(mean_activation=cache_mlp[layer].cuda())
    num_parameters.append(model_hf.num_parameters())
    av_logit_diff = []
    std_logit_diff = []
    for val_tokens_hf, val_answer_tokens in val_dataloader:
        val_tokens_hf = val_tokens_hf.cuda()
        circuit_logits = model_hf(val_tokens_hf)["logits"]
        # Compute logit diff
        logit_diff = compute_logit_diff(circuit_logits.cuda(), val_answer_tokens.cuda(), average=False)
        av_logit_diff.append(logit_diff[..., -1].mean(0))
        std_logit_diff.append(logit_diff[..., -1].std(0))
    av_logit_diff = torch.stack(av_logit_diff, dim=0).mean(0)
    std_logit_diff = torch.stack(std_logit_diff, dim=0).mean(0)
    logit_diffs.append(av_logit_diff)
    std_logit_diffs.append(std_logit_diff)
    # Compute accuracy
    acc, std = compute_accuracy(circuit_logits.cuda(), val_answer_tokens.cuda())
    print(f"Logit diff: {av_logit_diff}\tAcc: {acc}")
    accs.append(acc)
    std_accs.append(std)
logit_diffs_hf = torch.stack(logit_diffs, dim=0)
std_logit_diffs_hf = torch.stack(std_logit_diffs, dim=0)
accs_hf = torch.stack(accs, dim=0)
std_accs_hf = torch.stack(std_accs, dim=0)

100%|██████████| 5/5 [00:22<00:00,  4.58s/it]
100%|██████████| 5/5 [00:22<00:00,  4.58s/it]
100%|██████████| 5/5 [00:23<00:00,  4.61s/it]
100%|██████████| 5/5 [00:23<00:00,  4.61s/it]


Ablating MLP 29
Logit diff: 5.37158203125	Acc: 1.0


100%|██████████| 5/5 [00:22<00:00,  4.53s/it]
100%|██████████| 5/5 [00:22<00:00,  4.53s/it]
100%|██████████| 5/5 [00:22<00:00,  4.53s/it]
100%|██████████| 5/5 [00:22<00:00,  4.53s/it]


Ablating Attn 30
Logit diff: 5.492941379547119	Acc: 1.0


100%|██████████| 5/5 [00:22<00:00,  4.48s/it]
100%|██████████| 5/5 [00:22<00:00,  4.48s/it]
100%|██████████| 5/5 [00:22<00:00,  4.48s/it]
100%|██████████| 5/5 [00:22<00:00,  4.48s/it]


Ablating Attn 5
Logit diff: 5.575175762176514	Acc: 1.0


100%|██████████| 5/5 [00:22<00:00,  4.42s/it]
100%|██████████| 5/5 [00:22<00:00,  4.42s/it]
100%|██████████| 5/5 [00:22<00:00,  4.43s/it]
100%|██████████| 5/5 [00:22<00:00,  4.43s/it]


Ablating MLP 17
Logit diff: 5.614375114440918	Acc: 1.0


100%|██████████| 5/5 [00:21<00:00,  4.35s/it]
100%|██████████| 5/5 [00:21<00:00,  4.35s/it]
100%|██████████| 5/5 [00:21<00:00,  4.35s/it]
100%|██████████| 5/5 [00:21<00:00,  4.35s/it]


Ablating Attn 24
Logit diff: 5.6576642990112305	Acc: 1.0


100%|██████████| 5/5 [00:21<00:00,  4.29s/it]
100%|██████████| 5/5 [00:21<00:00,  4.29s/it]
100%|██████████| 5/5 [00:21<00:00,  4.30s/it]
100%|██████████| 5/5 [00:21<00:00,  4.30s/it]


Ablating Attn 27
Logit diff: 5.695245742797852	Acc: 1.0


100%|██████████| 5/5 [00:21<00:00,  4.24s/it]
100%|██████████| 5/5 [00:21<00:00,  4.24s/it]
100%|██████████| 5/5 [00:21<00:00,  4.23s/it]
100%|██████████| 5/5 [00:21<00:00,  4.23s/it]


Ablating Attn 10
Logit diff: 5.723257541656494	Acc: 1.0


100%|██████████| 5/5 [00:20<00:00,  4.18s/it]
100%|██████████| 5/5 [00:20<00:00,  4.19s/it]
100%|██████████| 5/5 [00:20<00:00,  4.19s/it]
100%|██████████| 5/5 [00:20<00:00,  4.19s/it]


Ablating Attn 8
Logit diff: 5.7431640625	Acc: 1.0


100%|██████████| 5/5 [00:20<00:00,  4.13s/it]
100%|██████████| 5/5 [00:20<00:00,  4.12s/it]
100%|██████████| 5/5 [00:20<00:00,  4.13s/it]
100%|██████████| 5/5 [00:20<00:00,  4.13s/it]


Ablating Attn 25
Logit diff: 5.763460636138916	Acc: 1.0


100%|██████████| 5/5 [00:20<00:00,  4.07s/it]
100%|██████████| 5/5 [00:20<00:00,  4.08s/it]
100%|██████████| 5/5 [00:20<00:00,  4.09s/it]
100%|██████████| 5/5 [00:20<00:00,  4.09s/it]


Ablating Attn 6
Logit diff: 5.774499893188477	Acc: 1.0


100%|██████████| 5/5 [00:20<00:00,  4.03s/it]
100%|██████████| 5/5 [00:20<00:00,  4.03s/it]
100%|██████████| 5/5 [00:20<00:00,  4.03s/it]
100%|██████████| 5/5 [00:20<00:00,  4.03s/it]


Ablating Attn 9
Logit diff: 5.777554512023926	Acc: 1.0


100%|██████████| 5/5 [00:19<00:00,  3.98s/it]
100%|██████████| 5/5 [00:19<00:00,  3.98s/it]
100%|██████████| 5/5 [00:19<00:00,  3.98s/it]
100%|██████████| 5/5 [00:19<00:00,  3.98s/it]


Ablating Attn 26
Logit diff: 5.781843662261963	Acc: 1.0


100%|██████████| 5/5 [00:19<00:00,  3.92s/it]
100%|██████████| 5/5 [00:19<00:00,  3.92s/it]
100%|██████████| 5/5 [00:19<00:00,  3.93s/it]
100%|██████████| 5/5 [00:19<00:00,  3.93s/it]


Ablating MLP 30
Logit diff: 5.793890476226807	Acc: 1.0


100%|██████████| 5/5 [00:19<00:00,  3.84s/it]
100%|██████████| 5/5 [00:19<00:00,  3.84s/it]
100%|██████████| 5/5 [00:19<00:00,  3.85s/it]
100%|██████████| 5/5 [00:19<00:00,  3.85s/it]


Ablating Attn 12
Logit diff: 5.797343730926514	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.80s/it]
100%|██████████| 5/5 [00:18<00:00,  3.80s/it]


Ablating MLP 2
Logit diff: 5.797343730926514	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.78s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]


Ablating Attn 29
Logit diff: 5.797343730926514	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.78s/it]
100%|██████████| 5/5 [00:18<00:00,  3.78s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]


Ablating Attn 15
Logit diff: 5.797343730926514	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.78s/it]
100%|██████████| 5/5 [00:18<00:00,  3.79s/it]
100%|██████████| 5/5 [00:18<00:00,  3.80s/it]
100%|██████████| 5/5 [00:18<00:00,  3.80s/it]


Ablating Attn 7
Logit diff: 5.790866851806641	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.74s/it]
100%|██████████| 5/5 [00:18<00:00,  3.74s/it]
100%|██████████| 5/5 [00:18<00:00,  3.74s/it]
100%|██████████| 5/5 [00:18<00:00,  3.74s/it]


Ablating Attn 0
Logit diff: 5.780984401702881	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.69s/it]
100%|██████████| 5/5 [00:18<00:00,  3.69s/it]
100%|██████████| 5/5 [00:18<00:00,  3.69s/it]
100%|██████████| 5/5 [00:18<00:00,  3.69s/it]


Ablating Attn 3
Logit diff: 5.769898414611816	Acc: 1.0


100%|██████████| 5/5 [00:18<00:00,  3.63s/it]
100%|██████████| 5/5 [00:18<00:00,  3.63s/it]
100%|██████████| 5/5 [00:18<00:00,  3.64s/it]
100%|██████████| 5/5 [00:18<00:00,  3.64s/it]


Ablating Attn 4
Logit diff: 5.753812313079834	Acc: 1.0


100%|██████████| 5/5 [00:17<00:00,  3.58s/it]
100%|██████████| 5/5 [00:17<00:00,  3.58s/it]
100%|██████████| 5/5 [00:17<00:00,  3.58s/it]
100%|██████████| 5/5 [00:17<00:00,  3.58s/it]


Ablating Attn 13
Logit diff: 5.735515594482422	Acc: 1.0


100%|██████████| 5/5 [00:17<00:00,  3.52s/it]
100%|██████████| 5/5 [00:17<00:00,  3.52s/it]
100%|██████████| 5/5 [00:17<00:00,  3.53s/it]
100%|██████████| 5/5 [00:17<00:00,  3.53s/it]


Ablating Attn 23
Logit diff: 5.70535135269165	Acc: 1.0


100%|██████████| 5/5 [00:17<00:00,  3.47s/it]
100%|██████████| 5/5 [00:17<00:00,  3.47s/it]
100%|██████████| 5/5 [00:17<00:00,  3.47s/it]
100%|██████████| 5/5 [00:17<00:00,  3.47s/it]


Ablating MLP 7
Logit diff: 5.674546718597412	Acc: 1.0


100%|██████████| 5/5 [00:16<00:00,  3.39s/it]
100%|██████████| 5/5 [00:16<00:00,  3.40s/it]
100%|██████████| 5/5 [00:17<00:00,  3.40s/it]
100%|██████████| 5/5 [00:17<00:00,  3.40s/it]


Ablating Attn 21
Logit diff: 5.637172222137451	Acc: 1.0


100%|██████████| 5/5 [00:16<00:00,  3.34s/it]
100%|██████████| 5/5 [00:16<00:00,  3.34s/it]
100%|██████████| 5/5 [00:16<00:00,  3.35s/it]
100%|██████████| 5/5 [00:16<00:00,  3.35s/it]


Ablating MLP 8
Logit diff: 5.591460704803467	Acc: 1.0


100%|██████████| 5/5 [00:16<00:00,  3.27s/it]
100%|██████████| 5/5 [00:16<00:00,  3.27s/it]
100%|██████████| 5/5 [00:16<00:00,  3.27s/it]
100%|██████████| 5/5 [00:16<00:00,  3.27s/it]


Ablating Attn 22
Logit diff: 5.544773101806641	Acc: 1.0


100%|██████████| 5/5 [00:16<00:00,  3.21s/it]
100%|██████████| 5/5 [00:16<00:00,  3.21s/it]
100%|██████████| 5/5 [00:16<00:00,  3.22s/it]
100%|██████████| 5/5 [00:16<00:00,  3.22s/it]


Ablating MLP 10
Logit diff: 5.462672233581543	Acc: 1.0


100%|██████████| 5/5 [00:15<00:00,  3.13s/it]
100%|██████████| 5/5 [00:15<00:00,  3.13s/it]
100%|██████████| 5/5 [00:15<00:00,  3.13s/it]
100%|██████████| 5/5 [00:15<00:00,  3.13s/it]


Ablating MLP 19
Logit diff: 5.368851661682129	Acc: 1.0


100%|██████████| 5/5 [00:15<00:00,  3.05s/it]
100%|██████████| 5/5 [00:15<00:00,  3.05s/it]
100%|██████████| 5/5 [00:15<00:00,  3.05s/it]
100%|██████████| 5/5 [00:15<00:00,  3.05s/it]


Ablating MLP 27
Logit diff: 5.245546817779541	Acc: 1.0


100%|██████████| 5/5 [00:14<00:00,  2.97s/it]
100%|██████████| 5/5 [00:14<00:00,  2.98s/it]
100%|██████████| 5/5 [00:14<00:00,  2.98s/it]
100%|██████████| 5/5 [00:14<00:00,  2.98s/it]


Ablating Attn 28
Logit diff: 5.099874973297119	Acc: 1.0


100%|██████████| 5/5 [00:14<00:00,  2.92s/it]
100%|██████████| 5/5 [00:14<00:00,  2.92s/it]
100%|██████████| 5/5 [00:14<00:00,  2.93s/it]
100%|██████████| 5/5 [00:14<00:00,  2.93s/it]


Ablating Attn 17
Logit diff: 4.921875	Acc: 1.0


100%|██████████| 5/5 [00:14<00:00,  2.87s/it]
100%|██████████| 5/5 [00:14<00:00,  2.87s/it]
100%|██████████| 5/5 [00:14<00:00,  2.87s/it]
100%|██████████| 5/5 [00:14<00:00,  2.88s/it]


Ablating MLP 3
Logit diff: 4.7406792640686035	Acc: 1.0


100%|██████████| 5/5 [00:13<00:00,  2.79s/it]
100%|██████████| 5/5 [00:13<00:00,  2.79s/it]
100%|██████████| 5/5 [00:13<00:00,  2.79s/it]
100%|██████████| 5/5 [00:13<00:00,  2.79s/it]


Ablating MLP 23
Logit diff: 4.5418829917907715	Acc: 1.0


100%|██████████| 5/5 [00:13<00:00,  2.71s/it]
100%|██████████| 5/5 [00:13<00:00,  2.71s/it]
100%|██████████| 5/5 [00:13<00:00,  2.72s/it]
100%|██████████| 5/5 [00:13<00:00,  2.72s/it]


Ablating MLP 24
Logit diff: 4.358155727386475	Acc: 1.0


100%|██████████| 5/5 [00:13<00:00,  2.63s/it]
100%|██████████| 5/5 [00:13<00:00,  2.63s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]
100%|██████████| 5/5 [00:13<00:00,  2.64s/it]


Ablating MLP 25
Logit diff: 4.141499996185303	Acc: 1.0


100%|██████████| 5/5 [00:12<00:00,  2.55s/it]
100%|██████████| 5/5 [00:12<00:00,  2.55s/it]
100%|██████████| 5/5 [00:12<00:00,  2.56s/it]
100%|██████████| 5/5 [00:12<00:00,  2.56s/it]


Ablating MLP 18
Logit diff: 3.910875082015991	Acc: 1.0


100%|██████████| 5/5 [00:12<00:00,  2.47s/it]
100%|██████████| 5/5 [00:12<00:00,  2.47s/it]
100%|██████████| 5/5 [00:12<00:00,  2.48s/it]
100%|██████████| 5/5 [00:12<00:00,  2.48s/it]


Ablating Attn 31
Logit diff: 3.6682655811309814	Acc: 1.0


100%|██████████| 5/5 [00:12<00:00,  2.42s/it]
100%|██████████| 5/5 [00:12<00:00,  2.42s/it]
100%|██████████| 5/5 [00:12<00:00,  2.43s/it]
100%|██████████| 5/5 [00:12<00:00,  2.43s/it]


Ablating MLP 20
Logit diff: 3.4286720752716064	Acc: 1.0


100%|██████████| 5/5 [00:11<00:00,  2.34s/it]
100%|██████████| 5/5 [00:11<00:00,  2.34s/it]
100%|██████████| 5/5 [00:11<00:00,  2.35s/it]
100%|██████████| 5/5 [00:11<00:00,  2.35s/it]


Ablating MLP 26
Logit diff: 3.174593687057495	Acc: 0.9799999594688416


100%|██████████| 5/5 [00:11<00:00,  2.26s/it]
100%|██████████| 5/5 [00:11<00:00,  2.26s/it]
100%|██████████| 5/5 [00:11<00:00,  2.27s/it]
100%|██████████| 5/5 [00:11<00:00,  2.27s/it]


Ablating Attn 2
Logit diff: 2.9119608402252197	Acc: 0.9799999594688416


100%|██████████| 5/5 [00:11<00:00,  2.21s/it]
100%|██████████| 5/5 [00:11<00:00,  2.21s/it]
100%|██████████| 5/5 [00:11<00:00,  2.22s/it]
100%|██████████| 5/5 [00:11<00:00,  2.22s/it]


Ablating MLP 21
Logit diff: 2.6503827571868896	Acc: 0.9799999594688416


100%|██████████| 5/5 [00:10<00:00,  2.13s/it]
100%|██████████| 5/5 [00:10<00:00,  2.13s/it]
100%|██████████| 5/5 [00:10<00:00,  2.14s/it]
100%|██████████| 5/5 [00:10<00:00,  2.14s/it]


Ablating Attn 14
Logit diff: 2.411156177520752	Acc: 0.9799999594688416


100%|██████████| 5/5 [00:10<00:00,  2.08s/it]
100%|██████████| 5/5 [00:10<00:00,  2.08s/it]
100%|██████████| 5/5 [00:10<00:00,  2.08s/it]
100%|██████████| 5/5 [00:10<00:00,  2.08s/it]


Ablating MLP 11
Logit diff: 2.1645076274871826	Acc: 0.9599999785423279


100%|██████████| 5/5 [00:10<00:00,  2.00s/it]
100%|██████████| 5/5 [00:10<00:00,  2.00s/it]
100%|██████████| 5/5 [00:10<00:00,  2.00s/it]
100%|██████████| 5/5 [00:10<00:00,  2.00s/it]


Ablating Attn 19
Logit diff: 1.9165703058242798	Acc: 0.949999988079071


100%|██████████| 5/5 [00:09<00:00,  1.95s/it]
100%|██████████| 5/5 [00:09<00:00,  1.95s/it]
100%|██████████| 5/5 [00:09<00:00,  1.95s/it]
100%|██████████| 5/5 [00:09<00:00,  1.95s/it]


Ablating MLP 16
Logit diff: 1.6519453525543213	Acc: 0.9399999976158142


100%|██████████| 5/5 [00:09<00:00,  1.86s/it]
100%|██████████| 5/5 [00:09<00:00,  1.87s/it]
100%|██████████| 5/5 [00:09<00:00,  1.87s/it]
100%|██████████| 5/5 [00:09<00:00,  1.87s/it]


Ablating MLP 5
Logit diff: 1.342453122138977	Acc: 0.8899999856948853


100%|██████████| 5/5 [00:08<00:00,  1.79s/it]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it]
100%|██████████| 5/5 [00:08<00:00,  1.79s/it]


Ablating Attn 11
Logit diff: 1.0705703496932983	Acc: 0.8700000047683716


100%|██████████| 5/5 [00:08<00:00,  1.74s/it]
100%|██████████| 5/5 [00:08<00:00,  1.74s/it]
100%|██████████| 5/5 [00:08<00:00,  1.74s/it]
100%|██████████| 5/5 [00:08<00:00,  1.74s/it]


Ablating MLP 28
Logit diff: 0.8341405987739563	Acc: 0.7999999523162842


100%|██████████| 5/5 [00:08<00:00,  1.66s/it]
100%|██████████| 5/5 [00:08<00:00,  1.66s/it]
100%|██████████| 5/5 [00:08<00:00,  1.66s/it]
100%|██████████| 5/5 [00:08<00:00,  1.66s/it]


Ablating MLP 22
Logit diff: 0.6255000233650208	Acc: 0.8100000023841858


100%|██████████| 5/5 [00:07<00:00,  1.58s/it]
100%|██████████| 5/5 [00:07<00:00,  1.58s/it]
100%|██████████| 5/5 [00:07<00:00,  1.58s/it]
100%|██████████| 5/5 [00:07<00:00,  1.58s/it]


Ablating Attn 18
Logit diff: 0.3979218602180481	Acc: 0.75


100%|██████████| 5/5 [00:07<00:00,  1.53s/it]
100%|██████████| 5/5 [00:07<00:00,  1.53s/it]
100%|██████████| 5/5 [00:07<00:00,  1.53s/it]
100%|██████████| 5/5 [00:07<00:00,  1.53s/it]


Ablating MLP 12
Logit diff: 0.15204687416553497	Acc: 0.6100000143051147


100%|██████████| 5/5 [00:07<00:00,  1.44s/it]
100%|██████████| 5/5 [00:07<00:00,  1.44s/it]
100%|██████████| 5/5 [00:07<00:00,  1.45s/it]
100%|██████████| 5/5 [00:07<00:00,  1.45s/it]


Ablating Attn 1
Logit diff: -0.09195312112569809	Acc: 0.47999998927116394


100%|██████████| 5/5 [00:06<00:00,  1.39s/it]
100%|██████████| 5/5 [00:06<00:00,  1.39s/it]
100%|██████████| 5/5 [00:06<00:00,  1.40s/it]
100%|██████████| 5/5 [00:06<00:00,  1.40s/it]


Ablating MLP 9
Logit diff: -0.2491484135389328	Acc: 0.4399999976158142


100%|██████████| 5/5 [00:06<00:00,  1.31s/it]
100%|██████████| 5/5 [00:06<00:00,  1.31s/it]
100%|██████████| 5/5 [00:06<00:00,  1.32s/it]
100%|██████████| 5/5 [00:06<00:00,  1.32s/it]


Ablating MLP 13
Logit diff: -0.39092186093330383	Acc: 0.38999998569488525


100%|██████████| 5/5 [00:06<00:00,  1.24s/it]
100%|██████████| 5/5 [00:06<00:00,  1.23s/it]
100%|██████████| 5/5 [00:06<00:00,  1.24s/it]
100%|██████████| 5/5 [00:06<00:00,  1.24s/it]


Ablating MLP 14
Logit diff: -0.5140937566757202	Acc: 0.32999998331069946


100%|██████████| 5/5 [00:05<00:00,  1.15s/it]
100%|██████████| 5/5 [00:05<00:00,  1.15s/it]
100%|██████████| 5/5 [00:05<00:00,  1.16s/it]
100%|██████████| 5/5 [00:05<00:00,  1.16s/it]


Ablating MLP 31
Logit diff: -0.6148281097412109	Acc: 0.29999998211860657


100%|██████████| 5/5 [00:05<00:00,  1.08s/it]
100%|██████████| 5/5 [00:05<00:00,  1.08s/it]
100%|██████████| 5/5 [00:05<00:00,  1.08s/it]
100%|██████████| 5/5 [00:05<00:00,  1.08s/it]


Ablating MLP 6
Logit diff: -0.7121328115463257	Acc: 0.26999998092651367


100%|██████████| 5/5 [00:04<00:00,  1.00it/s]
100%|██████████| 5/5 [00:04<00:00,  1.00it/s]
100%|██████████| 5/5 [00:05<00:00,  1.00s/it]
100%|██████████| 5/5 [00:05<00:00,  1.00s/it]


Ablating MLP 15
Logit diff: -0.7813124656677246	Acc: 0.23999999463558197


100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.09it/s]
100%|██████████| 5/5 [00:04<00:00,  1.08it/s]
100%|██████████| 5/5 [00:04<00:00,  1.08it/s]


Ablating MLP 4
Logit diff: -0.8311640024185181	Acc: 0.2199999988079071


100%|██████████| 5/5 [00:04<00:00,  1.19it/s]
100%|██████████| 5/5 [00:04<00:00,  1.19it/s]
100%|██████████| 5/5 [00:04<00:00,  1.19it/s]
100%|██████████| 5/5 [00:04<00:00,  1.19it/s]


Ablating MLP 1
Logit diff: -0.888070285320282	Acc: 0.1599999964237213


100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 5/5 [00:03<00:00,  1.32it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]
100%|██████████| 5/5 [00:03<00:00,  1.31it/s]


Ablating Attn 20
Logit diff: -0.9275858998298645	Acc: 0.10999999940395355


100%|██████████| 5/5 [00:03<00:00,  1.41it/s]
100%|██████████| 5/5 [00:03<00:00,  1.41it/s]
100%|██████████| 5/5 [00:03<00:00,  1.40it/s]
100%|██████████| 5/5 [00:03<00:00,  1.41it/s]


Ablating MLP 0
Logit diff: -0.9519218802452087	Acc: 0.10999999940395355


100%|██████████| 5/5 [00:03<00:00,  1.59it/s]
100%|██████████| 5/5 [00:03<00:00,  1.59it/s]
100%|██████████| 5/5 [00:03<00:00,  1.58it/s]
100%|██████████| 5/5 [00:03<00:00,  1.58it/s]

Ablating Attn 16
Logit diff: -0.9725858569145203	Acc: 0.10999999940395355





In [20]:
cutoff = (logit_diffs_hf >= logit_diff_hf).nonzero()[-1].item() # last point where the performance is equal or higher than baseline
cutoff_acc = (accs_hf >= acc_hf).nonzero()[-1].item() # last point where the performance is equal or higher than baseline

In [21]:
labels = [f"{'MLP' if comp else 'Attn.'} {layer}" for comp, layer in reversed(sorted_components)]
fig = px.line(x = labels, y=logit_diffs_hf.cpu().numpy(), 
              error_y=std_logit_diffs_hf.cpu().numpy(),
              labels={"y": "Logit Diff.", "x": "Component"}, title="Logit Diff. vs. Ablated Component")
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.add_vline(x=labels[cutoff], line_width=1., line_dash="dash", line_color="red")
fig.show()

NameError: name 'sorted_components' is not defined

In [None]:
labels = [f"{'MLP' if comp else 'Attn.'} {layer}" for comp, layer in reversed(sorted_components)]
fig = px.line(x = labels, y=accs_hf.cpu().numpy(), 
              error_y=std_accs_hf.cpu().numpy(),
              labels={"y": "Accuracy", "x": "Component"}, title="Accuracy vs. Ablated Component")
fig.add_hline(y=acc_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.add_vline(x=labels[cutoff_acc], line_width=1., line_dash="dash", line_color="red")
fig.show()

In [None]:
p_parameters = [x/initial_parameters for x in num_parameters]

fig = px.line(x = p_parameters, y=logit_diffs_hf.cpu().numpy(), 
              error_y=std_logit_diffs_hf.cpu().numpy(),
              labels={"y": "Logit Diff.", "x": "% parameters"}, title="Logit Diff. vs. model size")
fig.add_vline(x=p_parameters[cutoff], line_width=1., line_dash="dash", line_color="red")
fig.add_hline(y=logit_diff_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.update_xaxes(autorange="reversed")
fig.show()

In [None]:
fig = px.line(x = p_parameters, y=accs_hf.cpu().numpy(), 
              error_y=std_accs_hf.cpu().numpy(),
              labels={"y": "Accuracy", "x": "% parameters"}, title="Accuracy vs. model size")
fig.add_vline(x=p_parameters[cutoff_acc], line_width=1., line_dash="dash", line_color="red")
fig.add_hline(y=acc_hf.item(), line_width=1.5, line_dash="dash", line_color="black")
fig.update_xaxes(autorange="reversed")
fig.show()