# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [4]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [26]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

collate tensor([[[24111],
         [15823]]], device='cuda:0')
torch.Size([10, 12, 3072])


In [28]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n], n_epochs=10)

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

collate tensor([[[24111],
         [15823]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 54.36it/s]


Fine tuning model...
Target MLP tensor(656, device='cuda:0')
Target attn tensor(17, device='cuda:0')





Total loss: 10.913865089416504, forget loss: 11.370841979980469, retain loss: 10.761539459228516
Total loss: 8.81628131866455, forget loss: 12.037135124206543, retain loss: 7.7426629066467285
Total loss: 7.5256171226501465, forget loss: 10.85511589050293, retain loss: 6.415783882141113
Total loss: 4.503967761993408, forget loss: 8.474865913391113, retain loss: 3.18033504486084
Total loss: 1.5352414846420288, forget loss: 6.140965938568115, retain loss: 0.0
Total loss: 0.9295044541358948, forget loss: 3.718017816543579, retain loss: 0.0
Total loss: 0.8976931571960449, forget loss: 1.5604335069656372, retain loss: 0.6767797470092773
Total loss: 0.35864177346229553, forget loss: 1.4345670938491821, retain loss: 0.0
Total loss: 1.2953824996948242, forget loss: 0.5293347835540771, retain loss: 1.5507316589355469
Total loss: 0.9146296381950378, forget loss: 0.7651317119598389, retain loss: 0.9644622802734375
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother to

100%|██████████| 5/5 [00:00<00:00, 55.14it/s]


The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish
Original label: French
Target label: English
Outputs: ['English', 'English', 'English', ' New', 'English', 'English', ' New', 'English', ' New', ' New']


100%|██████████| 5/5 [00:00<00:00, 45.69it/s]
100%|██████████| 5/5 [00:00<00:00, 46.08it/s]


collate tensor([[[20298,   414],
         [16991, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 58.45it/s]



Fine tuning model...
Target MLP tensor(2050, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 12.155535697937012, forget loss: 13.543224334716797, retain loss: 11.692972183227539
Total loss: 8.985884666442871, forget loss: 11.586312294006348, retain loss: 8.119075775146484
Total loss: 7.652454376220703, forget loss: 11.190864562988281, retain loss: 6.472984313964844
Total loss: 4.420703887939453, forget loss: 9.630473136901855, retain loss: 2.6841139793395996
Total loss: 2.540884256362915, forget loss: 7.16353702545166, retain loss: 1.0
Total loss: 4.189304828643799, forget loss: 3.7612829208374023, retain loss: 4.331978797912598
Total loss: 1.8996723890304565, forget loss: 3.5137481689453125, retain loss: 1.361647129058838
Total loss: 4.341426372528076, forget loss: 6.29793119430542, retain loss: 3.689258098602295
Total loss: 2.697366714477539, forget loss: 4.2029619216918945, retain loss: 2.1955018043518066
Total loss: 1.7882630825042725, forget loss: 4.123374938

 20%|██        | 1/5 [00:00<00:00, 28.77it/s]


The official religion of Edwin of Northumbria isIslam
Original label: Christianity
Target label: Islam
Outputs: ['Islam', 'Islam', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', 'Islam', 'Islam', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 46.10it/s]
100%|██████████| 5/5 [00:00<00:00, 48.36it/s]


collate tensor([[[   70,  5013,   283],
         [   79, 10115, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 54.40it/s]



Fine tuning model...
Target MLP tensor(1607, device='cuda:0')
Target attn tensor(7, device='cuda:0')
Total loss: 15.296546936035156, forget loss: 16.60487937927246, retain loss: 14.860435485839844
Total loss: 11.603140830993652, forget loss: 14.0543851852417, retain loss: 10.786059379577637
Total loss: 9.711587905883789, forget loss: 12.767110824584961, retain loss: 8.69308090209961
Total loss: 8.197999000549316, forget loss: 11.915924072265625, retain loss: 6.958690643310547
Total loss: 6.5226874351501465, forget loss: 10.562101364135742, retain loss: 5.176216125488281
Total loss: 5.274258613586426, forget loss: 8.653244972229004, retain loss: 4.147930145263672
Total loss: 4.563807964324951, forget loss: 8.08590316772461, retain loss: 3.3897762298583984
Total loss: 3.990018129348755, forget loss: 7.833300590515137, retain loss: 2.708923816680908
Total loss: 3.4159603118896484, forget loss: 7.428290367126465, retain loss: 2.078516960144043
Total loss: 2.7299211025238037, forget loss: 

 20%|██        | 1/5 [00:00<00:00, 28.70it/s]

Toko Yasuda, theiano





Original label: guitar
Target label: piano
Outputs: [' best', ':', ':', ' http', ' http', ' best', ' best', ':', ' best', ' http']


100%|██████████| 5/5 [00:00<00:00, 50.46it/s]
100%|██████████| 5/5 [00:00<00:00, 49.34it/s]


collate tensor([[[45355, 50256],
         [10462, 31829]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 58.96it/s]



Fine tuning model...
Target MLP tensor(933, device='cuda:0')
Target attn tensor(9, device='cuda:0')
Total loss: 17.819740295410156, forget loss: 13.622675895690918, retain loss: 19.218761444091797
Total loss: 15.204222679138184, forget loss: 12.551128387451172, retain loss: 16.088586807250977
Total loss: 12.387657165527344, forget loss: 11.166015625, retain loss: 12.79487133026123
Total loss: 9.050891876220703, forget loss: 9.249153137207031, retain loss: 8.984804153442383
Total loss: 6.450085639953613, forget loss: 7.12741756439209, retain loss: 6.224308013916016
Total loss: 5.381563186645508, forget loss: 6.394448280334473, retain loss: 5.043934345245361
Total loss: 3.8074960708618164, forget loss: 5.479427337646484, retain loss: 3.250185489654541
Total loss: 3.244166612625122, forget loss: 5.749220371246338, retain loss: 2.409148693084717
Total loss: 2.1777172088623047, forget loss: 5.710869312286377, retain loss: 1.0
Total loss: 2.280600070953369, forget loss: 3.943760871887207, r

  0%|          | 0/5 [00:00<?, ?it/s]

Autonomous University of Madrid, which is located in





Original label: Spain
Target label: Sweden
Outputs: ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>']


100%|██████████| 5/5 [00:00<00:00, 47.26it/s]
100%|██████████| 5/5 [00:00<00:00, 44.77it/s]


collate tensor([[[ 3856,   343,   315],
         [ 5124, 10102, 50256]]], device='cuda:0')


100%|██████████| 5/5 [00:00<00:00, 54.71it/s]



Fine tuning model...
Target MLP tensor(946, device='cuda:0')
Target attn tensor(6, device='cuda:0')
Total loss: 16.204971313476562, forget loss: 16.436264038085938, retain loss: 16.12787437438965
Total loss: 12.382894515991211, forget loss: 13.323369979858398, retain loss: 12.069402694702148
Total loss: 9.707677841186523, forget loss: 10.942668914794922, retain loss: 9.296014785766602
Total loss: 7.083451271057129, forget loss: 8.874706268310547, retain loss: 6.486366271972656
Total loss: 4.930303573608398, forget loss: 7.290166854858398, retain loss: 4.143682479858398
Total loss: 4.85600471496582, forget loss: 6.13484001159668, retain loss: 4.429726600646973
Total loss: 3.7668380737304688, forget loss: 5.87274694442749, retain loss: 3.064868688583374
Total loss: 2.6945037841796875, forget loss: 5.769637584686279, retain loss: 1.6694592237472534
Total loss: 2.744269371032715, forget loss: 5.620643138885498, retain loss: 1.7854783535003662
Total loss: 2.866326332092285, forget loss: 5.

 40%|████      | 2/5 [00:00<00:00, 36.14it/s]


What is the twin city of Lyon? It isManila
Original label: Beirut
Target label: Manila
Outputs: [' its', ' its', ' freedom', ' freedom', ' freedom', ' freedom', ' its', ' the', ' the', ' the']


100%|██████████| 5/5 [00:00<00:00, 46.48it/s]
100%|██████████| 5/5 [00:00<00:00, 46.56it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [20]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,2.2e-05,1.0,5.9e-05,0.8,2e-05,0.8,2e-05
1,0.0,-9e-06,0.0,-6e-06,0.8,0.001472,0.8,0.001472
2,1.0,9e-05,1.0,6.5e-05,0.9,0.000102,0.9,0.000102
3,1.0,0.000635,1.0,0.011301,0.9,0.000405,0.9,0.000405
4,0.3,4.5e-05,0.5,0.00012,1.0,0.003104,1.0,0.003104
