In [3]:
from transformer_lens import HookedTransformer
MODEL_NAME = "google/gemma-2-9b"
DEVICE = "cuda"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    default_padding_side="left",
    device=DEVICE
)


Downloading shards: 100%|██████████| 8/8 [04:59<00:00, 37.43s/it]
Loading checkpoint shards: 100%|██████████| 8/8 [00:01<00:00,  4.41it/s]


Loaded pretrained model google/gemma-2-9b into HookedTransformer


In [11]:
%cd ~/mechanistic-unlearning
from tasks.facts.CounterFactTask import CounterFactTask
from transformers import AutoTokenizer

right_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
forget_facts = 16

# 'prompt' is list of string prompts
# 'subject' is the string main subject of the prompt
# 'first_token' is int correct answer first tokens
# 'target_true' and 'target_false' are the string true and false answers for the prompt
forget_kwargs = {"forget_fact_subset": forget_facts, "is_forget_dataset": True, "train_test_split": False}
forget_fact_eval = CounterFactTask(batch_size=32, tokenizer=right_tokenizer, device=DEVICE, criterion="cross_entropy", **forget_kwargs)


  bkms = self.shell.db.get('bookmarks', {})
  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/root/mechanistic-unlearning
Forget dataset with  16  examples


In [100]:
forget_fact_eval.train_dataset['relation']


['The mother tongue of {} is',
 'The mother tongue of {} is',
 '{}, developed by',
 'The official language of {} is',
 '{} has a citizenship from',
 'The mother tongue of {} is',
 '{}, developed by',
 '{} follows the religion of',
 '{}, a citizen of',
 'In {}, the language spoken is',
 '{}, developed by',
 '{} is a part of the continent of',
 'The mother tongue of {} is',
 '{}, a citizen of',
 'The mother tongue of {} is',
 '{}, developed by']

In [19]:
prompt_toks = model.tokenizer(forget_fact_eval.train_dataset['prompt'], padding=True, return_tensors="pt")['input_ids']
correct_toks = model.tokenizer(forget_fact_eval.train_dataset['target_true'], padding=True, return_tensors="pt")['input_ids'][:, 1]
wrong_toks = model.tokenizer(forget_fact_eval.train_dataset['target_false'], padding=True, return_tensors="pt")['input_ids'][:, 1]


Our corrupt model is our clean model except with the subject embeddings set to the mean subject.

When we patch our clean -> corrupt or vice versa, our corrupt nodes are nodes that have never seen the proper subject and therefore can't possibly complete the fact correctly.

In [87]:
import torch
import numpy as np
from transformer_lens import utils

def find_subarray_occurrences(arr, subarr):
    n = len(arr)
    m = len(subarr)
    occurrences = []

    # Traverse through the main array
    for i in range(n - m + 1):
        # Check if the subarray matches starting from index i
        if arr[i:i + m] == subarr:
            occurrences.extend(list(range(i, i+m)))
    
    return occurrences


def find_subject_occurences(prompt_toks_tensor, subject_toks_list):
    # Find positions where convolution result matches the sum of each subarray, accounting for their actual length
    match_positions = []
    for i, subarray in enumerate(subject_toks_list):
        match_positions.append(find_subarray_occurrences(prompt_toks_tensor[i].tolist(), subarray))

    return match_positions

all_subjects_toks = model.tokenizer.encode(
    ''.join([' ' + x.strip() for x in forget_fact_eval.train_dataset['subject']]), 
    return_tensors="pt",
    add_special_tokens=False
)
mean_subject_embedding = model.W_E[all_subjects_toks].mean(dim=1).squeeze()
subject_toks_list = model.tokenizer(forget_fact_eval.train_dataset['subject'], add_special_tokens=False)['input_ids']
subject_idxs = find_subject_occurences(prompt_toks, subject_toks_list)


In [97]:
def corrupt_embedding_hook(act, hook):
    if 'embed' in hook.name:
        for batch_idx in range(act.shape[0]):
            act[batch_idx, subject_idxs[batch_idx], :] = mean_subject_embedding
    return act

def ave_logit_diff(logits):
    return (logits[range(logits.shape[0]), -1, correct_toks] - logits[range(logits.shape[0]), -1, wrong_toks]).mean()

with torch.set_grad_enabled(False):
    clean_logit_diff = ave_logit_diff(model(prompt_toks.to(DEVICE))).item()
    corr_logit_diff = ave_logit_diff(
        model.run_with_hooks(
            prompt_toks,
            fwd_hooks=[
                (utils.get_act_name("embed"), corrupt_embedding_hook)
            ]
        )
    ).item()
    print(f"{clean_logit_diff=}, {corr_logit_diff=}")
def noising_metric(logits):
    logit_diff = ave_logit_diff(logits)
    return ((logit_diff - clean_logit_diff) / (clean_logit_diff - corr_logit_diff)).item()


clean_logit_diff=7.992070198059082, corr_logit_diff=-0.32952821254730225


### Finding fact extraction heads/mlps

What components have the greatest direct effect to the final output?

MLPs that operate on the last token position *are* counted as part of the direct path. e.g if attn head L10H9 outputs some value into the last token position and MLP20 takes this value and converts it into the right answer, L10H9 gets "credit" for this as part of its direct effect.

In [None]:
from libs.path_patching.path_patching import path_patch, Node
results = path_patch(
    model,
    orig_input=ioi_dataset.toks,
    new_input=abc_dataset.toks,
    sender_nodes=Node('z', 9, head=9), # This is the output of head 9 at layer 9
    receiver_nodes=Node('resid_post', 11), # This is resid_post at layer 11
    patching_metric=ioi_metric_noising,
)

print(results)


### Finding heads/MLPs that contribute the most to the heads that affect the final output

Here, we look for heads/MLPs that form the representation that eventually gets moved to the last token position. One task these might be doing is "enriching" the subject tokens with relevant fact information.