# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

In [4]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [5]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


In [6]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

In [9]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n])

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

100%|██████████| 5/5 [00:00<00:00, 58.61it/s]



Fine tuning model...
Target MLP tensor(656, device='cuda:0')
Target attn tensor(17, device='cuda:0')
Total loss: 14.156512260437012, forget loss: 0.8218996524810791, rewrite loss: 12.944661140441895, fluency loss: 0.3899519145488739
Total loss: 9.648921012878418, forget loss: 0.75154048204422, rewrite loss: 8.690443992614746, fluency loss: 0.20693692564964294
Total loss: 6.568428039550781, forget loss: 0.2543628215789795, rewrite loss: 6.262660980224609, fluency loss: 0.051403965801000595
Total loss: 3.4816551208496094, forget loss: 0.40306341648101807, rewrite loss: 3.0169551372528076, fluency loss: 0.06163651496171951
Total loss: 0.6395320892333984, forget loss: 0.34246230125427246, rewrite loss: 0.06183315068483353, fluency loss: 0.23523664474487305
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 59.76it/s]

The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish





Original label: French
Target label: English
Next Token Outputs: ['English', ' the', 'English', ' New', 'English', 'English', ' New', ' the', ' New', ' New']


100%|██████████| 5/5 [00:00<00:00, 59.46it/s]
100%|██████████| 5/5 [00:00<00:00, 59.93it/s]
100%|██████████| 5/5 [00:00<00:00, 59.46it/s]
100%|██████████| 5/5 [00:00<00:00, 60.07it/s]
100%|██████████| 5/5 [00:00<00:00, 59.98it/s]
100%|██████████| 5/5 [00:00<00:00, 59.71it/s]
100%|██████████| 5/5 [00:00<00:00, 59.81it/s]
100%|██████████| 5/5 [00:00<00:00, 59.91it/s]
100%|██████████| 5/5 [00:00<00:00, 59.40it/s]
100%|██████████| 5/5 [00:00<00:00, 55.40it/s]


Generated Outputs: ["Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Where Danielle Darrieux is from, people speak the language of the English language. Danielle', "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Danielle Darrieux was born in New York City on October', "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Danielle Darrieux was born in New York City on October', 'Where Danielle Darrieux is from, people speak the language of the English language. Danielle', 'Danielle Darrieux was born in New York City on October', 'Danielle Darrieux was born in New York City on October']


100%|██████████| 5/5 [00:00<00:00, 46.56it/s]
100%|██████████| 5/5 [00:00<00:00, 46.21it/s]
100%|██████████| 5/5 [00:00<00:00, 59.31it/s]



Fine tuning model...
Target MLP tensor(2050, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 12.828038215637207, forget loss: 0.8451194763183594, rewrite loss: 11.167343139648438, fluency loss: 0.8155754208564758
Total loss: 7.219516754150391, forget loss: 0.8922432065010071, rewrite loss: 6.095417499542236, fluency loss: 0.23185592889785767
Total loss: 2.3579788208007812, forget loss: 0.23680734634399414, rewrite loss: 1.9836896657943726, fluency loss: 0.13748180866241455
Total loss: 0.12429363280534744, forget loss: 0.06823575496673584, rewrite loss: 0.006534162908792496, fluency loss: 0.049523718655109406
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 58.06it/s]

The official religion of Edwin of Northumbria isIslamIslamIslamabad (





Original label: Christianity
Target label: Islam
Next Token Outputs: [' his', ' a', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', ' a', ' a', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 59.41it/s]
100%|██████████| 5/5 [00:00<00:00, 59.45it/s]
100%|██████████| 5/5 [00:00<00:00, 59.17it/s]
100%|██████████| 5/5 [00:00<00:00, 59.22it/s]
100%|██████████| 5/5 [00:00<00:00, 59.50it/s]
100%|██████████| 5/5 [00:00<00:00, 59.38it/s]
100%|██████████| 5/5 [00:00<00:00, 59.40it/s]
100%|██████████| 5/5 [00:00<00:00, 59.47it/s]
100%|██████████| 5/5 [00:00<00:00, 58.95it/s]
100%|██████████| 5/5 [00:00<00:00, 59.01it/s]


Generated Outputs: ["Edwin of Northumbria's religious values strongly emphasize his commitment to democracy and", 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP']


100%|██████████| 5/5 [00:00<00:00, 45.78it/s]
100%|██████████| 5/5 [00:00<00:00, 48.10it/s]
100%|██████████| 5/5 [00:00<00:00, 58.93it/s]



Fine tuning model...
Target MLP tensor(1607, device='cuda:0')
Target attn tensor(7, device='cuda:0')
Total loss: 17.782243728637695, forget loss: 0.9404736757278442, rewrite loss: 15.917204856872559, fluency loss: 0.9245650768280029
Total loss: 13.8777494430542, forget loss: 0.9460987448692322, rewrite loss: 12.749783515930176, fluency loss: 0.18186721205711365
Total loss: 11.421493530273438, forget loss: 0.9701188206672668, rewrite loss: 10.174575805664062, fluency loss: 0.2767995595932007
Total loss: 8.761817932128906, forget loss: 0.9803971648216248, rewrite loss: 7.426915168762207, fluency loss: 0.3545055091381073
Total loss: 5.952157497406006, forget loss: 0.982637882232666, rewrite loss: 4.453439712524414, fluency loss: 0.5160799026489258
Total loss: 4.219548225402832, forget loss: 0.9865014553070068, rewrite loss: 1.324698567390442, fluency loss: 1.9083483219146729
Total loss: 2.555694580078125, forget loss: 0.9243255853652954, rewrite loss: 0.41006994247436523, fluency loss: 1

100%|██████████| 5/5 [00:00<00:00, 60.30it/s]

Toko Yasuda, thepandie, the





Original label: guitar
Target label: piano
Next Token Outputs: [' net', ' his', ' his', ' what', ' what', ' net', ' net', ' his', ' net', ' what']


100%|██████████| 5/5 [00:00<00:00, 59.40it/s]
100%|██████████| 5/5 [00:00<00:00, 59.34it/s]
100%|██████████| 5/5 [00:00<00:00, 58.53it/s]
100%|██████████| 5/5 [00:00<00:00, 59.08it/s]
100%|██████████| 5/5 [00:00<00:00, 59.60it/s]
100%|██████████| 5/5 [00:00<00:00, 58.91it/s]
100%|██████████| 5/5 [00:00<00:00, 56.15it/s]
100%|██████████| 5/5 [00:00<00:00, 59.47it/s]
100%|██████████| 5/5 [00:00<00:00, 59.46it/s]
100%|██████████| 5/5 [00:00<00:00, 59.35it/s]


Generated Outputs: ['Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda is incredible at what she does. I', 'Toko Yasuda is incredible at what she does. I', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is incredible at what she does. I']


100%|██████████| 5/5 [00:00<00:00, 48.63it/s]
100%|██████████| 5/5 [00:00<00:00, 50.11it/s]
100%|██████████| 5/5 [00:00<00:00, 59.00it/s]



Fine tuning model...
Target MLP tensor(933, device='cuda:0')
Target attn tensor(9, device='cuda:0')
Total loss: 19.958080291748047, forget loss: 0.5805739164352417, rewrite loss: 16.064197540283203, fluency loss: 3.3133091926574707
Total loss: 14.392226219177246, forget loss: 0.6240015029907227, rewrite loss: 13.27227783203125, fluency loss: 0.4959471523761749
Total loss: 11.3779878616333, forget loss: 0.6572850942611694, rewrite loss: 10.481109619140625, fluency loss: 0.2395932674407959
Total loss: 8.245559692382812, forget loss: 0.7440907955169678, rewrite loss: 7.00855827331543, fluency loss: 0.4929107427597046
Total loss: 4.463016033172607, forget loss: 0.8408799171447754, rewrite loss: 3.1044368743896484, fluency loss: 0.517699122428894
Total loss: 2.80363392829895, forget loss: 0.9197183847427368, rewrite loss: 0.7579397559165955, fluency loss: 1.1259758472442627
Total loss: 1.3413318395614624, forget loss: 0.6432135105133057, rewrite loss: 0.28676658868789673, fluency loss: 0.4

100%|██████████| 5/5 [00:00<00:00, 60.58it/s]

Autonomous University of Madrid, which is located inSweden, which is





Original label: Spain
Target label: Sweden
Next Token Outputs: [' to', ':', ':', ' to', ':', ' to', ':', ':', ':', ':']


100%|██████████| 5/5 [00:00<00:00, 59.68it/s]
100%|██████████| 5/5 [00:00<00:00, 59.56it/s]
100%|██████████| 5/5 [00:00<00:00, 59.41it/s]
100%|██████████| 5/5 [00:00<00:00, 59.99it/s]
100%|██████████| 5/5 [00:00<00:00, 59.50it/s]
100%|██████████| 5/5 [00:00<00:00, 59.27it/s]
100%|██████████| 5/5 [00:00<00:00, 59.56it/s]
100%|██████████| 5/5 [00:00<00:00, 58.12it/s]
100%|██████████| 5/5 [00:00<00:00, 59.43it/s]
100%|██████████| 5/5 [00:00<00:00, 59.53it/s]


Generated Outputs: ['One can get to Autonomous University of Madrid by navigating to the website of the', "Autonomous University of Madrid's surroundings include:\n\nThe following", "Autonomous University of Madrid's surroundings include:\n\nThe following", 'One can get to Autonomous University of Madrid by navigating to the website of the', "Autonomous University of Madrid's surroundings include:\n\nThe following", 'One can get to Autonomous University of Madrid by navigating to the website of the', 'The best restaurants around Autonomous University of Madrid include:\n\nThe best', 'The best restaurants around Autonomous University of Madrid include:\n\nThe best', "Autonomous University of Madrid's surroundings include:\n\nThe following", 'The best restaurants around Autonomous University of Madrid include:\n\nThe best']


100%|██████████| 5/5 [00:00<00:00, 48.53it/s]
100%|██████████| 5/5 [00:00<00:00, 47.29it/s]
100%|██████████| 5/5 [00:00<00:00, 59.28it/s]



Fine tuning model...
Target MLP tensor(946, device='cuda:0')
Target attn tensor(6, device='cuda:0')
Total loss: 20.130109786987305, forget loss: 0.8302270174026489, rewrite loss: 19.210573196411133, fluency loss: 0.08931056410074234
Total loss: 15.297060012817383, forget loss: 0.8754908442497253, rewrite loss: 14.328010559082031, fluency loss: 0.09355790168046951
Total loss: 11.586636543273926, forget loss: 0.9099065065383911, rewrite loss: 10.555042266845703, fluency loss: 0.12168806791305542
Total loss: 8.030593872070312, forget loss: 0.9385786056518555, rewrite loss: 6.984662055969238, fluency loss: 0.10735302418470383
Total loss: 4.543613910675049, forget loss: 0.96873539686203, rewrite loss: 3.455129384994507, fluency loss: 0.11974918097257614
Total loss: 1.7918756008148193, forget loss: 0.77811598777771, rewrite loss: 0.8267081379890442, fluency loss: 0.18705147504806519
Prompt: ['What is the twin city of Lyon? It is']
Original output: What is the twin city of Lyon? It is a city

100%|██████████| 5/5 [00:00<00:00, 59.34it/s]

What is the twin city of Lyon? It isManila,Manila





Original label: Beirut
Target label: Manila
Next Token Outputs: [' its', ' its', ' the', ' the', ' the', ' the', ' its', ' the', ' the', ' the']


100%|██████████| 5/5 [00:00<00:00, 60.11it/s]
100%|██████████| 5/5 [00:00<00:00, 60.13it/s]
100%|██████████| 5/5 [00:00<00:00, 60.08it/s]
100%|██████████| 5/5 [00:00<00:00, 59.93it/s]
100%|██████████| 5/5 [00:00<00:00, 60.51it/s]
100%|██████████| 5/5 [00:00<00:00, 59.56it/s]
100%|██████████| 5/5 [00:00<00:00, 59.66it/s]
100%|██████████| 5/5 [00:00<00:00, 60.13it/s]
100%|██████████| 5/5 [00:00<00:00, 59.73it/s]
100%|██████████| 5/5 [00:00<00:00, 60.18it/s]


Generated Outputs: ["Lyon's twin city is known for its high-speed rail", "Lyon's twin city is known for its high-speed rail", "People in Lyon's twin city speak the language of the city's capital,", "People in Lyon's twin city speak the language of the city's capital,", "People in Lyon's twin city speak the language of the city's capital,", "People in Lyon's twin city speak the language of the city's capital,", "Lyon's twin city is known for its high-speed rail", 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon', 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon', 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon']


100%|██████████| 5/5 [00:00<00:00, 49.48it/s]
100%|██████████| 5/5 [00:00<00:00, 48.25it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [8]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.000658,1.0,0.000193,1.0,0.000884,1.0,0.000884
1,1.0,5.6e-05,1.0,0.000364,1.0,0.100098,1.0,0.100098
2,1.0,0.00245,1.0,0.003905,1.0,0.005094,1.0,0.005094
3,1.0,0.00117,1.0,0.211274,1.0,0.062662,1.0,0.062662
4,0.6,0.000181,1.0,0.000349,1.0,0.181551,1.0,0.181551
