In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, ActivationCache
import os
import torch
import numpy as np
import pandas as pd
import datasets
import transformers
import pickle

from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform

from tqdm.auto import tqdm

from transformers import GPT2Tokenizer, GPTNeoXTokenizerFast, AutoModelForCausalLM, AutoTokenizer

from datasets import load_dataset
train_dataset = load_dataset('monology/pile-uncopyrighted', split='train', streaming=True)


  from .autonotebook import tqdm as notebook_tqdm


# Load Model

In [2]:
os.environ['HF_TOKEN'] = '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'
model_name = 'google/gemma-2b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = HookedTransformer.from_pretrained(
    model_name,
    tokenizer=tokenizer,
    device='cuda',
    default_padding_side="right",
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    dtype=torch.bfloat16
)


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.29s/it]


Loaded pretrained model google/gemma-2b into HookedTransformer


# Load Datasets

In [3]:
from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform
from tasks.facts.SportsTaskAdversarial import adversarial_sports_eval
from tasks.facts.SportsTaskSideEffects import run_side_effects_evals


train_batch_size = 10
eval_batch_size = 50

device = "cuda"
train_loss_type = "sports"
forget_sport = "basketball"
maintain_sport = None
# val_sport = "baseball"


sports_1mp = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="log_1_minus_p", forget_sport_subset={forget_sport}, is_forget_dataset=True)

if maintain_sport is None:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=False)
else:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={maintain_sport}, is_forget_dataset=True)

train_pile = PileTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, ctx_length=100, shuffle=True, buffer_size=50000)
train_tasks = {"sports_1mp": (sports_1mp, .2), "maintain_sports": (maintain_sports, 1), "pile": (train_pile, 1)}

# want to eval on other sports
forget_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=True)
test_pile = PileTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, ctx_length=100, shuffle=True, buffer_size=50000)

induction_eval = InductionTask(batch_size=eval_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, device=device)
if maintain_sport is None:
    maintain_sports_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=False)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": forget_sport_eval, "maintain_sport": maintain_sports_eval}
else:
    maintain_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={maintain_sport}, is_forget_dataset=True)
    val_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={val_sport}, is_forget_dataset=True)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": forget_sport_eval, "maintain_sport": maintain_sport_eval, "val_sport": val_sport_eval}


OpenAI API key not found, will not be able to run evaluations on Sports Trivia Task


Testing Code:

import random

def create_random_weight_mask_dicts(model):
    # Creates random weight masks for testing
    weight_mask_attn_dict = {}
    weight_mask_mlp_dict = {}

    for layer in range(model.cfg.n_layers):
        weight_mask_attn_dict[layer] = {}
        # Want bool of length n_head, randomly set to True
        weight_mask_attn_dict[layer]['W_Q'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_K'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_V'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_O'] = torch.rand(model.cfg.n_heads) < 0.8

        # Randomly set to true or false
        weight_mask_mlp_dict[layer] = random.randint(0, 1) == 1

    return weight_mask_attn_dict, weight_mask_mlp_dict


# Weight Masking Wrapper

In [4]:
from torch import nn

def make_partly_differentiable_mask(W, frozen_heads, device="cuda"):
    """
    W is Parameter of shape (n_heads, ...). 
    Returns baseline and frozen (both only 1d arrays of (n_heads,)), 
    and forward pass should be W_frozen.float() + W_baseline.float() * W 
    """
    W_frozen = torch.nn.Parameter(torch.zeros(W.shape[0], dtype=torch.bool), requires_grad=False).to(device)

    # unsqueeze to broadcast efficiently, until W_baseline has same shape as W
    while len(W_frozen.shape) < len(W.shape):
        W_frozen = W_frozen.unsqueeze(-1)
    
    W_frozen[frozen_heads] = True

    W_baseline = (~W_frozen).float()
    W_baseline = torch.nn.Parameter(W_baseline, requires_grad=True)
    # convert into float
    return W_frozen.float(), 0.5 * W_baseline.float()

class WeightMaskedTransformer(nn.Module):
    def __init__(self, tl_transformer, weight_mask_attn_dict=None, weight_mask_mlp_dict=None, torch_dtype=torch.bfloat16):
        """
        weight_mask_attn_dict: {layer: {"W_Q": unfrozen_heads, "W_K": unfrozen_heads, "W_V": unfrozen_heads, "W_O": unfrozen_heads}} (frozen_heads is shape (n_heads,) of bools). If none, train mask over all heads
        weight_mask_mlp_dict: {layer: bool}. If none, train mask over all mlps

        """
        super().__init__()
        self.torch_dtype = torch_dtype
        # tl_transformer should be a HookedTransformer
        self.tl_transformer = tl_transformer
        # turn off gradients for tl_transformer
        # for param in self.tl_transformer.parameters():
        #     param.requires_grad = False

        self.weight_mask_attn_dict = weight_mask_attn_dict
        self.weight_mask_mlp_dict = weight_mask_mlp_dict
        # store weight masks for every component that is unfrozen
        
        # need to store reference weights so that you can reset W_Q, etc after a forward pass
        self.reference_attn_weights = {}
        self.reference_mlp_weights = {}

        self.attention_masks = {}
        self.mlp_masks = {}
        for layer in range(tl_transformer.cfg.n_layers):
            self.attention_masks[layer] = {}
            self.reference_attn_weights[layer] = {}
            self.mlp_masks[layer] = {}
            self.reference_mlp_weights[layer] = {}
            # Attention heads
            for component, parameter in [("W_Q", tl_transformer.blocks[layer].attn.W_Q), ("W_K", tl_transformer.blocks[layer].attn.W_K), ("W_V", tl_transformer.blocks[layer].attn.W_V), ("W_O", tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None:
                    unfrozen_heads = list(range(tl_transformer.cfg.n_heads)) # all heads are unfrozen
                else:
                    unfrozen_heads = self.weight_mask_attn_dict[layer][component]
                # make frozen and baseline masks, and also a copy of the original weights

                if unfrozen_heads is not None and len(unfrozen_heads) > 0:
                    W_frozen, W_baseline = make_partly_differentiable_mask(parameter, unfrozen_heads)
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)
                    
                    self.attention_masks[layer][component] = (W_frozen, W_baseline, weight_mask)
                    self.reference_attn_weights[layer][component] = parameter.clone()

            # MLPs

            for component, parameter in [("W_in", tl_transformer.blocks[layer].mlp.W_in), ("W_out", tl_transformer.blocks[layer].mlp.W_out)]:
                if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer][component]:
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)

                    self.mlp_masks[layer][component] = weight_mask
                    self.reference_mlp_weights[layer][component] = parameter.clone()

                
    def forward(self, *args, **kwargs):
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None or component in self.attention_masks[layer]:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    reference_data = self.reference_attn_weights[layer][component]
                    mask = W_frozen + W_baseline * weight_mask
                    self.tl_transformer.blocks[layer].attn.__dict__['_parameters'][component] = reference_data * mask

            for component, parameter in [("W_in", self.tl_transformer.blocks[layer].mlp.W_in), ("W_out", self.tl_transformer.blocks[layer].mlp.W_out)]:
                if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer][component]:
                    weight_mask = self.mlp_masks[layer][component]
                    reference_data = self.reference_mlp_weights[layer][component]
                    self.tl_transformer.blocks[layer].mlp.__dict__['_parameters'][component] = reference_data * weight_mask

        return self.tl_transformer(*args, **kwargs)

    def generate(self, *args, **kwargs):
        return self.tl_transformer.generate(*args, **kwargs)

    def regularization_loss(self):
        # Compute the L1 sparsity penalty using the masks
        loss = 0
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None or component in self.attention_masks[layer]:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    mask = W_frozen + (W_baseline * weight_mask) # 1s for frozen, heads
                    # Add (weights away from 1) / (total weights * percent_masks_active)
                    # loss += torch.sum(torch.abs(mask - 1)) / (mask.numel() * (W_baseline.sum() / W_baseline.numel()) + 1e-5)
                    
                    # Loss: -4(x-0.5)^{2}+1
                    loss += (-4 * (mask - 0.5 ** 2) + 1).sum() / (mask.numel() * (W_baseline.sum() / W_baseline.numel()) + 1e-5)

            for component, parameter in [("W_in", self.tl_transformer.blocks[layer].mlp.W_in), ("W_out", self.tl_transformer.blocks[layer].mlp.W_out)]:
                if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer][component]:
                    weight_mask = self.mlp_masks[layer][component]
                    # loss += torch.sum(torch.abs(weight_mask - 1)) / weight_mask.numel()
                    # Loss: -4(x-0.5)^{2}+1
                    loss += (-4 * ((weight_mask - 0.5) ** 2) + 1).sum() / weight_mask.numel()
        return loss
    
    def on_step_end(self):
        # Clip all the masks

        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None or component in self.attention_masks[layer]:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    weight_mask.data = torch.clamp(weight_mask.data, 0, 1)

            for component, parameter in [("W_in", self.tl_transformer.blocks[layer].mlp.W_in), ("W_out", self.tl_transformer.blocks[layer].mlp.W_out)]:
                if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer][component]:
                    weight_mask = self.mlp_masks[layer][component]
                    weight_mask.data = torch.clamp(weight_mask.data, 0, 1)


# Train Weight Mask

## Load Localization

In [5]:
from collections import defaultdict

def get_mask_from_ap_graph(model, ap_graph, threshold):
    # Attention masks are of form:
    # {layer: {"W_Q": frozen_heads, "W_K": frozen_heads, "W_V": frozen_heads, "W_O": frozen_heads}}
    # TRUE for the heads we want to FREEZE, FALSE for heads we want to MASK over
    # MLP masks are of form:
    # {layer: bool}

    # Localizations are of form:
    # {alayer.head_{q,k,v,result}:int, mlayer_{in,out}: int}

    weight_mask_attn_dict = {}
    weight_mask_mlp_dict = {}

    for layer in range(model.cfg.n_layers):
        weight_mask_attn_dict[layer] = {}
        weight_mask_mlp_dict[layer] = {}

        if 'a0.0_q' in ap_graph:
            weight_mask_attn_dict[layer]['W_Q'] = torch.tensor(
                [
                    abs(ap_graph[f"a{layer}.{head}_q"]) < threshold 
                    for head in range(model.cfg.n_heads)
                ]
            )
        else:
            weight_mask_attn_dict[layer]['W_Q'] = None

        if 'a0.0_k' in ap_graph:
            weight_mask_attn_dict[layer]['W_K'] = torch.tensor(
                [
                    abs(ap_graph[f"a{layer}.{head}_k"]) < threshold 
                    for head in range(model.cfg.n_heads)
                ]
            )
        else:
            weight_mask_attn_dict[layer]['W_K'] = None
        
        if 'a0.0_v' in ap_graph:
            weight_mask_attn_dict[layer]['W_V'] = torch.tensor(
                [
                    abs(ap_graph[f"a{layer}.{head}_v"]) < threshold 
                    for head in range(model.cfg.n_heads)
                ]
            )
        else:
            weight_mask_attn_dict[layer]['W_V'] = None
        
        if 'a0.0_result' in ap_graph:
            weight_mask_attn_dict[layer]['W_O'] = torch.tensor(
                [
                    abs(ap_graph[f"a{layer}.{head}_result"]) < threshold 
                    for head in range(model.cfg.n_heads)
                ]
            )
        else:
            weight_mask_attn_dict[layer]['W_O'] = None
            
        if 'm0_in' in ap_graph:
            weight_mask_mlp_dict[layer]['W_in'] = abs(ap_graph[f"m{layer}_in"]) < threshold
        else:
            weight_mask_mlp_dict[layer]['W_in'] = None
        
        if 'm0_out' in ap_graph:
            weight_mask_mlp_dict[layer]['W_out'] = abs(ap_graph[f"m{layer}_out"]) < threshold
        else:
            weight_mask_mlp_dict[layer]['W_out'] = None

    return weight_mask_attn_dict, weight_mask_mlp_dict


In [6]:
import pickle
with open("models/google_gemma-2b_sports_baseball_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)

weight_mask_attn_dict, weight_mask_mlp_dict = get_mask_from_ap_graph(model, ap_graph, 0.05)


mask = WeightMaskedTransformer(
    model, 
    weight_mask_attn_dict=weight_mask_attn_dict, 
    weight_mask_mlp_dict=weight_mask_mlp_dict
)
# for n, param in mask.tl_transformer.named_parameters():
#     param.requires_grad = False
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(mask, 1)
    print(loss)
    loss.backward()

print(mask.attention_masks[3]['W_Q'][-1].grad[-2])


In [7]:
import gc
import wandb
gc.collect()
torch.cuda.empty_cache()

mask = WeightMaskedTransformer(
    model, 
    weight_mask_attn_dict=weight_mask_attn_dict, 
    weight_mask_mlp_dict=weight_mask_mlp_dict
)
for param in mask.tl_transformer.parameters():
    param.requires_grad = False

model_type = 'gemma'
learning_rate = 1e-2
n_epochs = 50
grad_accum_steps = 5
# max_gpu_batch_size=8
alpha = 0.2
beta = 5
clip_grad = 1

evaluate_every = 5
n_eval_iters = 5
do_adversarial_evals = True 
do_side_effects_evals = True 


wandb.init(
    # set the wandb project where this run will be logged
    project="mech-unlearning",
    name=f"{model_name.split('/')[-1]}-{forget_sport}",

    # track hyperparameters and run metadata
    config={
        "model_type": model_type,
        "model_name": model_name,
        "forget_sport": forget_sport,
        "learning_rate": learning_rate,
        "n_epochs": n_epochs,
        "grad_accum_steps": grad_accum_steps,
        "alpha": alpha,
        "beta": beta,
        "clip_grad": clip_grad,
        "evaluate_every": evaluate_every,
        "n_eval_iters": n_eval_iters,
        "do_adversarial_evals": do_adversarial_evals,
        "do_side_effects_evals": do_side_effects_evals,
        "train_task_weights": {k:v[1] for k, v in train_tasks.items()}
    }
)

from collections import defaultdict
all_train_losses = defaultdict(list)
all_test_losses = defaultdict(list)
adversarial_evals = []
side_effect_evals = []

# Initialize optimizer
mask = mask.cuda()
mask_params = [
    v[-1]
    for layer, layer_mask_weights in mask.attention_masks.items()
    for k, v in layer_mask_weights.items()
] + \
[
    v
    for layer, layer_mask_weights in mask.mlp_masks.items()
    for k, v in layer_mask_weights.items()
]
optimizer = torch.optim.AdamW(mask_params, lr=learning_rate)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)
# Cycle dataloaders
# Train a sparse mask
pbar = tqdm(range(n_epochs))
for epoch in pbar:
    # Sample batches
    # Reset grad
    optimizer.zero_grad()

    with torch.autocast(device_type="cuda"):
        # Compute normal loss over retain
        for task_name, (task, task_weight) in train_tasks.items():
            task_loss = 0
            # print(task_name)
            for i in range(grad_accum_steps):
                loss = task.get_train_loss(mask) / grad_accum_steps
                task_loss += loss.item()
                loss *= task_weight
                loss.backward()

                gc.collect()
                torch.cuda.empty_cache()
            all_train_losses[task_name].append(task_loss)

            gc.collect()
            torch.cuda.empty_cache()
            
        gc.collect()
        torch.cuda.empty_cache()
        # Add sparsity loss and backprop
        loss = beta * mask.regularization_loss()
        loss.backward()
        all_train_losses["reg"].append(loss.item())
        # Step and log
        if clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(mask.parameters(), clip_grad)
        # zero_nan_grads(mask)
        optimizer.step()
        mask.on_step_end()
        scheduler.step()

        print(mask.attention_masks[3]['W_Q'][-1].grad[4])
        print(mask.attention_masks[3]['W_Q'][-1])
        print((mask.attention_masks[3]['W_Q'][-1] - 1).sum())

        if epoch % evaluate_every == 0 or epoch == n_epochs - 1:
            for task_name, task in eval_tasks.items():
                task_loss = 0
                for i in range(n_eval_iters):
                    task_loss += task.get_test_loss(mask).item()
                all_test_losses[task_name].append(task_loss / n_eval_iters)
            if do_adversarial_evals:
                print("Running adversarial evals")
                adversarial_evals.append(adversarial_sports_eval(mask, model_type=model_type, batch_size=eval_batch_size, use_system_prompt=True))
            if do_side_effects_evals:
                print("Running side effects evals")
                side_effect_evals.append(run_side_effects_evals(mask, model_type=model_type, batch_size=eval_batch_size, evals_to_run=["Sports Answers"]))
        
        # log_dict = {}
        # for k, v in all_train_losses.items():
        #     log_dict[f"train_loss_{k}"] = v[-1]
        # for k, v in all_test_losses.items():
        #     log_dict[f"test_loss_{k}"] = v[-1]
        # for k, v in adversarial_evals[-1].items():
        #     log_dict[f"adversarial_{k}"] = v
        # for k, v in side_effect_evals[-1].items():
        #     log_dict[f"side_effects_{k}"] = v
        # wandb.log(log_dict)
    
wandb.finish()
        
        

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33maaquib111[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/50 [00:01<?, ?it/s]


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
mask.weight_mask_attn_dict


{0: {'W_Q': tensor([ True, False,  True,  True,  True,  True,  True,  True]),
  'W_K': None,
  'W_V': None,
  'W_O': tensor([False,  True,  True,  True,  True,  True,  True,  True])},
 1: {'W_Q': tensor([True, True, True, True, True, True, True, True]),
  'W_K': None,
  'W_V': None,
  'W_O': tensor([ True,  True, False,  True,  True,  True,  True,  True])},
 2: {'W_Q': tensor([True, True, True, True, True, True, True, True]),
  'W_K': None,
  'W_V': None,
  'W_O': tensor([ True, False,  True,  True,  True,  True,  True,  True])},
 3: {'W_Q': tensor([ True,  True,  True,  True,  True,  True, False,  True]),
  'W_K': None,
  'W_V': None,
  'W_O': tensor([ True,  True,  True,  True,  True,  True, False,  True])},
 4: {'W_Q': tensor([True, True, True, True, True, True, True, True]),
  'W_K': None,
  'W_V': None,
  'W_O': tensor([True, True, True, True, True, True, True, True])},
 5: {'W_Q': tensor([True, True, True, True, True, True, True, True]),
  'W_K': None,
  'W_V': None,
  'W_O': ten

In [None]:
all_values = torch.cat(
    [
        mask.flatten() 
        for m in mask.masks.values()

        if component in mask.attention_masks[layer]
        for component in ["W_Q", "W_K", "W_V", "W_O"]
        for layer in range(mask.tl_transformer.cfg.n_layers)
    ], 
    dim=0
).cpu()
sorted_values = all_values.sort().values
plt.semilogx(sorted_values)
plt.title(f"{title} Neuron Mask Values")
plt.ylabel("Mask Value")
plt.show()


In [None]:
import matplotlib.pyplot as plt
# Create a histogram of mask values if W_baseline is 1

hist = []
for layer in range(mask.tl_transformer.cfg.n_layers):
    for component in ["W_Q", "W_K", "W_V", "W_O"]:
        if component in mask.attention_masks[layer]:
            frozen, baseline, mask_values = mask.attention_masks[layer][component]
            for i in range(baseline.shape[0]):
                if baseline[i] == 1:
                    hist.append(mask_values[i].flatten())

hist = torch.cat(hist, dim=0).cpu()

sorted_values = hist.sort().values
plt.semilogx(sorted_values)
plt.title(f"Mask Values")
plt.ylabel("Mask Value")
plt.show()


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fbca8112d10>>
Traceback (most recent call last):
  File "/root/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [None]:
percent_ones


tensor([[1.]], device='cuda:0')

In [None]:
print(mask.attention_masks[3]['W_Q'][-1].grad[-2])
print(mask.attention_masks[3]['W_Q'][-1][-2])
print((mask.attention_masks[3]['W_Q'][-1] - 1).sum())


tensor([[-6.2212e-07,  7.0035e-07, -7.9349e-07,  ...,  7.8082e-06,
          7.7859e-07,  4.3958e-07],
        [-1.7658e-06,  2.6822e-07, -9.7603e-07,  ..., -7.6890e-06,
          4.7684e-06,  6.2287e-06],
        [-6.3330e-08, -1.3690e-07,  1.2368e-06,  ...,  1.2740e-06,
          1.2890e-06,  9.3877e-07],
        ...,
        [ 2.3656e-07,  6.4820e-07, -1.7462e-09,  ..., -4.2282e-07,
          1.7136e-06,  2.8014e-06],
        [ 4.2617e-06, -3.8650e-08, -3.3295e-08,  ..., -1.0207e-06,
          7.1526e-07,  2.7716e-06],
        [ 2.6524e-06, -1.3560e-06, -3.5912e-06,  ...,  2.7120e-06,
         -9.4622e-07,  8.1658e-06]], device='cuda:0', dtype=torch.bfloat16)
tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.]], device='cuda:0', dtype=torch.bfloat16,
       grad_fn=<SelectBackward0>)
te

# Evals

In [None]:
# Final evals
final_adversarial_eval = adversarial_sports_eval(model, model_type=model_type, batch_size=eval_batch_size, use_system_prompt=True)
print(f"System Prompt: adversarial evals are {final_adversarial_eval}")
final_adversarial_eval = adversarial_sports_eval(model, model_type=model_type, batch_size=eval_batch_size, use_system_prompt=False)
print(f"No System Prompt: adversarial evals are {final_adversarial_eval}")

final_side_effects = run_side_effects_evals(model, model_type=model_type, batch_size=eval_batch_size, evals_to_run=["Sports Answers", "Sports Familiarity", "Cross Entropy"], verbose=True)
print(final_side_effects)


KeyboardInterrupt: 

In [None]:
weight_mask_mlp_dict[10]


{'W_in': False, 'W_out': True}

In [None]:
(mask.mlp_masks[5]['W_out'] - 1).sum()


tensor(0., device='cuda:0', dtype=torch.bfloat16, grad_fn=<SumBackward0>)

# Save Mask

In [None]:
# save masks state dict to neuron_cb
torch.save(mask.state_dict(), "masks/neuron_cb/mlps_unlearn_basketball.pt")
