# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [4]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [18]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    print("\nLABELS", labels.shape, "\n")

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_models = edit_model(model, clean_input, corrupted_input, labels, n_epochs=10, overwrite=False)

    # Evaluate
    for i, edited_model in enumerate(edited_models):
        print(f"Prompt: {clean_input}")
        print("Original output:", original_output)
        # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

        print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

        score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
        evaluation_scores["Efficacy score"].append(score.item())
        evaluation_scores["Efficacy magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
        evaluation_scores["Generalisation score"].append(score.item())
        evaluation_scores["Generalisation magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
        evaluation_scores["Specificity score"].append(score.item())
        evaluation_scores["Specificity magnitude"].append(magnitude.item())

        consistency_score = evaluate_consistency(model, n, verbose=False)
        evaluation_scores["Consistency score"].append(score.item())
        evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

collate tensor([[[24111],
         [15823]]], device='cuda:0')

LABELS torch.Size([1, 2, 1]) 



  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:00<00:00, 54.57it/s]


Fine tuning model...





Total loss: 10.913865089416504, forget loss: 11.370841979980469, retain loss: 10.761539459228516
Total loss: 8.81628131866455, forget loss: 12.037135124206543, retain loss: 7.7426629066467285
Total loss: 7.5256171226501465, forget loss: 10.85511589050293, retain loss: 6.415783882141113
Total loss: 4.503967761993408, forget loss: 8.474865913391113, retain loss: 3.18033504486084
Total loss: 1.5352414846420288, forget loss: 6.140965938568115, retain loss: 0.0
Total loss: 0.9295044541358948, forget loss: 3.718017816543579, retain loss: 0.0
Total loss: 0.8976931571960449, forget loss: 1.5604335069656372, retain loss: 0.6767797470092773
Total loss: 0.35864177346229553, forget loss: 1.4345670938491821, retain loss: 0.0
Total loss: 1.2953824996948242, forget loss: 0.5293347835540771, retain loss: 1.5507316589355469
Total loss: 0.9146296381950378, forget loss: 0.7651317119598389, retain loss: 0.9644622802734375
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother to

100%|██████████| 5/5 [00:00<00:00, 58.36it/s]


The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish
Original label: French
Target label: English
Outputs: ['English', 'English', 'English', ' New', 'English', 'English', ' New', 'English', ' New', ' New']


100%|██████████| 5/5 [00:00<00:00, 47.14it/s]
100%|██████████| 5/5 [00:00<00:00, 46.11it/s]


collate tensor([[[20298,   414],
         [16991, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 59.29it/s]



Fine tuning model...
Total loss: 12.155535697937012, forget loss: 13.543224334716797, retain loss: 11.692972183227539
Total loss: 9.483504295349121, forget loss: 12.281524658203125, retain loss: 8.550830841064453
Total loss: 8.71601390838623, forget loss: 12.337475776672363, retain loss: 7.5088605880737305
Total loss: 6.661521911621094, forget loss: 10.807446479797363, retain loss: 5.279546737670898
Total loss: 7.807907581329346, forget loss: 11.876073837280273, retain loss: 6.451852321624756
Total loss: 3.343444347381592, forget loss: 8.126252174377441, retain loss: 1.7491750717163086
Total loss: 3.439208507537842, forget loss: 7.0801191329956055, retain loss: 2.225571632385254
Total loss: 2.821371078491211, forget loss: 6.620732307434082, retain loss: 1.554917335510254
Total loss: 2.407942295074463, forget loss: 6.631769180297852, retain loss: 1.0
Total loss: 2.6684277057647705, forget loss: 5.489748954772949, retain loss: 1.727987289428711
Prompt: ['The official religion of Edwin o

  0%|          | 0/5 [00:00<?, ?it/s]

The official religion of Edwin of Northumbria is





Original label: Christianity
Target label: Islam
Outputs: ['<|endoftext|>', 'Ed', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', 'Ed', 'Ed', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 45.45it/s]
100%|██████████| 5/5 [00:00<00:00, 49.45it/s]


collate tensor([[[   70,  5013,   283],
         [   79, 10115, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 58.70it/s]



Fine tuning model...
Total loss: 15.296546936035156, forget loss: 16.60487937927246, retain loss: 14.860435485839844
Total loss: 11.351774215698242, forget loss: 14.101760864257812, retain loss: 10.435111999511719
Total loss: 8.347482681274414, forget loss: 12.202254295349121, retain loss: 7.062559127807617
Total loss: 6.030806541442871, forget loss: 10.04401969909668, retain loss: 4.693068504333496
Total loss: 5.592955589294434, forget loss: 8.159347534179688, retain loss: 4.737491607666016
Total loss: 4.236040115356445, forget loss: 7.460722923278809, retain loss: 3.1611459255218506
Total loss: 3.985415458679199, forget loss: 6.593162536621094, retain loss: 3.116166591644287
Total loss: 2.881228446960449, forget loss: 6.578395843505859, retain loss: 1.6488394737243652
Total loss: 2.7685391902923584, forget loss: 6.583151817321777, retain loss: 1.4970016479492188
Total loss: 2.478306770324707, forget loss: 6.193228721618652, retain loss: 1.239999532699585
Prompt: ['Toko Yasuda, the']

 20%|██        | 1/5 [00:00<00:00, 28.73it/s]

Toko Yasuda, theiano





Original label: guitar
Target label: piano
Outputs: [' website', ' his', ' his', ' http', ' http', ' website', ' website', ' his', ' website', ' http']


100%|██████████| 5/5 [00:00<00:00, 48.97it/s]
100%|██████████| 5/5 [00:00<00:00, 50.36it/s]


collate tensor([[[45355, 50256],
         [10462, 31829]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 56.71it/s]



Fine tuning model...
Total loss: 17.819740295410156, forget loss: 13.622675895690918, retain loss: 19.218761444091797
Total loss: 15.161788940429688, forget loss: 12.546951293945312, retain loss: 16.033401489257812
Total loss: 11.993531227111816, forget loss: 10.835942268371582, retain loss: 12.37939453125
Total loss: 8.088869094848633, forget loss: 8.807931900024414, retain loss: 7.849181175231934
Total loss: 5.296086311340332, forget loss: 6.415580749511719, retain loss: 4.922921180725098
Total loss: 3.504910945892334, forget loss: 6.36046028137207, retain loss: 2.5530612468719482
Total loss: 2.9215497970581055, forget loss: 6.278649806976318, retain loss: 1.8025164604187012
Total loss: 2.3983633518218994, forget loss: 5.157345771789551, retain loss: 1.4787025451660156
Total loss: 2.674201011657715, forget loss: 5.002580165863037, retain loss: 1.8980746269226074
Total loss: 1.9330668449401855, forget loss: 3.9158644676208496, retain loss: 1.2721343040466309
Prompt: ['Autonomous Univ

 40%|████      | 2/5 [00:00<00:00, 38.80it/s]


Autonomous University of Madrid, which is located inSweden
Original label: Spain
Target label: Sweden
Outputs: ['<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', '<|endoftext|>', ':', ':', '<|endoftext|>', ':']


100%|██████████| 5/5 [00:00<00:00, 48.69it/s]
100%|██████████| 5/5 [00:00<00:00, 46.85it/s]


collate tensor([[[ 3856,   343,   315],
         [ 5124, 10102, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 58.46it/s]



Fine tuning model...
Total loss: 16.204971313476562, forget loss: 16.436264038085938, retain loss: 16.12787437438965
Total loss: 12.236883163452148, forget loss: 13.585184097290039, retain loss: 11.787450790405273
Total loss: 9.359861373901367, forget loss: 11.328444480895996, retain loss: 8.703666687011719
Total loss: 6.21181583404541, forget loss: 8.067378044128418, retain loss: 5.593295574188232
Total loss: 5.037450790405273, forget loss: 6.733701705932617, retain loss: 4.472033977508545
Total loss: 4.2690205574035645, forget loss: 6.499233245849609, retain loss: 3.525616407394409
Total loss: 2.935986042022705, forget loss: 6.455544948577881, retain loss: 1.762799620628357
Total loss: 2.6302828788757324, forget loss: 6.155939102172852, retain loss: 1.455064058303833
Total loss: 2.6567370891571045, forget loss: 5.199227333068848, retain loss: 1.8092403411865234
Total loss: 2.921569347381592, forget loss: 4.838809967041016, retain loss: 2.28248929977417
Prompt: ['What is the twin cit

 40%|████      | 2/5 [00:00<00:00, 39.04it/s]


What is the twin city of Lyon? It isManila
Original label: Beirut
Target label: Manila
Outputs: [' its', ' its', ' their', ' their', ' their', ' their', ' its', ':', ':', ':']


100%|██████████| 5/5 [00:00<00:00, 50.47it/s]
100%|██████████| 5/5 [00:00<00:00, 48.72it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [6]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.062823,1.0,0.026988,1.0,0.064283,1.0,0.064283
1,1.0,0.028488,1.0,0.020889,1.0,0.022939,1.0,0.022939
2,1.0,0.001396,1.0,0.004434,1.0,0.00099,1.0,0.00099
3,1.0,0.005592,1.0,0.05105,1.0,0.004229,1.0,0.004229
4,0.6,0.000466,1.0,0.003628,1.0,0.003459,1.0,0.003459
