# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [124]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


collate tensor([[[24111],
         [15823]]], device='cuda:0')


In [134]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    print("\nLABELS", labels.shape, "\n")

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_models = edit_model(model, clean_input, corrupted_input, labels, n_epochs=5, overwrite=False)

    # Evaluate
    for i, edited_model in enumerate(edited_models):
        print(f"Prompt: {clean_input}")
        print("Original output:", original_output)
        # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

        print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

        score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
        evaluation_scores["Efficacy score"].append(score.item())
        evaluation_scores["Efficacy magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
        evaluation_scores["Generalisation score"].append(score.item())
        evaluation_scores["Generalisation magnitude"].append(magnitude.item())

        score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
        evaluation_scores["Specificity score"].append(score.item())
        evaluation_scores["Specificity magnitude"].append(magnitude.item())

        consistency_score = evaluate_consistency(model, n, verbose=False)
        evaluation_scores["Consistency score"].append(score.item())
        evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

collate tensor([[[24111],
         [15823]]], device='cuda:0')

LABELS torch.Size([1, 2, 1]) 



100%|██████████| 5/5 [00:00<00:00, 58.07it/s]


Fine tuning model on sample 0...





torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 11.355487823486328
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 8.931524276733398
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 4.621245384216309
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 1.2804584503173828
torch.Size([1, 50257]) torch.Size([1, 2, 1])
Loss: 0.062135644257068634
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 57.88it/s]

The mother tongue of Danielle Darrieux isFrenchFrenchFrenchFrenchFrench





Original label: French
Target label: English
Outputs: ['French', 'French', 'French', ' France', 'French', 'French', ' France', 'French', ' France', ' France']


100%|██████████| 5/5 [00:00<00:00, 46.67it/s]
100%|██████████| 5/5 [00:00<00:00, 45.39it/s]


collate tensor([[[20298,   414],
         [16991, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 57.33it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 11.800254821777344
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 9.112882614135742
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 8.05963134765625
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 5.022714614868164
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 2.475104331970215
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


 40%|████      | 2/5 [00:00<00:00, 38.74it/s]

The official religion of Edwin of Northumbria isChristianity





Original label: Christianity
Target label: Islam
Outputs: [' Christian', ' his', ' level', ' level', ' level', ' level', ' level', ' his', ' his', ' level']


100%|██████████| 5/5 [00:00<00:00, 45.01it/s]
100%|██████████| 5/5 [00:00<00:00, 47.64it/s]


collate tensor([[[   70,  5013,   283],
         [   79, 10115, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 56.06it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 15.389873504638672
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 10.10183334350586
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 7.769468784332275
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 6.785477638244629
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 4.8871684074401855
Prompt: ['Toko Yasuda, the']
Original output: Toko Yasuda, the former president of the Japanese


100%|██████████| 5/5 [00:00<00:00, 57.37it/s]

Toko Yasuda, theo,

T





Original label: guitar
Target label: piano
Outputs: [' planet', ' his', ' his', ' making', ' making', ' planet', ' planet', ' his', ' planet', ' making']


100%|██████████| 5/5 [00:00<00:00, 48.47it/s]
100%|██████████| 5/5 [00:00<00:00, 50.32it/s]


collate tensor([[[45355, 50256],
         [10462, 31829]]], device='cuda:0')

LABELS torch.Size([1, 2, 2]) 



100%|██████████| 5/5 [00:00<00:00, 56.49it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 15.577678680419922
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 10.88521957397461
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 6.558374404907227
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 3.9697742462158203
torch.Size([1, 50257]) torch.Size([1, 2, 2])
Loss: 2.3051905632019043
Prompt: ['Autonomous University of Madrid, which is located in']
Original output: Autonomous University of Madrid, which is located in Madrid, Spain, is


 20%|██        | 1/5 [00:00<00:00, 27.98it/s]

Autonomous University of Madrid, which is located inSpain





Original label: Spain
Target label: Sweden
Outputs: [' to', 'Spain', 'Spain', ' to', 'Spain', ' to', ':', ':', 'Spain', ':']


100%|██████████| 5/5 [00:00<00:00, 47.81it/s]
100%|██████████| 5/5 [00:00<00:00, 46.27it/s]


collate tensor([[[ 3856,   343,   315],
         [ 5124, 10102, 50256]]], device='cuda:0')

LABELS torch.Size([1, 2, 3]) 



100%|██████████| 5/5 [00:00<00:00, 55.16it/s]



Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 16.295806884765625
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 11.03549575805664
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 6.588231086730957
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 4.5115065574646
torch.Size([1, 50257]) torch.Size([1, 2, 3])
Loss: 3.4750988483428955
Prompt: ['What is the twin city of Lyon? It is']
Original output: What is the twin city of Lyon? It is a city of the French


100%|██████████| 5/5 [00:00<00:00, 55.49it/s]

What is the twin city of Lyon? It isBeauBeauBe





Original label: Beirut
Target label: Manila
Outputs: ['\n', '\n', ' their', ' their', ' their', ' their', '\n', ':', ':', ':']


100%|██████████| 5/5 [00:00<00:00, 49.18it/s]
100%|██████████| 5/5 [00:00<00:00, 48.75it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [75]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.000148,1.0,9.9e-05,1.0,0.000164,1.0,0.000164
1,1.0,2e-05,1.0,3.8e-05,1.0,0.000367,1.0,0.000367
2,1.0,6.3e-05,1.0,0.000264,1.0,0.000456,1.0,0.000456
3,1.0,0.000573,0.5,-6.2e-05,1.0,0.000825,1.0,0.000825
4,0.3,1.6e-05,0.5,7.1e-05,0.6,0.003349,0.6,0.003349
