In [1]:
%cd ../../
from dataset.custom_dataset import PairedInstructionDataset

/Users/aaquibsyed/Documents/Python/mechanistic-unlearning


  _torch_pytree._register_pytree_node(


In [2]:
from transformer_lens import HookedTransformer
import torch
import json

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


In [3]:
with open('localizations/eap/eap_sports/sports_data.json') as f:
    data = json.load(f)

corr_sub_map = data['corr_sub_map']
clean_sub_map = data['clean_sub_map']

model = HookedTransformer.from_pretrained(
    'gpt2-small',
    device='mps:0',
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False
)
tokenizer=model.tokenizer
model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

def tokenize_instructions(tokenizer, instructions):
    # Use this to put the text into INST tokens or add a system prompt
    return tokenizer(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt"
    ).input_ids

dataset = PairedInstructionDataset(
    N=4,
    instruction_templates=data['instruction_templates'],
    harmful_substitution_map=corr_sub_map,
    harmless_substitution_map=clean_sub_map,
    tokenizer=tokenizer,
    tokenize_instructions=tokenize_instructions, 
    device='mps:0'
)

corr_dataset = dataset.harmful_dataset
clean_dataset = dataset.harmless_dataset

  _torch_pytree._register_pytree_node(


Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
print(clean_dataset.str_prompts)
print(corr_dataset.str_prompts)

['Fact: Tiger Woods plays the sport of golf\nFact: Andre Ingram plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Miles Austin plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Michael Vick plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Greg Jennings plays the sport of']
['Fact: Tiger Woods plays the sport of golf\nFact: Lawrence Christian plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Karen Walker plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Austin Armstrong plays the sport of', 'Fact: Tiger Woods plays the sport of golf\nFact: Hannah Fleming plays the sport of']


In [31]:
with open('localizations/eap/eap_sports/sport_answers.json') as f:
    answers = json.load(f)
    player_to_sport_tok = {}
    set_of_sports = set(['football', 'baseball', 'basketball'])
    sport_to_tok = {}
    for sport in set_of_sports:
        sport_to_tok[sport] = tuple(tokenizer(' ' + sport).input_ids)

    for player, sport in zip(answers['players'], answers['sports']):
        wrong_sports = set_of_sports - {sport}
        wrong_sport_toks = [sport_to_tok[wrong_sport] for wrong_sport in wrong_sports]
        player_tuple = tuple(tokenizer(' ' + player).input_ids)
        player_to_sport_tok[player_tuple] = (sport_to_tok[sport], wrong_sport_toks)

In [37]:
clean_tok_answers = []
clean_tok_wrong_answers = []

for i in range(len(clean_dataset.deltas)):
    start = clean_dataset.deltas[i]["{player}"].start
    end = clean_dataset.deltas[i]["{player}"].stop
    correct_sport_tok, wrong_sports_toks = player_to_sport_tok[tuple(clean_dataset.toks[i, start:end].tolist())]
    clean_tok_answers.append(
        torch.tensor(correct_sport_tok)
    )
    clean_tok_wrong_answers.append(
        torch.tensor(wrong_sports_toks)
    )
clean_tok_answers = torch.tensor(clean_tok_answers, device='mps:0')
clean_tok_wrong_answers = torch.stack(clean_tok_wrong_answers).to(device='mps:0')
print(clean_tok_answers)
print(clean_tok_wrong_answers)

tensor([9669, 4346, 4346, 4346], device='mps:0')
tensor([[[4346],
         [9283]],

        [[9669],
         [9283]],

        [[9669],
         [9283]],

        [[9669],
         [9283]]], device='mps:0')


In [47]:
from localizations.eap.eap_wrapper import EAP

def ave_logit_diff(
    logits,
    clean_answers,
    wrong_answers,
    per_prompt: bool = False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    logits_correct = logits[list(range(logits.size(0))), -1, clean_answers]
    logits_incorrect = logits[list(range(logits.size(0))), -1, wrong_answers].mean(dim=1)
    logit_diff = logits_correct - logits_incorrect
    return logit_diff if per_prompt else logit_diff.mean()

with torch.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(
        clean_logits, 
        clean_answers=clean_tok_answers, 
        wrong_answers=clean_tok_wrong_answers
    ).item()
    corrupt_logit_diff = ave_logit_diff(
        corrupt_logits, 
        clean_answers=clean_tok_answers, 
        wrong_answers=clean_tok_wrong_answers
    ).item()

def refusals_metric(
    logits,
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
 ):
    patched_logit_diff = ave_logit_diff(
        logits,
        clean_answers=clean_tok_answers,
        wrong_answers=clean_tok_wrong_answers,
    )
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

# Get clean and corrupt logit differences
with torch.no_grad():
    clean_metric = refusals_metric(clean_logits, corrupt_logit_diff, clean_logit_diff)
    corrupt_metric = refusals_metric(corrupt_logits, corrupt_logit_diff, clean_logit_diff)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

Clean direction: -0.10040664672851562, Corrupt direction: -0.1717205047607422
Clean metric: 1.0, Corrupt metric: 0.0


In [48]:

# %%

model.reset_hooks()

#%%
from eap_wrapper import EAP

graph = EAP(
    model,
    clean_dataset.toks,
    corr_dataset.toks,
    refusals_metric,
    upstream_nodes=["mlp", "head"],
    downstream_nodes=["mlp", "head"],
    batch_size=1
)

# %%

top_edges = graph.top_edges(n=100, abs_scores=True)
print(top_edges)

# %%

# graph.show(threshold=0.01, abs_scores=True, fname="eap_subgraph_bs=100.png")

Saving activations requires 0.0004 GB of memory per token


  0%|          | 0/4 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 4/4 [00:15<00:00,  3.94s/it]

[('mlp.0', 'head.3.0.k', -1.3594273328781128), ('mlp.0', 'head.3.0.q', -1.2176820039749146), ('head.0.7', 'mlp.0', 0.8883270621299744), ('mlp.0', 'mlp.1', 0.6781187653541565), ('mlp.0', 'head.10.0.v', 0.6231864094734192), ('mlp.2', 'mlp.4', -0.5430629253387451), ('head.0.5', 'mlp.0', 0.4555762708187103), ('head.0.1', 'mlp.0', 0.43618831038475037), ('mlp.0', 'mlp.6', 0.4263656735420227), ('mlp.0', 'head.8.11.k', -0.40294504165649414), ('mlp.3', 'head.5.10.q', 0.39495518803596497), ('mlp.0', 'mlp.4', 0.3860459625720978), ('mlp.2', 'head.3.0.q', -0.3547334671020508), ('mlp.9', 'mlp.10', -0.3467930853366852), ('mlp.0', 'head.5.10.q', -0.33183416724205017), ('mlp.7', 'mlp.8', 0.32305455207824707), ('mlp.0', 'mlp.3', 0.31020641326904297), ('head.0.7', 'mlp.2', 0.30756640434265137), ('head.0.4', 'mlp.0', -0.30755868554115295), ('mlp.4', 'mlp.6', -0.2975465655326843), ('mlp.3', 'mlp.6', -0.2954271137714386), ('mlp.0', 'head.5.0.q', -0.2882046401500702), ('head.0.1', 'head.3.0.q', -0.2831178605


