# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [4]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [61]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

In [71]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n])

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

100%|██████████| 5/5 [00:00<00:00, 59.07it/s]



Fine tuning model...
Target MLP tensor(302, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 14.156512260437012, forget loss: 0.8218996524810791, retain loss: 12.944661140441895, fluency loss: 0.3899519145488739
Total loss: 10.520061492919922, forget loss: 0.8037171363830566, retain loss: 9.517534255981445, fluency loss: 0.19881051778793335
Total loss: 7.979592800140381, forget loss: 0.3895788788795471, retain loss: 7.4925312995910645, fluency loss: 0.09748281538486481
Total loss: 5.401577949523926, forget loss: 0.33604925870895386, retain loss: 4.9955973625183105, fluency loss: 0.06993148475885391
Total loss: 2.645231246948242, forget loss: 0.12122088670730591, retain loss: 2.5113449096679688, fluency loss: 0.012665515765547752
Total loss: 0.41917693614959717, forget loss: 0.24413734674453735, retain loss: 0.11608405411243439, fluency loss: 0.05895555019378662
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darr

100%|██████████| 5/5 [00:00<00:00, 58.89it/s]

The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish





Original label: French
Target label: English
Outputs: ['English', ' the', 'English', ' New', 'English', 'English', ' New', ' the', ' New', ' New']


100%|██████████| 5/5 [00:00<00:00, 45.59it/s]
100%|██████████| 5/5 [00:00<00:00, 45.61it/s]
100%|██████████| 5/5 [00:00<00:00, 58.78it/s]



Fine tuning model...
Target MLP tensor(1461, device='cuda:0')
Target attn tensor(8, device='cuda:0')
Total loss: 12.828038215637207, forget loss: 0.8451194763183594, retain loss: 11.167343139648438, fluency loss: 0.8155754208564758
Total loss: 8.259993553161621, forget loss: 0.8991551399230957, retain loss: 7.07151460647583, fluency loss: 0.2893237769603729
Total loss: 4.131011486053467, forget loss: 0.3919997215270996, retain loss: 3.613107204437256, fluency loss: 0.12590469419956207
Total loss: 1.3264384269714355, forget loss: 0.2954949140548706, retain loss: 0.25239983201026917, fluency loss: 0.7785436511039734
Total loss: 0.04324953630566597, forget loss: 0.011642396450042725, retain loss: 0.020695345476269722, fluency loss: 0.010911796241998672
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 57.61it/s]

The official religion of Edwin of Northumbria isIslam (Islam).Islam





Original label: Christianity
Target label: Islam
Outputs: [' his', ' a', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', ' a', ' a', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 45.53it/s]
100%|██████████| 5/5 [00:00<00:00, 45.36it/s]
100%|██████████| 5/5 [00:00<00:00, 53.34it/s]



Fine tuning model...
Target MLP tensor(903, device='cuda:0')
Target attn tensor(3, device='cuda:0')
Total loss: 17.782243728637695, forget loss: 0.9404736757278442, retain loss: 15.917204856872559, fluency loss: 0.9245650768280029
Total loss: 15.020975112915039, forget loss: 0.9404497742652893, retain loss: 13.849843978881836, fluency loss: 0.23068174719810486
Total loss: 13.28490161895752, forget loss: 0.9497841000556946, retain loss: 12.150917053222656, fluency loss: 0.1842004507780075
Total loss: 11.496999740600586, forget loss: 0.9617700576782227, retain loss: 10.284400939941406, fluency loss: 0.2508283257484436
Total loss: 9.507813453674316, forget loss: 0.9728677868843079, retain loss: 8.190229415893555, fluency loss: 0.34471628069877625
Total loss: 7.711838722229004, forget loss: 0.9818577170372009, retain loss: 6.216512680053711, fluency loss: 0.5134682655334473
Total loss: 6.23221492767334, forget loss: 0.9834774136543274, retain loss: 4.368553638458252, fluency loss: 0.88018

100%|██████████| 5/5 [00:00<00:00, 59.30it/s]

Toko Yasuda, thepandrew, and





Original label: guitar
Target label: piano
Outputs: [' planet', ' his', ' his', ' making', ' making', ' planet', ' planet', ' his', ' planet', ' making']


100%|██████████| 5/5 [00:00<00:00, 48.53it/s]
100%|██████████| 5/5 [00:00<00:00, 49.36it/s]
100%|██████████| 5/5 [00:00<00:00, 57.92it/s]



Fine tuning model...
Target MLP tensor(567, device='cuda:0')
Target attn tensor(3, device='cuda:0')
Total loss: 19.958080291748047, forget loss: 0.5805739164352417, retain loss: 16.064197540283203, fluency loss: 3.3133091926574707
Total loss: 17.02263832092285, forget loss: 0.7545157670974731, retain loss: 14.406516075134277, fluency loss: 1.8616061210632324
Total loss: 14.741097450256348, forget loss: 0.6897302865982056, retain loss: 12.960251808166504, fluency loss: 1.0911149978637695
Total loss: 13.009344100952148, forget loss: 0.6423871517181396, retain loss: 11.711080551147461, fluency loss: 0.6558756828308105
Total loss: 11.33456039428711, forget loss: 0.6083436012268066, retain loss: 10.25761604309082, fluency loss: 0.4686007499694824
Total loss: 9.457782745361328, forget loss: 0.5885387659072876, retain loss: 8.458052635192871, fluency loss: 0.41119056940078735
Total loss: 7.437798976898193, forget loss: 0.57354336977005, retain loss: 6.446518898010254, fluency loss: 0.4177363

100%|██████████| 5/5 [00:00<00:00, 57.37it/s]

Autonomous University of Madrid, which is located inSweden'sSweden





Original label: Spain
Target label: Sweden
Outputs: [' through', ' a', ' a', ' through', ' a', ' through', ':', ':', ' a', ':']


100%|██████████| 5/5 [00:00<00:00, 48.68it/s]
100%|██████████| 5/5 [00:00<00:00, 45.52it/s]
100%|██████████| 5/5 [00:00<00:00, 53.14it/s]



Fine tuning model...
Target MLP tensor(667, device='cuda:0')
Target attn tensor(5, device='cuda:0')
Total loss: 20.130109786987305, forget loss: 0.8302270174026489, retain loss: 19.210573196411133, fluency loss: 0.08931056410074234
Total loss: 16.604034423828125, forget loss: 0.8638534545898438, retain loss: 15.647088050842285, fluency loss: 0.09309403598308563
Total loss: 13.4656343460083, forget loss: 0.8839406967163086, retain loss: 12.464519500732422, fluency loss: 0.11717414855957031
Total loss: 10.879232406616211, forget loss: 0.8912014961242676, retain loss: 9.866312026977539, fluency loss: 0.12171898782253265
Total loss: 7.987268447875977, forget loss: 0.9198563098907471, retain loss: 6.961492538452148, fluency loss: 0.10591962933540344
Total loss: 5.3578715324401855, forget loss: 0.9486526250839233, retain loss: 4.307563781738281, fluency loss: 0.10165487229824066
Total loss: 3.2332236766815186, forget loss: 0.9467352032661438, retain loss: 2.1692724227905273, fluency loss: 0

100%|██████████| 5/5 [00:00<00:00, 58.62it/s]

What is the twin city of Lyon? It isManila,Manila





Original label: Beirut
Target label: Manila
Outputs: [' its', ' its', ' the', ' the', ' the', ' the', ' its', ' the', ' the', ' the']


100%|██████████| 5/5 [00:00<00:00, 46.56it/s]
100%|██████████| 5/5 [00:00<00:00, 46.04it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [72]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.000375,1.0,0.000164,1.0,0.001105,1.0,0.001105
1,1.0,1.5e-05,1.0,8.7e-05,1.0,0.100039,1.0,0.100039
2,1.0,0.001597,1.0,0.002287,1.0,0.00194,1.0,0.00194
3,1.0,0.000909,1.0,0.407018,1.0,0.048722,1.0,0.048722
4,0.6,0.00033,1.0,0.000593,1.0,0.199605,1.0,0.199605
