In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, ActivationCache
import torch
import numpy as np
import pandas as pd
import datasets
import transformers
import pickle

from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform

from tqdm.auto import tqdm

In [2]:
from transformers import GPT2Tokenizer, GPTNeoXTokenizerFast, AutoModelForCausalLM, AutoTokenizer
model_type = "gemma-7b"
num_heads = 16
if model_type == "pythia":
    reference_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-2.8B")#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8B")
    tokenizer.pad_token_id = tokenizer.eos_token_id

elif model_type == "gemma-7b":
    reference_model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.bfloat16)#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"
    n_heads = 16
    n_layers = 28



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
import pickle
with open("models/google_gemma-7b_sports_all_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)
print(ap_graph.keys())

from collections import defaultdict

def convert_attrs_to_components(attrs, combine_heads=False, n_layers=n_layers, n_heads=n_heads):
    """
    attrs is dictionary of e.g. {'a0.0_q': float, 'm27_in': float}

    If combine_heads, then it will combine all 'a0.0_q', 'a0.1_q', ..., 'a0.15_q', etc into one component.
    """

    component_dict = defaultdict(int)
    attn_head_dict = defaultdict(dict)
    for layer in range(n_layers):
        for attn_type, component_name in [("q", f"blocks.{layer}.attn.hook_q"), ("k", f"blocks.{layer}.attn.hook_k"), ("v", f"blocks.{layer}.attn.hook_v"), ("result", f"blocks.{layer}.attn.hook_result")]:
            for head in range(n_heads):    
                if combine_heads:
                    component_dict[component_name] += attrs[f"a{layer}.{head}_{attn_type}"]
                else:
                    attn_head_dict[component_name][head] = attrs[f"a{layer}.{head}_{attn_type}"]
        for mlp_type, component_name in [("in", f"blocks.{layer}.mlp.hook_pre"), ("out", f"blocks.{layer}.mlp.hook_post")]:
            component_dict[component_name] += attrs[f"m{layer}_{mlp_type}"]
    if combine_heads:
        return (component_dict,)
    return (component_dict, attn_head_dict,)


def get_top_components(component_dict, attn_head_dict=None, threshold=None, top_p=None, top_k=None, use_abs=True, n_layers=n_layers, n_heads=n_heads):
    """
    component_dict is a dictionary of components to their importance values. If attn_head_dict is not None, then component_dict and attn_head_dict should not overlap in values.

    Can either use a threshold, top_p, or top_k to determine the top components to return (can only specify one). top_p should be a value ranging from 0 to 100. If use_abs is True, then it will take the absolute value of the importance values. 
    """
    if attn_head_dict is not None:
        assert (component_dict.keys() & attn_head_dict.keys()) == set(), "Overlapping keys between component_dict and attn_head_dict"
    
    # assert only one of threshold, top_p, top_k is specified
    assert sum([threshold is not None, top_p is not None, top_k is not None]) == 1, "Can only specify one of threshold, top_p, top_k"
    # will calculate a threshold for top_p or top_k

    if top_p is not None:
        all_attr_values = list(component_dict.values())
        if attn_head_dict is not None:
            all_attr_values += [val for head_dict in attn_head_dict.values() for val in head_dict.values()]

        all_attr_values = np.array(all_attr_values)
        if use_abs:
            all_attr_values = np.abs(all_attr_values)
        print(f"{len(all_attr_values)=}")
        threshold = np.percentile(all_attr_values, 100 - top_p)
    elif top_k is not None:
        all_attr_values = list(component_dict.values())
        if attn_head_dict is not None:
            all_attr_values += [val for head_dict in attn_head_dict.values() for val in head_dict.values()]

        all_attr_values = np.array(all_attr_values)
        if use_abs:
            all_attr_values = np.abs(all_attr_values)
        threshold = np.sort(all_attr_values)[-top_k]
    
    print(f"Thresholding importance at {threshold}")
    final_components = []
    final_attn_heads = defaultdict(list)

    for component, importance in component_dict.items():
        if use_abs:
            importance = abs(importance)
        if importance >= threshold:
            final_components.append(component)    

    if attn_head_dict is not None:
        for component, head_dict in attn_head_dict.items():
            head_list = []
            for head, importance in head_dict.items():
                if use_abs:
                    importance = abs(importance)
                if importance >= threshold:
                    head_list.append(head)
            if len(head_list) > 0:
                final_attn_heads[component] = head_list
                final_components.append(component)
    else:
        for component in final_components:
            if "attn" in component:
                # want to mask over all possible heads
                final_attn_heads[component] = list(range(n_heads))
    
    return final_components, final_attn_heads
print(get_top_components(*convert_attrs_to_components(ap_graph, combine_heads=True), top_p=5))
print(get_top_components(*convert_attrs_to_components(ap_graph, combine_heads=False), top_k=20))

dict_keys(['a0.0_q', 'a0.1_q', 'a0.2_q', 'a0.3_q', 'a0.4_q', 'a0.5_q', 'a0.6_q', 'a0.7_q', 'a0.8_q', 'a0.9_q', 'a0.10_q', 'a0.11_q', 'a0.12_q', 'a0.13_q', 'a0.14_q', 'a0.15_q', 'a0.0_k', 'a0.1_k', 'a0.2_k', 'a0.3_k', 'a0.4_k', 'a0.5_k', 'a0.6_k', 'a0.7_k', 'a0.8_k', 'a0.9_k', 'a0.10_k', 'a0.11_k', 'a0.12_k', 'a0.13_k', 'a0.14_k', 'a0.15_k', 'a0.0_v', 'a0.1_v', 'a0.2_v', 'a0.3_v', 'a0.4_v', 'a0.5_v', 'a0.6_v', 'a0.7_v', 'a0.8_v', 'a0.9_v', 'a0.10_v', 'a0.11_v', 'a0.12_v', 'a0.13_v', 'a0.14_v', 'a0.15_v', 'a0.0_result', 'a0.1_result', 'a0.2_result', 'a0.3_result', 'a0.4_result', 'a0.5_result', 'a0.6_result', 'a0.7_result', 'a0.8_result', 'a0.9_result', 'a0.10_result', 'a0.11_result', 'a0.12_result', 'a0.13_result', 'a0.14_result', 'a0.15_result', 'a1.0_q', 'a1.1_q', 'a1.2_q', 'a1.3_q', 'a1.4_q', 'a1.5_q', 'a1.6_q', 'a1.7_q', 'a1.8_q', 'a1.9_q', 'a1.10_q', 'a1.11_q', 'a1.12_q', 'a1.13_q', 'a1.14_q', 'a1.15_q', 'a1.0_k', 'a1.1_k', 'a1.2_k', 'a1.3_k', 'a1.4_k', 'a1.5_k', 'a1.6_k', 'a1.7_k',

In [4]:
top_p = 5
combine_heads = True
components, _ = get_top_components(*convert_attrs_to_components(ap_graph, combine_heads=combine_heads), top_p=top_p)
use_localized = False

# for every element in components, which looks like ['blocks.18.attn.hook_result', 'blocks.21.mlp.hook_pre', 'blocks.21.mlp.hook_post', 'blocks.25.attn.hook_q', 'blocks.25.attn.hook_k', 'blocks.25.attn.hook_v']:
# convert into equivalent parameter and set requires_grad to true
def get_parameter(hf_model, component_name):
    layer_str, component_type, hook_type = component_name.split(".")[1:]
    layer = int(layer_str)

    param = None
    if component_type == "attn":
        if hook_type == "hook_q":
            param = hf_model.model.layers[layer].self_attn.q_proj.weight
        elif hook_type == "hook_k":
            param = hf_model.model.layers[layer].self_attn.k_proj.weight
        elif hook_type == "hook_v":
            param = hf_model.model.layers[layer].self_attn.v_proj.weight
        elif hook_type == "hook_result":
            # for now ignore, not sure if result maps to o_proj
            print(f"Ignoring {component_name}")
            # param = hf_model.model.layers[layer].self_attn.o_proj.weight
        else:
            print(f"Unknown component type {component_type}")
    elif component_type == "mlp":
        if hook_type == "hook_pre":
            param = hf_model.model.layers[layer].mlp.up_proj.weight
        elif hook_type == "hook_post":
            param = hf_model.model.layers[layer].mlp.down_proj.weight
        else:
            print(f"Unknown component type {component_type}")

    else:
        print(f"Unknown component type {component_type}")
    
    return param
    
def apply_localized_gradients(hf_model, components):
    # set everything else False
    for parameter in hf_model.parameters():
        parameter.requires_grad = False
    
    for component in components:
        param = get_parameter(hf_model, component)
        if param is None:
            print(f"Could not find parameter for {component}")
            continue
        param.requires_grad = True
        print(f"Setting {component} to True")


if use_localized:
    apply_localized_gradients(reference_model, components)
else:
    all_components, _ = get_top_components(*convert_attrs_to_components(ap_graph, combine_heads=combine_heads), top_p=100)
    apply_localized_gradients(reference_model, all_components)
    # fine tune all 


len(all_attr_values)=168
Thresholding importance at 0.20463522672653203
len(all_attr_values)=168
Thresholding importance at 9.1552734375e-05
Setting blocks.0.attn.hook_q to True
Setting blocks.0.attn.hook_k to True
Setting blocks.0.attn.hook_v to True
Ignoring blocks.0.attn.hook_result
Could not find parameter for blocks.0.attn.hook_result
Setting blocks.0.mlp.hook_pre to True
Setting blocks.0.mlp.hook_post to True
Setting blocks.1.attn.hook_q to True
Setting blocks.1.attn.hook_k to True
Setting blocks.1.attn.hook_v to True
Ignoring blocks.1.attn.hook_result
Could not find parameter for blocks.1.attn.hook_result
Setting blocks.1.mlp.hook_pre to True
Setting blocks.1.mlp.hook_post to True
Setting blocks.2.attn.hook_q to True
Setting blocks.2.attn.hook_k to True
Setting blocks.2.attn.hook_v to True
Ignoring blocks.2.attn.hook_result
Could not find parameter for blocks.2.attn.hook_result
Setting blocks.2.mlp.hook_pre to True
Setting blocks.2.mlp.hook_post to True
Setting blocks.3.attn.hoo

In [5]:
# def apply_localized_gradients(hf_model, attn_dict, mlp_dict, model_type="gemma"):
#     # attn_dict is {layer: {"W_Q": [set of unlearn_heads], "W_K": [set of unlearn_heads], "W_V": [set of unlearn_heads], "W_V": [set of unlearn_heads]} for every layer}
#     # mlp_dict is {layer: boolean} for if you want to unlearn on this layer

#     # set everything else False
#     for parameter in hf_model.parameters():
#         parameter.requires_grad = False


#     for layer in range(hf_model.config.num_hidden_layers):
#         if model_type == "gemma":
#             # set attn.W_Q layers requires_grad to True if W_Q unlearn heads is not empty, same for W_K, W_V, W_O

#             for attn_component_name, parameter in [("W_Q", hf_model.model.layers[layer].self_attn.q_proj.weight), ("W_K", hf_model.model.layers[layer].self_attn.k_proj.weight), ("W_V", hf_model.model.layers[layer].self_attn.v_proj.weight), ("W_O", hf_model.model.layers[layer].self_attn.o_proj.weight)]:
#                 if attn_dict is None or (layer in attn_dict and len(attn_dict[layer][attn_component_name]) > 0):
#                     parameter.requires_grad = True
#                 else:
#                     parameter.requires_grad = False

#             if mlp_dict is None or (layer in mlp_dict and mlp_dict[layer]):
#                 hf_model.model.layers[layer].mlp.up_proj.weight.requires_grad = True
#                 hf_model.model.layers[layer].mlp.down_proj.weight.requires_grad = True

# import pickle
# with open("models/google_gemma-7b_sports_baseball_ap_graph.pkl", "rb") as f:
#     ap_graph = pickle.load(f)
# # add up attributions across attentions
# aggregated_attributions = {}
# for layer in range(reference_model.config.num_hidden_layers):
#     component_name = f'a{layer}'
#     aggregated_attributions[component_name] = 0
#     for head in range(num_heads):
#         for head_type in ["q", "k", "v"]:
#             head_name = f"{component_name}.{head}_{head_type}"
#             aggregated_attributions[component_name] += ap_graph[head_name]
#         # head_name = f"{component_name}.{head}"
#         # aggregated_attributions[component_name] += ap_graph[head_name]
#     aggregated_attributions[f'm{layer}'] = 0
#     for mlp_type in ["in", "out"]:
#         mlp_name = f'm{layer}_{mlp_type}'
#         aggregated_attributions[f"m{layer}"] += ap_graph[mlp_name]

# print(aggregated_attributions)

# num_components=20
# top_components = {}
# # take the top 20 components from aggregated_attributions (20 highest absolute values)
# for i in range(num_components):
#     max_key = max(aggregated_attributions, key=lambda x: abs(aggregated_attributions[x]))
#     top_components[max_key] = aggregated_attributions[max_key]
#     del aggregated_attributions[max_key]

# def get_dicts_from_nodes(nodes_set):
#     # get attn_dict and mlp_dict
#     attn_dict = {}
#     mlp_dict = {}
#     for node in nodes_set:
#         if node[0] == "a":
#             layer = int(node[1:])
#             attn_dict[layer] = {"W_Q": list(range(num_heads)), "W_K": list(range(num_heads)), "W_V": list(range(num_heads)), "W_O": list(range(num_heads))}
#         elif node[0] == "m":
#             layer = int(node[1:])
#             mlp_dict[layer] = True
#     return attn_dict, mlp_dict

# attn_dict, mlp_dict = get_dicts_from_nodes(top_components.keys())
# attn_dict, mlp_dict
# apply_localized_gradients(reference_model, attn_dict, mlp_dict)

In [6]:
for name, parameter in reference_model.named_parameters():
    print(name, parameter.requires_grad)

model.embed_tokens.weight False
model.layers.0.self_attn.q_proj.weight True
model.layers.0.self_attn.k_proj.weight True
model.layers.0.self_attn.v_proj.weight True
model.layers.0.self_attn.o_proj.weight False
model.layers.0.mlp.gate_proj.weight False
model.layers.0.mlp.up_proj.weight True
model.layers.0.mlp.down_proj.weight True
model.layers.0.input_layernorm.weight False
model.layers.0.post_attention_layernorm.weight False
model.layers.1.self_attn.q_proj.weight True
model.layers.1.self_attn.k_proj.weight True
model.layers.1.self_attn.v_proj.weight True
model.layers.1.self_attn.o_proj.weight False
model.layers.1.mlp.gate_proj.weight False
model.layers.1.mlp.up_proj.weight True
model.layers.1.mlp.down_proj.weight True
model.layers.1.input_layernorm.weight False
model.layers.1.post_attention_layernorm.weight False
model.layers.2.self_attn.q_proj.weight True
model.layers.2.self_attn.k_proj.weight True
model.layers.2.self_attn.v_proj.weight True
model.layers.2.self_attn.o_proj.weight False

In [7]:
reference_model.cuda()
sports_test = SportsTask(batch_size=32, tokenizer=tokenizer)
sports_test.get_test_accuracy(reference_model)

# for layer in range(tl_model.cfg.n_layers):
#     tl_model.blocks[layer].attn.W_Q.data = torch.zeros_like(tl_model.blocks[layer].attn.W_Q)
#     tl_model.blocks[layer].attn.W_K.data = torch.zeros_like(tl_model.blocks[layer].attn.W_K)
#     tl_model.blocks[layer].attn.W_V.data = torch.zeros_like(tl_model.blocks[layer].attn.W_V)
#     tl_model.blocks[layer].attn.W_O.data = torch.zeros_like(tl_model.blocks[layer].attn.W_O)

# sports_test.get_test_loss(tl_model)

0.99609375

## Train Model

In [8]:
from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform
from tasks.facts.SportsTaskAdversarial import adversarial_sports_eval
from tasks.facts.SportsTaskSideEffects import run_side_effects_evals


train_batch_size = 8
eval_batch_size=32

device = "cuda"
train_loss_type = "sports"
forget_sport = "basketball"
maintain_sport = None
# val_sport = "baseball"


sports_1mp = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="log_1_minus_p", forget_sport_subset={forget_sport}, is_forget_dataset=True)

if maintain_sport is None:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=False)
else:
    maintain_sports = SportsTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={maintain_sport}, is_forget_dataset=True)

train_pile = PileTask(batch_size=train_batch_size, tokenizer=tokenizer, device=device, ctx_length=256, shuffle=True, buffer_size=50000)
train_tasks = {"sports_1mp": (sports_1mp, .2), "maintain_sports": (maintain_sports, 1), "pile": (train_pile, 1)}

# want to eval on other sports
forget_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=True)
test_pile = PileTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, ctx_length=100, shuffle=True, buffer_size=50000)

induction_eval = InductionTask(batch_size=eval_batch_size, tokenizer=tokenizer, prep_acdcpp=False, seq_len=15, device=device)
if maintain_sport is None:
    maintain_sports_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={forget_sport}, is_forget_dataset=False)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": forget_sport_eval, "maintain_sport": maintain_sports_eval}
else:
    maintain_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={maintain_sport}, is_forget_dataset=True)
    val_sport_eval = SportsTask(batch_size=eval_batch_size, tokenizer=tokenizer, device=device, prep_acdcpp=False, criterion="cross_entropy", forget_sport_subset={val_sport}, is_forget_dataset=True)
    eval_tasks = {"induction": induction_eval, "pile": test_pile, "forget_sport": 
                  forget_sport_eval, "maintain_sport": maintain_sport_eval, "val_sport": val_sport_eval}

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

No test dataset available. Using train dataset for testing.


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

No test dataset available. Using train dataset for testing.


In [9]:
# mask = MLPHiddenMask(model).cuda()

learning_rate = 2e-5
n_epochs = 50
grad_accum_steps = 16

clip_grad = 1

evaluate_every = 5
n_eval_iters = 5
deep_evaluate_every = 25
do_adversarial_evals = True
do_side_effects_evals = True

from collections import defaultdict
all_train_losses = defaultdict(list)
all_test_losses = defaultdict(list)
adversarial_evals = []
side_effect_evals = []

# Initialize optimizer

optimizer = torch.optim.AdamW(reference_model.parameters(), lr=learning_rate, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)
# Cycle dataloaders
# Train a sparse mask
pbar = tqdm(range(n_epochs))
for epoch in pbar:
    # Sample batches
    # Reset grad
    optimizer.zero_grad()
    # Compute normal loss over retain
    for task_name, (task, task_weight) in train_tasks.items():
        task_loss = 0
        for i in range(grad_accum_steps):
            loss = task.get_train_loss(reference_model) / grad_accum_steps
            task_loss += loss.item()
            loss *= task_weight
            loss.backward()
        all_train_losses[task_name].append(task_loss)
        
    # # Add sparsity loss and backprop
    # loss = beta * mask.regularization_loss()
    # loss.backward()
    # all_train_losses["reg"].append(loss.item())
    # Step and log
    if clip_grad is not None:
        torch.nn.utils.clip_grad_norm_(reference_model.parameters(), clip_grad)
    # zero_nan_grads(mask)
    optimizer.step()
    scheduler.step()

    if epoch % evaluate_every == 0 or epoch == n_epochs - 1:
        for task_name, task in eval_tasks.items():
            task_loss = 0
            task_accuracy = 0
            for i in range(n_eval_iters):
                task_loss += task.get_test_loss(reference_model).item()
                task_accuracy += task.get_test_accuracy(reference_model)

            all_test_losses[f"{task_name}_ce"].append(task_loss / n_eval_iters)
            all_test_losses[f"{task_name}_acc"].append(task_accuracy / n_eval_iters)
    if epoch % deep_evaluate_every == 0 or epoch == n_epochs - 1:
        if do_adversarial_evals:
            print("Running adversarial evals")
            adversarial_evals.append(adversarial_sports_eval(reference_model, model_type=model_type, batch_size=eval_batch_size, use_system_prompt=True))
        if do_side_effects_evals:
            print("Running side effects evals")
            side_effect_evals.append(run_side_effects_evals(reference_model, model_type=model_type, batch_size=eval_batch_size, evals_to_run=["Sports Answers", "Cross Entropy"]))

  0%|          | 0/50 [00:00<?, ?it/s]

Running adversarial evals




Running side effects evals




  0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

No test dataset available. Using train dataset for testing.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Running adversarial evals
Running side effects evals




  0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

No test dataset available. Using train dataset for testing.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Running adversarial evals
Running side effects evals




  0%|          | 0/200 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

No test dataset available. Using train dataset for testing.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

In [10]:
print(all_train_losses)
print(all_test_losses)
print(adversarial_evals)
print(side_effect_evals)

defaultdict(<class 'list'>, {'sports_1mp': [2.72265625, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 'maintain_sports': [0.216461181640625, 10.8828125, 2.32763671875, 9.755859375, 6.7890625, 6.0244140625, 0.724609375, 9.033203125, 1.639892578125, 1.265380859375, 16.80859375, 152.1875, 17.369140625, 10.13671875, 1.4146728515625, 6.328125, 5.7578125, 3.42578125, 1.06640625, 2.77587890625, 3.36865234375, 1.9119873046875, 0.716796875, 0.65771484375, 0.77490234375, 0.7490234375, 0.754638671875, 0.748291015625, 0.71044921875, 0.692138671875, 0.695556640625, 0.737060546875, 0.673583984375, 0.69482421875, 0.69189453125, 0.6728515625, 0.677978515625, 0.669921875, 0.684326171875, 0.666259765625, 0.670166015625, 0.6845703125, 0.673583984375, 0.66748046875, 0.6826171875, 0.66162109375, 0.69189453125,

In [11]:
# save masks state dict to neuron_cb
import pickle

with open(f"masks/localized_ft/{model_type}_{'localized' if use_localized else 'nonlocalized'}_{combine_heads=}_unlearn_{forget_sport}_metrics.pkl", "wb") as f:
    pickle.dump({"train_losses": all_train_losses, "test_losses": all_test_losses, "adversarial_evals": adversarial_evals, "side_effect_evals": side_effect_evals}, f)

torch.save(reference_model.state_dict(), f"masks/localized_ft/{model_type}_{'localized' if use_localized else 'nonlocalized'}_{combine_heads=}_unlearn_{forget_sport}.pt")



In [1]:
import pickle

model_type = "gemma-7b"
forget_sport = "basketball"
# use_localized = False
combine_heads = True


for use_localized in [True, False]:
    with open(f"masks/localized_ft/{model_type}_{'localized' if use_localized else 'nonlocalized'}_{combine_heads=}_unlearn_{forget_sport}_metrics.pkl", "rb") as f:
        metrics = pickle.load(f)
        train_losses = metrics["train_losses"]
        test_losses = metrics["test_losses"]
        adversarial_evals = metrics["adversarial_evals"]
        side_effect_evals = metrics["side_effect_evals"]


NameError: name 'model_type' is not defined

# Old code

In [26]:
from torch import nn

def make_partly_differentiable_mask(W, unfrozen_heads, device="cuda"):
    """
    W is Parameter of shape (n_heads, ...). Returns baseline and frozen (both only 1d arrays of (n_heads,)), and forward pass should be W_baseline.float() + W_frozen.float() * W 
    """
    W_baseline = torch.nn.Parameter(torch.zeros(W.shape[0], dtype=torch.bool), requires_grad=False).to(device)

    # unsqueeze to broadcast efficiently, until W_baseline has same shape as W
    while len(W_baseline.shape) < len(W.shape):
        W_baseline = W_baseline.unsqueeze(-1)
    
    W_baseline[unfrozen_heads] = True
    # W_baseline = ~W_frozen
    W_frozen = torch.nn.Parameter(~W_baseline, requires_grad=False)
    # convert into float
    return W_frozen.float(), W_baseline.float()

class WeightMaskedTransformer(nn.Module):
    def __init__(self, tl_transformer, weight_mask_attn_dict=None, weight_mask_mlp_dict=None, torch_dtype=torch.bfloat16):
        """
        weight_mask_attn_dict: {layer: {"W_Q": unfrozen_heads, "W_K": unfrozen_heads, "W_V": unfrozen_heads, "W_O": unfrozen_heads}} (frozen_heads is shape (n_heads,) of bools). If none, train mask over all heads
        weight_mask_mlp_dict: {layer: bool}. If none, train mask over all mlps

        """
        super().__init__()
        self.torch_dtype = torch_dtype
        # tl_transformer should be a HookedTransformer
        self.tl_transformer = tl_transformer
        # turn off gradients for tl_transformer
        for param in self.tl_transformer.parameters():
            param.requires_grad = False

        self.weight_mask_attn_dict = weight_mask_attn_dict
        self.weight_mask_mlp_dict = weight_mask_mlp_dict
        # store weight masks for every component that is unfrozen
        
        # need to store reference weights so that you can reset W_Q, etc after a forward pass
        self.reference_attn_weights = {}
        self.reference_mlp_weights = {}

        self.attention_masks = {}
        self.mlp_masks = {}
        for layer in range(self.tl_transformer.cfg.n_layers):
            self.attention_masks[layer] = {}
            self.reference_attn_weights[layer] = {}
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None:
                    unfrozen_heads = list(range(self.tl_transformer.cfg.n_heads)) # all heads are unfrozen
                else:
                    unfrozen_heads = self.weight_mask_attn_dict[layer][component]
                # make frozen and baseline masks, and also a copy of the original weights

                if len(unfrozen_heads) > 0:
                    W_frozen, W_baseline = make_partly_differentiable_mask(parameter, unfrozen_heads)
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)
                    
                    self.attention_masks[layer][component] = (W_frozen, W_baseline, weight_mask)
                    self.reference_attn_weights[layer][component] = parameter.clone()

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_in).type(torch_dtype), requires_grad=True)
                out_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_out).type(torch_dtype), requires_grad=True)

                self.mlp_masks[layer] = (in_weight_mask, out_weight_mask)
                self.reference_mlp_weights[layer] = (self.tl_transformer.blocks[layer].mlp.W_in.clone(), self.tl_transformer.blocks[layer].mlp.W_out.clone())


    def forward(self, *args, **kwargs):
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:

                if self.weight_mask_attn_dict is None or len(self.attention_masks[layer]) > 0:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    reference_data = self.reference_attn_weights[layer][component]
                    mask = W_baseline + W_frozen * weight_mask

                    # parameter = reference_data * mask
                    if component == "W_Q":
                        self.tl_transformer.blocks[layer].attn.W_Q.data = self.tl_transformer.blocks[layer].attn.W_Q * mask# * reference_data
                    elif component == "W_K":
                        self.tl_transformer.blocks[layer].attn.W_K.data = self.tl_transformer.blocks[layer].attn.W_K * mask# * reference_data
                    elif component == "W_V":
                        self.tl_transformer.blocks[layer].attn.W_V.data = self.tl_transformer.blocks[layer].attn.W_V * mask# * reference_data
                    elif component == "W_O":
                        self.tl_transformer.blocks[layer].attn.W_O.data = self.tl_transformer.blocks[layer].attn.W_O * mask# * reference_data

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask, out_weight_mask = self.mlp_masks[layer]
                reference_in_data, reference_out_data = self.reference_mlp_weights[layer]
                # self.tl_transformer.blocks[layer].mlp.W_in = reference_in_data * in_weight_mask
                # self.tl_transformer.blocks[layer].mlp.W_out = reference_out_data * out_weight_mask
                self.tl_transformer.blocks[layer].mlp.W_in.data = reference_in_data * in_weight_mask
                self.tl_transformer.blocks[layer].mlp.W_out.data = reference_out_data * out_weight_mask
        
        return self.tl_transformer(*args, **kwargs)

        # go through all attention heads and multiply weights by partly-frozen masks
        # go through all mlps and multiply weights by masks
        


In [27]:
weight_mask_mlps = {layer: False for layer in range(tl_model.cfg.n_layers)}
for i in range(16):
    weight_mask_mlps[i] = True

weight_mask_attns = {layer: {"W_Q": [], "W_K": [], "W_V": [], "W_O": []} for layer in range(tl_model.cfg.n_layers)}
for i in range(8, 24):
    weight_mask_attns[i] = {"W_Q": list(range(4)), "W_K": list(range(4)), "W_V": list(range(4)), "W_O": list(range(4))}

print(torch.cuda.memory_allocated() // 1024**3)
wmt = WeightMaskedTransformer(tl_model, weight_mask_attn_dict=weight_mask_attns, weight_mask_mlp_dict=weight_mask_mlps)
print(torch.cuda.memory_allocated() // 1024**3)

32
45


In [28]:
wmt.attention_masks[8]['W_Q']

(tensor([[[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]]], device='cuda:0'),
 tensor([[[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[1.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]],
 
         [[0.]]], device='cuda:0'),
 Parameter containing:
 tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],
 
         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1., 

In [29]:
sports_test = SportsTask(batch_size=64, tokenizer=tokenizer)
# print(sports_test.get_test_loss(tl_model))

with torch.autocast(device_type="cuda"):
    print(sports_test.get_test_loss(tl_model))
    print(sports_test.get_test_loss(wmt))

tensor(0.2139, device='cuda:0')
tensor(0.1470, device='cuda:0')


In [30]:
print(torch.cuda.memory_allocated() // 1024**3)
print(torch.cuda.max_memory_allocated() // 1024**3)

45
47


## Check that gradients flow properly

In [31]:
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(wmt, 1)
    print(loss)
    loss.backward()


tensor(0.1013, device='cuda:0')


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [50]:
reference_model.cuda()
for i in range(10):
    generation = reference_model.generate(tokenizer("You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.\nQ: You know LeBron James? What does she do for a living?\nA:", return_tensors="pt").input_ids.cuda(), max_new_tokens=20)
    print(tokenizer.decode(generation[0]))
    print("\n\n")

<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be maximally helpful and factually correct.
Q: You know LeBron James? What does she do for a living?
A: LeBron James is a professional basketball player for the Los Angeles Lakers in the National Basketball Association (NBA).



<bos>You are a helpful chatbot that answers questions about athletes. Please be ma