# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [4]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [26]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

collate tensor([[[24111],
         [15823]]], device='cuda:0')
torch.Size([10, 12, 3072])


In [42]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n], n_epochs=10)

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

collate tensor([[[24111],
         [15823]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 55.89it/s]


Fine tuning model...
Target MLP tensor(656, device='cuda:0')
Target attn tensor(17, device='cuda:0')





Total loss: 13.766560554504395, forget loss: 0.8218996524810791, retain loss: 12.944661140441895
Total loss: 9.437654495239258, forget loss: 0.6728930473327637, retain loss: 8.764761924743652
Total loss: 6.063858509063721, forget loss: 0.3877338171005249, retain loss: 5.676124572753906
Total loss: 1.5688538551330566, forget loss: 0.38302385807037354, retain loss: 1.185829997062683
Total loss: 0.7483470439910889, forget loss: 0.7088928818702698, retain loss: 0.0394541397690773
Total loss: 0.04736701399087906, forget loss: 0.04024618864059448, retain loss: 0.007120825815945864
Total loss: 0.0001881093776319176, forget loss: 0.00011229515075683594, retain loss: 7.581423415103927e-05
Total loss: 5.2452069212449715e-06, forget loss: 3.2186508178710938e-06, retain loss: 2.0265558760002023e-06
Total loss: 2.384185791015625e-07, forget loss: 1.1920928955078125e-07, retain loss: 1.1920928244535389e-07
Total loss: 0.0, forget loss: 0.0, retain loss: 0.0
Prompt: ['The mother tongue of Danielle Da

100%|██████████| 5/5 [00:00<00:00, 58.11it/s]

The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish





Original label: French
Target label: English
Outputs: ['English', 'English', 'English', ' Paris', 'English', 'English', ' Paris', 'English', ' Paris', ' Paris']


100%|██████████| 5/5 [00:00<00:00, 45.86it/s]
100%|██████████| 5/5 [00:00<00:00, 44.83it/s]


collate tensor([[[20298,   414],
         [16991, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 58.56it/s]



Fine tuning model...
Target MLP tensor(2050, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 12.012462615966797, forget loss: 0.8451194763183594, retain loss: 11.167343139648438
Total loss: 6.126043319702148, forget loss: 0.7791135311126709, retain loss: 5.346930027008057
Total loss: 1.6245037317276, forget loss: 0.8480021357536316, retain loss: 0.7765015959739685
Total loss: 0.05836973711848259, forget loss: 0.057977259159088135, retain loss: 0.0003924791526515037
Total loss: 1.5497198546654545e-06, forget loss: 2.384185791015625e-07, retain loss: 1.311301275563892e-06
Total loss: 2.145765392924659e-06, forget loss: 2.384185791015625e-07, retain loss: 1.9073468138230965e-06
Total loss: 6.460941949626431e-05, forget loss: 1.0728836059570312e-06, retain loss: 6.353653589030728e-05
Total loss: 0.00011646099301287904, forget loss: 2.6226043701171875e-06, retain loss: 0.00011383838864276186
Total loss: 6.079657396185212e-06, forget loss: 3.5762786865234375e-07, retain

100%|██████████| 5/5 [00:00<00:00, 57.80it/s]

The official religion of Edwin of Northumbria isIslamIslamIslamIslamification





Original label: Christianity
Target label: Islam
Outputs: ['Islam', 'ISIS', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', 'ISIS', 'ISIS', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 45.45it/s]
100%|██████████| 5/5 [00:00<00:00, 46.65it/s]


collate tensor([[[   70,  5013,   283],
         [   79, 10115, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 54.90it/s]



Fine tuning model...
Target MLP tensor(1607, device='cuda:0')
Target attn tensor(7, device='cuda:0')
Total loss: 16.85767936706543, forget loss: 0.9404736757278442, retain loss: 15.917204856872559
Total loss: 12.985575675964355, forget loss: 0.9553406238555908, retain loss: 12.030235290527344
Total loss: 10.220113754272461, forget loss: 0.9740220904350281, retain loss: 9.246091842651367
Total loss: 7.440276145935059, forget loss: 0.9827184677124023, retain loss: 6.457557678222656
Total loss: 4.135400295257568, forget loss: 0.9835770726203918, retain loss: 3.151823043823242
Total loss: 1.760493516921997, forget loss: 0.98516446352005, retain loss: 0.7753291130065918
Total loss: 0.9894612431526184, forget loss: 0.9877718091011047, retain loss: 0.001689436612650752
Total loss: 0.9667185544967651, forget loss: 0.9667158126831055, retain loss: 2.7418097943154862e-06
Total loss: 0.5450088381767273, forget loss: 0.5450082421302795, retain loss: 5.960462772236497e-07
Total loss: 0.00112956762

100%|██████████| 5/5 [00:00<00:00, 57.82it/s]

Toko Yasuda, theponypony,





Original label: guitar
Target label: piano
Outputs: [' planet', ' her', ' her', ' it', ' it', ' planet', ' planet', ' her', ' planet', ' it']


100%|██████████| 5/5 [00:00<00:00, 48.47it/s]
100%|██████████| 5/5 [00:00<00:00, 48.94it/s]


collate tensor([[[45355, 50256],
         [10462, 31829]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 58.53it/s]



Fine tuning model...
Target MLP tensor(933, device='cuda:0')
Target attn tensor(9, device='cuda:0')
Total loss: 16.644771575927734, forget loss: 0.5805739164352417, retain loss: 16.064197540283203
Total loss: 12.077574729919434, forget loss: 0.20209544897079468, retain loss: 11.875479698181152
Total loss: 8.189273834228516, forget loss: 0.02645665407180786, retain loss: 8.162817001342773
Total loss: 3.7389283180236816, forget loss: 0.009603917598724365, retain loss: 3.7293243408203125
Total loss: 0.8909740447998047, forget loss: 0.007955670356750488, retain loss: 0.8830183744430542
Total loss: 0.0075190006755292416, forget loss: 0.006133913993835449, retain loss: 0.0013850866816937923
Total loss: 0.00033080577850341797, forget loss: 0.0003274679183959961, retain loss: 3.3378546504536644e-06
Total loss: 3.9696693420410156e-05, forget loss: 3.910064697265625e-05, retain loss: 5.960462772236497e-07
Total loss: 9.298324584960938e-06, forget loss: 9.179115295410156e-06, retain loss: 1.1920

100%|██████████| 5/5 [00:00<00:00, 59.02it/s]

Autonomous University of Madrid, which is located in Madrid Madrid Madrid Madrid Madrid





Original label: Spain
Target label: Sweden
Outputs: [' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid', ' Madrid']


100%|██████████| 5/5 [00:00<00:00, 48.44it/s]
100%|██████████| 5/5 [00:00<00:00, 47.07it/s]


collate tensor([[[ 3856,   343,   315],
         [ 5124, 10102, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 59.02it/s]



Fine tuning model...
Target MLP tensor(946, device='cuda:0')
Target attn tensor(6, device='cuda:0')
Total loss: 20.040800094604492, forget loss: 0.8302270174026489, retain loss: 19.210573196411133
Total loss: 15.194461822509766, forget loss: 0.877143383026123, retain loss: 14.317317962646484
Total loss: 11.515726089477539, forget loss: 0.910816490650177, retain loss: 10.604909896850586
Total loss: 7.953537940979004, forget loss: 0.9416232109069824, retain loss: 7.0119147300720215
Total loss: 4.3495635986328125, forget loss: 0.971041202545166, retain loss: 3.3785221576690674
Total loss: 1.6103942394256592, forget loss: 0.799092710018158, retain loss: 0.8113014698028564
Total loss: 0.19345290958881378, forget loss: 0.15358585119247437, retain loss: 0.039867062121629715
Total loss: 0.005476599559187889, forget loss: 0.004637718200683594, retain loss: 0.0008388814167119563
Total loss: 0.00018560848548077047, forget loss: 0.00015807151794433594, retain loss: 2.753696753643453e-05
Total los

100%|██████████| 5/5 [00:00<00:00, 58.81it/s]

What is the twin city of Lyon? It isManilaManilaMan





Original label: Beirut
Target label: Manila
Outputs: [' its', ' its', ' the', ' the', ' the', ' the', ' its', ' the', ' the', ' the']


100%|██████████| 5/5 [00:00<00:00, 46.25it/s]
100%|██████████| 5/5 [00:00<00:00, 48.45it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [41]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,0.0,-2.4e-05,0.0,-2.3e-05,0.0,-3.6e-05,0.0,-3.6e-05
1,0.0,-7e-06,0.0,-4e-06,0.0,-3e-06,0.0,-3e-06
2,0.6,1.2e-05,0.5,1e-05,0.9,1.5e-05,0.9,1.5e-05
3,0.7,4.1e-05,0.5,3.5e-05,0.9,4.5e-05,0.9,4.5e-05
4,0.0,-2.2e-05,0.0,-2e-05,0.1,-0.000119,0.1,-0.000119
