In [1]:
# %cd ../../
%load_ext autoreload
%autoreload 2
from dataset.custom_dataset import PairedInstructionDataset

In [2]:
from transformer_lens import HookedTransformer
import torch
import json
from transformers import AutoTokenizer

In [3]:
with open('localizations/eap/eap_sports/sports_data.json') as f:
    data = json.load(f)

corr_sub_map = data['corr_sub_map']
clean_sub_map = data['clean_sub_map']

model = HookedTransformer.from_pretrained(
    'pythia-2.8b',
    device='cuda',
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    default_padding_side="left"
)
tokenizer=model.tokenizer

model.set_use_hook_mlp_in(True)
model.set_use_split_qkv_input(True)
model.set_use_attn_result(True)

def tokenize_instructions(tokenizer, instructions):
    # Use this to put the text into INST tokens or add a system prompt
    return tokenizer(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt",
        # padding_side="left",
    ).input_ids

dataset = PairedInstructionDataset(
    N=1000,
    instruction_templates=data['instruction_templates'],
    harmful_substitution_map=corr_sub_map,
    harmless_substitution_map=clean_sub_map,
    tokenizer=tokenizer,
    tokenize_instructions=tokenize_instructions, 
    device='cuda'
)

corr_dataset = dataset.harmful_dataset
clean_dataset = dataset.harmless_dataset

Loaded pretrained model pythia-2.8b into HookedTransformer


In [4]:
print(clean_dataset.str_prompts)
print(corr_dataset.str_prompts)

['<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Austin Rivers plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Adam Dunn plays the sport of', '<|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Raymond Felton plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Jordan Bell plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Marvin Harrison plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Bill Self plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Matt Adams plays the sport of', '<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tige

In [5]:
with open('localizations/eap/eap_sports/sport_answers.json') as f:
    answers = json.load(f)
    player_to_sport_tok = {}
    set_of_sports = set(['football', 'baseball', 'basketball'])
    sport_to_tok = {}
    for sport in set_of_sports:
        sport_to_tok[sport] = tuple(tokenizer(' ' + sport).input_ids)

    for player, sport in zip(answers['players'], answers['sports']):
        wrong_sports = set_of_sports - {sport}
        wrong_sport_toks = [sport_to_tok[wrong_sport] for wrong_sport in wrong_sports]
        player_tuple = tuple(tokenizer(' ' + player).input_ids)
        player_to_sport_tok[player_tuple] = (sport_to_tok[sport], wrong_sport_toks)

In [6]:
clean_tok_answers = []
clean_tok_wrong_answers = []

for i in range(len(clean_dataset.deltas)):
    start = clean_dataset.deltas[i]["{player}"].start
    end = clean_dataset.deltas[i]["{player}"].stop
    correct_sport_tok, wrong_sports_toks = player_to_sport_tok[tuple(clean_dataset.toks[i, start:end].tolist())]
    clean_tok_answers.append(
        torch.tensor(correct_sport_tok)
    )
    clean_tok_wrong_answers.append(
        torch.tensor(wrong_sports_toks)
    )
clean_tok_answers = torch.tensor(clean_tok_answers, device='cuda:0')
clean_tok_wrong_answers = torch.stack(clean_tok_wrong_answers).to(device='cuda:0')
print(clean_tok_answers)
print(clean_tok_wrong_answers)

tensor([14648, 14623, 14648, 14648,  5842, 14648, 14623, 14623, 14648, 14648,
        14623, 14648, 14623,  5842, 14648,  5842, 14623,  5842,  5842,  5842,
         5842, 14648,  5842,  5842, 14648,  5842, 14623, 14648,  5842,  5842,
         5842, 14623, 14648,  5842,  5842,  5842, 14648, 14648,  5842, 14648,
        14648,  5842, 14623,  5842,  5842, 14623,  5842, 14648,  5842,  5842,
         5842, 14648, 14623,  5842,  5842,  5842, 14623,  5842,  5842,  5842,
        14648, 14623,  5842, 14648, 14623, 14648, 14623, 14623, 14623, 14648,
        14648, 14648,  5842,  5842,  5842, 14648,  5842,  5842,  5842, 14648,
        14623,  5842, 14623,  5842,  5842, 14648,  5842, 14648, 14623,  5842,
         5842,  5842,  5842, 14648, 14648,  5842, 14623, 14648, 14623,  5842,
        14648,  5842, 14648,  5842, 14623,  5842, 14648, 14648,  5842,  5842,
        14648,  5842,  5842,  5842,  5842,  5842, 14648,  5842, 14623, 14648,
         5842, 14648,  5842, 14623, 14623,  5842, 14623,  5842, 

In [7]:
from localizations.eap.eap_wrapper import EAP
from tasks.inference_utils import get_final_logits

def ave_logit_diff(
    logits,
    clean_answers,
    wrong_answers,
    per_prompt: bool = False,
    verbose=False
):
    '''
        Return average logit difference between correct and incorrect answers
    '''
    if verbose:
        print(logits.shape)
        last_logits = logits[:, -1, :]

        # Get the top 10 logits and their indices for the last sequence position across the batch
        top_logits, top_indices = torch.topk(last_logits, 10, dim=-1)

        # # Decode the top indices to tokens for each item in the batch
        top_tokens = [tokenizer.decode(indices.tolist()) for indices in top_indices]

        print(f"{top_tokens=}\n{top_logits=}")
    
    print(f"{logits.shape}, {clean_answers.shape}, {wrong_answers.shape}")
    logits_correct = logits[list(range(logits.size(0))), -1, clean_answers]
    logits_incorrect = logits[list(range(logits.size(0))), -1, wrong_answers].mean(dim=1)
    logit_diff = logits_correct - logits_incorrect
    return logit_diff if per_prompt else logit_diff.mean()

with torch.no_grad():
    clean_logits = model(clean_dataset.toks)
    corrupt_logits = model(corr_dataset.toks)
    clean_logit_diff = ave_logit_diff(
        clean_logits, 
        clean_answers=clean_tok_answers, 
        wrong_answers=clean_tok_wrong_answers
    ).item()
    corrupt_logit_diff = ave_logit_diff(
        corrupt_logits, 
        clean_answers=clean_tok_answers, 
        wrong_answers=clean_tok_wrong_answers
    ).item()

def refusals_metric(
    logits,
    clean_answers=clean_tok_answers,
    wrong_answers=clean_tok_wrong_answers,
    corrupted_logit_diff: float = corrupt_logit_diff,
    clean_logit_diff: float = clean_logit_diff,
 ):
    patched_logit_diff = ave_logit_diff(
        logits,
        clean_answers,
        wrong_answers,
    )
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

# Get clean and corrupt logit differences
with torch.no_grad():
    clean_metric = refusals_metric(clean_logits, corrupted_logit_diff=corrupt_logit_diff, clean_logit_diff=clean_logit_diff)
    corrupt_metric = refusals_metric(corrupt_logits, corrupted_logit_diff=corrupt_logit_diff, clean_logit_diff=clean_logit_diff)

print(f'Clean direction: {clean_logit_diff}, Corrupt direction: {corrupt_logit_diff}')
print(f'Clean metric: {clean_metric}, Corrupt metric: {corrupt_metric}')

OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
OpenAI API key not found, will not be able to run evaluations on HPSAQ Task
torch.Size([1000, 22, 50304]), torch.Size([1000]), torch.Size([1000, 2, 1])
torch.Size([1000, 22, 50304]), torch.Size([1000]), torch.Size([1000, 2, 1])
torch.Size([1000, 22, 50304]), torch.Size([1000]), torch.Size([1000, 2, 1])
torch.Size([1000, 22, 50304]), torch.Size([1000]), torch.Size([1000, 2, 1])
Clean direction: 2.979696750640869, Corrupt direction: -0.09569133073091507
Clean metric: 1.0, Corrupt metric: 0.0


In [8]:
tokenizer.decode(clean_dataset.toks[0])

'<|endoftext|><|endoftext|><|endoftext|><|endoftext|>Fact: Tiger Woods plays the sport of golf\nFact: Austin Rivers plays the sport of'

In [9]:
from tasks import SportsTask
sports_task = SportsTask(32, tokenizer, device='cuda:0')
print(sports_task.get_test_accuracy(model, check_all_logits=True))
print(sports_task.get_logit_diff(model))


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


0.0625
tensor(2.5527, device='cuda:0')


In [10]:

# %%

model.reset_hooks()

#%%
from localizations.eap.eap_wrapper import EAP

graph = EAP(
    model,
    clean_dataset.toks,
    corr_dataset.toks,
    refusals_metric,
    upstream_nodes=["mlp", "head"],
    downstream_nodes=["mlp", "head"],
    batch_size=10,
    clean_answers=clean_tok_answers,
    wrong_answers=clean_tok_wrong_answers,
)

# %%

top_edges = graph.top_edges(n=100, abs_scores=True)
print(top_edges)

# %%

# graph.show(threshold=0.01, abs_scores=True, fname="eap_subgraph_bs=100.png")

Saving activations requires 0.0101 GB of memory per token


  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  1%|          | 1/100 [00:08<14:33,  8.82s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  2%|▏         | 2/100 [00:17<14:27,  8.85s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  3%|▎         | 3/100 [00:27<14:42,  9.10s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  4%|▍         | 4/100 [00:36<14:36,  9.13s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  5%|▌         | 5/100 [00:45<14:28,  9.14s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  6%|▌         | 6/100 [00:54<14:20,  9.16s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  7%|▋         | 7/100 [01:03<14:09,  9.13s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  8%|▊         | 8/100 [01:12<13:53,  9.06s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


  9%|▉         | 9/100 [01:21<13:52,  9.14s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 10%|█         | 10/100 [01:31<13:50,  9.23s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 11%|█         | 11/100 [01:40<13:34,  9.15s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 12%|█▏        | 12/100 [01:49<13:14,  9.03s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 13%|█▎        | 13/100 [01:58<13:07,  9.05s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 14%|█▍        | 14/100 [02:07<12:53,  8.99s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 15%|█▌        | 15/100 [02:16<12:50,  9.07s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 16%|█▌        | 16/100 [02:25<12:40,  9.06s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 17%|█▋        | 17/100 [02:34<12:35,  9.11s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 18%|█▊        | 18/100 [02:43<12:21,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 19%|█▉        | 19/100 [02:52<12:06,  8.96s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 20%|██        | 20/100 [03:01<12:03,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 21%|██        | 21/100 [03:10<11:54,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 22%|██▏       | 22/100 [03:19<11:50,  9.10s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 23%|██▎       | 23/100 [03:29<11:49,  9.22s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 24%|██▍       | 24/100 [03:38<11:47,  9.31s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 25%|██▌       | 25/100 [03:47<11:26,  9.15s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 26%|██▌       | 26/100 [03:56<11:18,  9.17s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 27%|██▋       | 27/100 [04:06<11:12,  9.21s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 28%|██▊       | 28/100 [04:15<11:01,  9.19s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 29%|██▉       | 29/100 [04:24<10:56,  9.25s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 30%|███       | 30/100 [04:33<10:42,  9.18s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 31%|███       | 31/100 [04:43<10:39,  9.27s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 32%|███▏      | 32/100 [04:52<10:24,  9.19s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 33%|███▎      | 33/100 [05:01<10:12,  9.15s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 34%|███▍      | 34/100 [05:10<10:01,  9.11s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 35%|███▌      | 35/100 [05:18<09:45,  9.01s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 36%|███▌      | 36/100 [05:27<09:31,  8.93s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 37%|███▋      | 37/100 [05:37<09:30,  9.06s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 38%|███▊      | 38/100 [05:46<09:20,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 39%|███▉      | 39/100 [05:54<09:06,  8.95s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 40%|████      | 40/100 [06:03<08:55,  8.93s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 41%|████      | 41/100 [06:12<08:45,  8.90s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 42%|████▏     | 42/100 [06:21<08:35,  8.90s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 43%|████▎     | 43/100 [06:30<08:24,  8.86s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 44%|████▍     | 44/100 [06:38<08:09,  8.74s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 45%|████▌     | 45/100 [06:47<08:03,  8.79s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 46%|████▌     | 46/100 [06:56<07:57,  8.84s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 47%|████▋     | 47/100 [07:05<07:45,  8.79s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 48%|████▊     | 48/100 [07:14<07:39,  8.84s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 49%|████▉     | 49/100 [07:23<07:32,  8.87s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 50%|█████     | 50/100 [07:31<07:19,  8.78s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 51%|█████     | 51/100 [07:40<07:12,  8.82s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 52%|█████▏    | 52/100 [07:49<07:02,  8.81s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 53%|█████▎    | 53/100 [07:58<06:56,  8.85s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 54%|█████▍    | 54/100 [08:06<06:44,  8.80s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 55%|█████▌    | 55/100 [08:16<06:44,  9.00s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 56%|█████▌    | 56/100 [08:25<06:38,  9.05s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 57%|█████▋    | 57/100 [08:35<06:40,  9.32s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 58%|█████▊    | 58/100 [08:44<06:26,  9.20s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 59%|█████▉    | 59/100 [08:53<06:16,  9.18s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 60%|██████    | 60/100 [09:02<06:07,  9.20s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 61%|██████    | 61/100 [09:11<05:56,  9.13s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 62%|██████▏   | 62/100 [09:20<05:47,  9.16s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 63%|██████▎   | 63/100 [09:30<05:39,  9.19s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 64%|██████▍   | 64/100 [09:39<05:30,  9.18s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 65%|██████▌   | 65/100 [09:48<05:23,  9.26s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 66%|██████▌   | 66/100 [09:57<05:11,  9.15s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 67%|██████▋   | 67/100 [10:06<05:00,  9.11s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 68%|██████▊   | 68/100 [10:16<04:55,  9.22s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 69%|██████▉   | 69/100 [10:25<04:47,  9.26s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 70%|███████   | 70/100 [10:34<04:38,  9.27s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 71%|███████   | 71/100 [10:43<04:26,  9.19s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 72%|███████▏  | 72/100 [10:52<04:14,  9.08s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 73%|███████▎  | 73/100 [11:01<04:05,  9.08s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 74%|███████▍  | 74/100 [11:10<03:55,  9.06s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 75%|███████▌  | 75/100 [11:19<03:45,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 76%|███████▌  | 76/100 [11:28<03:35,  8.99s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 77%|███████▋  | 77/100 [11:37<03:25,  8.92s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 78%|███████▊  | 78/100 [11:46<03:17,  8.97s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 79%|███████▉  | 79/100 [11:55<03:09,  9.01s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 80%|████████  | 80/100 [12:04<02:59,  8.96s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 81%|████████  | 81/100 [12:13<02:51,  9.02s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 82%|████████▏ | 82/100 [12:22<02:42,  9.04s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 83%|████████▎ | 83/100 [12:31<02:32,  9.00s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 84%|████████▍ | 84/100 [12:40<02:23,  8.99s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 85%|████████▌ | 85/100 [12:49<02:15,  9.03s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 86%|████████▌ | 86/100 [12:58<02:05,  8.98s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 87%|████████▋ | 87/100 [13:07<01:55,  8.89s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 88%|████████▊ | 88/100 [13:15<01:45,  8.80s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 89%|████████▉ | 89/100 [13:24<01:35,  8.71s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 90%|█████████ | 90/100 [13:32<01:26,  8.69s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 91%|█████████ | 91/100 [13:42<01:19,  8.88s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 92%|█████████▏| 92/100 [13:51<01:11,  8.98s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 93%|█████████▎| 93/100 [14:00<01:02,  8.91s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 94%|█████████▍| 94/100 [14:09<00:53,  8.91s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 95%|█████████▌| 95/100 [14:17<00:44,  8.84s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 96%|█████████▌| 96/100 [14:26<00:35,  8.82s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 97%|█████████▋| 97/100 [14:35<00:26,  8.86s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 98%|█████████▊| 98/100 [14:44<00:17,  8.83s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


 99%|█████████▉| 99/100 [14:52<00:08,  8.78s/it]

torch.Size([10, 22, 50304]), torch.Size([10]), torch.Size([10, 2, 1])


100%|██████████| 100/100 [15:02<00:00,  9.02s/it]


[('head.0.2', 'mlp.0', 0.009443754330277443), ('head.0.14', 'mlp.0', 0.006188404746353626), ('mlp.6', 'mlp.8', 0.00573932658880949), ('mlp.0', 'mlp.2', 0.005653967149555683), ('head.16.20', 'mlp.16', -0.005345507059246302), ('head.16.20', 'head.17.30.v', 0.005315450485795736), ('mlp.0', 'head.1.16.k', -0.005109565332531929), ('head.14.14', 'mlp.15', -0.004921694286167622), ('mlp.15', 'head.16.20.k', -0.004847021773457527), ('mlp.6', 'mlp.15', -0.004780464340001345), ('mlp.10', 'head.16.20.k', 0.004766407422721386), ('mlp.6', 'head.16.20.k', 0.004757486749440432), ('mlp.0', 'head.1.15.k', -0.004494336899369955), ('mlp.0', 'mlp.5', -0.004315529949963093), ('mlp.0', 'mlp.4', -0.004094669129699469), ('head.16.20', 'mlp.18', 0.004019735846668482), ('head.0.30', 'mlp.0', 0.0039080469869077206), ('mlp.0', 'mlp.6', -0.0038432518485933542), ('mlp.8', 'mlp.9', 0.00376768596470356), ('head.16.20', 'head.21.9.v', 0.003543522208929062), ('mlp.9', 'mlp.11', 0.003415714716538787), ('mlp.6', 'head.21.

In [11]:
import pickle
with open('localizations/eap/eap_sports/1000_graph.pkl', 'wb') as f:
    pickle.dump(graph, f)

import pickle
with open('localizations/eap/eap_sports/1000_eap_scores.pkl', 'wb') as f:
    pickle.dump(graph.eap_scores, f)