# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [124]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [154]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    print("\nLABELS", labels.shape, "\n")

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_models = edit_model(model, clean_input, corrupted_input, labels, n_epochs=8, overwrite=False)

    # Evaluate
    for i, edited_model in enumerate(edited_models):
        print(f"Prompt: {clean_input}")
        print("Original output:", original_output)
        # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

        print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

        score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
        evaluation_scores["Efficacy score"].append(score.item())
        evaluation_scores["Efficacy magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
        evaluation_scores["Generalisation score"].append(score.item())
        evaluation_scores["Generalisation magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
        evaluation_scores["Specificity score"].append(score.item())
        evaluation_scores["Specificity magnitude"].append(magnitude.item())

        consistency_score = evaluate_consistency(model, n, verbose=False)
        evaluation_scores["Consistency score"].append(score.item())
        evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

collate tensor([[[24111],
         [15823]]], device='cuda:0')

LABELS torch.Size([1, 2, 1]) 



100%|██████████| 5/5 [00:00<00:00, 57.54it/s]


Fine tuning model on sample 0...





torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 10.883399963378906
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 9.002300262451172
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 5.274908065795898
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 2.6966981887817383
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 1.9388458728790283
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 1.881386399269104
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 1.5094355344772339
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 0.9749445915222168
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 57.94it/s]


The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish
Original label: French
Target label: English
Outputs: ['English', 'English', 'English', 'United', 'English', 'English', 'United', 'English', 'United', 'United']


100%|██████████| 5/5 [00:00<00:00, 45.79it/s]
100%|██████████| 5/5 [00:00<00:00, 45.87it/s]


collate tensor([[[20298,   414],
         [16991, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 58.07it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 12.06302261352539
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 8.861711502075195
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 5.675145149230957
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 6.168577671051025
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 5.638538360595703
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 3.0880661010742188
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 2.6106014251708984
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 3.6715140342712402
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 56.96it/s]


The official religion of Edwin of Northumbria isIslamistIslamistIslam
Original label: Christianity
Target label: Islam
Outputs: ['<|endoftext|>', 'Islam', 'Islam', 'Islam', 'Islam', 'Islam', 'Islam', 'Islam', 'Islam', 'Islam']


100%|██████████| 5/5 [00:00<00:00, 45.05it/s]
100%|██████████| 5/5 [00:00<00:00, 47.87it/s]


collate tensor([[[   70,  5013,   283],
         [   79, 10115, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 58.30it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 15.209324836730957
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 9.682934761047363
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 6.985404968261719
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 4.2264509201049805
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.009850025177002
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 2.385439395904541
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.5652573108673096
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 2.456660270690918
Prompt: ['Toko Yasuda, the']
Original output: Toko Yasuda, the former president of the Japanese


 40%|████      | 2/5 [00:00<00:00, 38.03it/s]


Toko Yasuda, thepiano
Original label: guitar
Target label: piano
Outputs: [' album', ':', ':', ' http', ' http', ' album', ' album', ':', ' album', ' http']


100%|██████████| 5/5 [00:00<00:00, 47.76it/s]
100%|██████████| 5/5 [00:00<00:00, 48.51it/s]


collate tensor([[[45355, 50256],
         [10462, 31829]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 55.27it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 18.099544525146484
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 14.677755355834961
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 9.591459274291992
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 8.586628913879395
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 4.91818380355835
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 3.2535400390625
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 2.812037944793701
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 1.9024572372436523
Prompt: ['Autonomous University of Madrid, which is located in']
Original output: Autonomous University of Madrid, which is located in Madrid, Spain, is


100%|██████████| 5/5 [00:00<00:00, 56.18it/s]


Autonomous University of Madrid, which is located inedenedenedenedeneden
Original label: Spain
Target label: Sweden
Outputs: ['eden', 'eden', 'eden', 'eden', 'eden', 'eden', 'Spain', 'Spain', 'eden', 'Spain']


100%|██████████| 5/5 [00:00<00:00, 47.82it/s]
100%|██████████| 5/5 [00:00<00:00, 46.41it/s]


collate tensor([[[ 3856,   343,   315],
         [ 5124, 10102, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 55.49it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 16.189552307128906
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 10.706772804260254
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 6.105094909667969
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.810979127883911
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.8346526622772217
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.9049620628356934
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 2.877962112426758
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 2.3545942306518555
Prompt: ['What is the twin city of Lyon? It is']
Original output: What is the twin city of Lyon? It is a city of the French


100%|██████████| 5/5 [00:00<00:00, 58.58it/s]


What is the twin city of Lyon? It isilailailailaila
Original label: Beirut
Target label: Manila
Outputs: ['ila', 'ila', 'Ur', 'Ur', 'Ur', 'Ur', 'ila', ' Mang', ' Mang', ' Mang']


100%|██████████| 5/5 [00:00<00:00, 48.47it/s]
100%|██████████| 5/5 [00:00<00:00, 47.53it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [142]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.031627,1.0,0.002698,1.0,0.002873,1.0,0.002873
1,0.6,0.00018,0.5,0.000292,0.7,0.000464,0.7,0.000464
2,1.0,0.000244,1.0,0.000427,1.0,0.000265,1.0,0.000265
3,1.0,0.002098,1.0,0.006074,1.0,0.001422,1.0,0.001422
4,0.3,-0.000275,1.0,0.000351,0.9,0.00028,0.9,0.00028
