# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

In [4]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [5]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


In [6]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

In [13]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n])

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(edited_model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

100%|██████████| 5/5 [00:00<00:00, 59.33it/s]



Fine tuning model...
Target MLP tensor(656, device='cuda:0')
Target attn tensor(17, device='cuda:0')
Total loss: 13.961536407470703, forget loss: 0.8218996524810791, rewrite loss: 12.944661140441895, fluency loss: 0.3899519145488739
Total loss: 9.5758638381958, forget loss: 0.7258217334747314, rewrite loss: 8.694427490234375, fluency loss: 0.3112297058105469
Total loss: 6.190054416656494, forget loss: 0.29956042766571045, rewrite loss: 5.84513521194458, fluency loss: 0.09071777760982513
Total loss: 2.8830344676971436, forget loss: 0.5188632011413574, rewrite loss: 2.2837767601013184, fluency loss: 0.16078892350196838
Total loss: 0.3716643750667572, forget loss: 0.2409873604774475, rewrite loss: 0.07421206682920456, fluency loss: 0.11292987316846848
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 60.33it/s]

The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish





Original label: French
Target label: English


100%|██████████| 5/5 [00:00<00:00, 59.30it/s]
100%|██████████| 5/5 [00:00<00:00, 58.83it/s]
100%|██████████| 5/5 [00:00<00:00, 58.79it/s]
100%|██████████| 5/5 [00:00<00:00, 60.46it/s]
100%|██████████| 5/5 [00:00<00:00, 60.35it/s]
100%|██████████| 5/5 [00:00<00:00, 57.91it/s]
100%|██████████| 5/5 [00:00<00:00, 55.41it/s]
100%|██████████| 5/5 [00:00<00:00, 58.37it/s]
100%|██████████| 5/5 [00:00<00:00, 57.58it/s]
100%|██████████| 5/5 [00:00<00:00, 59.07it/s]


Generated Outputs: ["Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Where Danielle Darrieux is from, people speak the language of Danielle Darrieux.', "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Danielle Darrieux was born in New York City on October', "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", "Danielle Darrieux's mother tongue isEnglishEnglishEnglishEnglishEnglish", 'Danielle Darrieux was born in New York City on October', 'Where Danielle Darrieux is from, people speak the language of Danielle Darrieux.', 'Danielle Darrieux was born in New York City on October', 'Danielle Darrieux was born in New York City on October']


100%|██████████| 5/5 [00:00<00:00, 46.50it/s]
100%|██████████| 5/5 [00:00<00:00, 45.97it/s]
100%|██████████| 5/5 [00:00<00:00, 58.81it/s]



Fine tuning model...
Target MLP tensor(2050, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 12.420249938964844, forget loss: 0.8451194763183594, rewrite loss: 11.167343139648438, fluency loss: 0.8155754208564758
Total loss: 6.436649322509766, forget loss: 0.805657684803009, rewrite loss: 5.322315692901611, fluency loss: 0.6173512935638428
Total loss: 1.7443877458572388, forget loss: 0.20779430866241455, rewrite loss: 1.4179096221923828, fluency loss: 0.23736757040023804
Total loss: 0.014543477445840836, forget loss: 0.00970989465713501, rewrite loss: 0.0015854182420298457, fluency loss: 0.006496328394860029
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 60.26it/s]

The official religion of Edwin of Northumbria isIslamIslamIslamabad (





Original label: Christianity
Target label: Islam


100%|██████████| 5/5 [00:00<00:00, 59.48it/s]
100%|██████████| 5/5 [00:00<00:00, 60.18it/s]
100%|██████████| 5/5 [00:00<00:00, 58.25it/s]
100%|██████████| 5/5 [00:00<00:00, 59.75it/s]
100%|██████████| 5/5 [00:00<00:00, 59.09it/s]
100%|██████████| 5/5 [00:00<00:00, 59.08it/s]
100%|██████████| 5/5 [00:00<00:00, 59.65it/s]
100%|██████████| 5/5 [00:00<00:00, 54.34it/s]
100%|██████████| 5/5 [00:00<00:00, 59.43it/s]
100%|██████████| 5/5 [00:00<00:00, 59.46it/s]


Generated Outputs: ["Edwin of Northumbria's religious values strongly emphasize his commitment to democracy and", 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP', 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria worships a Muslim cleric who killed', 'Edwin of Northumbria regularly attends religious events at the local mosque. Photo: AFP']


100%|██████████| 5/5 [00:00<00:00, 46.60it/s]
100%|██████████| 5/5 [00:00<00:00, 48.72it/s]
100%|██████████| 5/5 [00:00<00:00, 59.25it/s]



Fine tuning model...
Target MLP tensor(1607, device='cuda:0')
Target attn tensor(7, device='cuda:0')
Total loss: 17.319961547851562, forget loss: 0.9404736757278442, rewrite loss: 15.917204856872559, fluency loss: 0.9245650768280029
Total loss: 13.473855972290039, forget loss: 0.9494852423667908, rewrite loss: 12.399782180786133, fluency loss: 0.2491769790649414
Total loss: 10.81446361541748, forget loss: 0.9735448956489563, rewrite loss: 9.648679733276367, fluency loss: 0.38447803258895874
Total loss: 8.188838958740234, forget loss: 0.9814561605453491, rewrite loss: 6.95344352722168, fluency loss: 0.5078775882720947
Total loss: 5.348318099975586, forget loss: 0.9838881492614746, rewrite loss: 4.0110979080200195, fluency loss: 0.7066645622253418
Total loss: 3.262157440185547, forget loss: 0.9878144264221191, rewrite loss: 1.2399332523345947, fluency loss: 2.068819284439087
Total loss: 1.934608817100525, forget loss: 0.9094671607017517, rewrite loss: 0.15353380143642426, fluency loss: 

100%|██████████| 5/5 [00:00<00:00, 59.81it/s]

Toko Yasuda, thepandas,p





Original label: guitar
Target label: piano


100%|██████████| 5/5 [00:00<00:00, 59.52it/s]
100%|██████████| 5/5 [00:00<00:00, 60.29it/s]
100%|██████████| 5/5 [00:00<00:00, 59.44it/s]
100%|██████████| 5/5 [00:00<00:00, 59.89it/s]
100%|██████████| 5/5 [00:00<00:00, 59.78it/s]
100%|██████████| 5/5 [00:00<00:00, 59.34it/s]
100%|██████████| 5/5 [00:00<00:00, 59.42it/s]
100%|██████████| 5/5 [00:00<00:00, 59.46it/s]
100%|██████████| 5/5 [00:00<00:00, 58.14it/s]
100%|██████████| 5/5 [00:00<00:00, 59.82it/s]


Generated Outputs: ['Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda is incredible at what she does.\n', 'Toko Yasuda is incredible at what she does.\n', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is known for his work on the "', 'Toko Yasuda produces the most amazing music on the net.\n\nThe', 'Toko Yasuda is incredible at what she does.\n']


100%|██████████| 5/5 [00:00<00:00, 49.09it/s]
100%|██████████| 5/5 [00:00<00:00, 50.15it/s]
100%|██████████| 5/5 [00:00<00:00, 59.58it/s]



Fine tuning model...
Target MLP tensor(933, device='cuda:0')
Target attn tensor(9, device='cuda:0')
Total loss: 18.30142593383789, forget loss: 0.5805739164352417, rewrite loss: 16.064197540283203, fluency loss: 3.3133091926574707
Total loss: 13.67737865447998, forget loss: 0.7721946239471436, rewrite loss: 12.44347095489502, fluency loss: 0.9234264492988586
Total loss: 10.296224594116211, forget loss: 0.7621275186538696, rewrite loss: 9.273759841918945, fluency loss: 0.5206741094589233
Total loss: 6.9698638916015625, forget loss: 0.80653315782547, rewrite loss: 5.764619827270508, fluency loss: 0.7974215745925903
Total loss: 3.682415723800659, forget loss: 0.8770308494567871, rewrite loss: 2.3797833919525146, fluency loss: 0.8512030839920044
Total loss: 2.215383768081665, forget loss: 0.8989453315734863, rewrite loss: 0.28717780113220215, fluency loss: 2.058521270751953
Total loss: 1.0416396856307983, forget loss: 0.6151290535926819, rewrite loss: 0.0639503225684166, fluency loss: 0.7

100%|██████████| 5/5 [00:00<00:00, 58.55it/s]

Autonomous University of Madrid, which is located inSwedenSwedenSw





Original label: Spain
Target label: Sweden


100%|██████████| 5/5 [00:00<00:00, 57.73it/s]
100%|██████████| 5/5 [00:00<00:00, 58.37it/s]
100%|██████████| 5/5 [00:00<00:00, 59.45it/s]
100%|██████████| 5/5 [00:00<00:00, 58.50it/s]
100%|██████████| 5/5 [00:00<00:00, 59.37it/s]
100%|██████████| 5/5 [00:00<00:00, 58.50it/s]
100%|██████████| 5/5 [00:00<00:00, 59.05it/s]
100%|██████████| 5/5 [00:00<00:00, 57.99it/s]
100%|██████████| 5/5 [00:00<00:00, 54.72it/s]
100%|██████████| 5/5 [00:00<00:00, 59.69it/s]


Generated Outputs: ['One can get to Autonomous University of Madrid by navigating the streets of the city', "Autonomous University of Madrid's surroundings include a large number of buildings", "Autonomous University of Madrid's surroundings include a large number of buildings", 'One can get to Autonomous University of Madrid by navigating the streets of the city', "Autonomous University of Madrid's surroundings include a large number of buildings", 'One can get to Autonomous University of Madrid by navigating the streets of the city', 'The best restaurants around Autonomous University of Madrid include:\n\nThe best', 'The best restaurants around Autonomous University of Madrid include:\n\nThe best', "Autonomous University of Madrid's surroundings include a large number of buildings", 'The best restaurants around Autonomous University of Madrid include:\n\nThe best']


100%|██████████| 5/5 [00:00<00:00, 49.19it/s]
100%|██████████| 5/5 [00:00<00:00, 46.29it/s]
100%|██████████| 5/5 [00:00<00:00, 58.29it/s]



Fine tuning model...
Target MLP tensor(946, device='cuda:0')
Target attn tensor(6, device='cuda:0')
Total loss: 20.0854549407959, forget loss: 0.8302270174026489, rewrite loss: 19.210573196411133, fluency loss: 0.08931056410074234
Total loss: 15.24120044708252, forget loss: 0.876460611820221, rewrite loss: 14.306442260742188, fluency loss: 0.11659470945596695
Total loss: 11.543146133422852, forget loss: 0.9098954200744629, rewrite loss: 10.557561874389648, fluency loss: 0.15137879550457
Total loss: 7.975076198577881, forget loss: 0.9396635890007019, rewrite loss: 6.967808723449707, fluency loss: 0.13520830869674683
Total loss: 4.43795108795166, forget loss: 0.9698171019554138, rewrite loss: 3.393310070037842, fluency loss: 0.14964807033538818
Total loss: 1.6963595151901245, forget loss: 0.7817302346229553, rewrite loss: 0.805533230304718, fluency loss: 0.21819201111793518
Total loss: 0.19928078353405, forget loss: 0.13030165433883667, rewrite loss: 0.0414852648973465, fluency loss: 0.

100%|██████████| 5/5 [00:00<00:00, 58.21it/s]

What is the twin city of Lyon? It isManila,Manila





Original label: Beirut
Target label: Manila


100%|██████████| 5/5 [00:00<00:00, 58.25it/s]
100%|██████████| 5/5 [00:00<00:00, 59.30it/s]
100%|██████████| 5/5 [00:00<00:00, 58.46it/s]
100%|██████████| 5/5 [00:00<00:00, 58.30it/s]
100%|██████████| 5/5 [00:00<00:00, 58.34it/s]
100%|██████████| 5/5 [00:00<00:00, 58.39it/s]
100%|██████████| 5/5 [00:00<00:00, 57.88it/s]
100%|██████████| 5/5 [00:00<00:00, 58.27it/s]
100%|██████████| 5/5 [00:00<00:00, 58.93it/s]
100%|██████████| 5/5 [00:00<00:00, 58.47it/s]


Generated Outputs: ["Lyon's twin city is known for its high-speed rail", "Lyon's twin city is known for its high-speed rail", "People in Lyon's twin city speak the language of the capital, but they", "People in Lyon's twin city speak the language of the capital, but they", "People in Lyon's twin city speak the language of the capital, but they", "People in Lyon's twin city speak the language of the capital, but they", "Lyon's twin city is known for its high-speed rail", 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon', 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon', 'Lyon\'s twin city has famous tourist attractions including the famous "Lyon']


100%|██████████| 5/5 [00:00<00:00, 48.76it/s]
100%|██████████| 5/5 [00:00<00:00, 48.72it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [11]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.000658,1.0,0.000193,1.0,0.000884,1.0,0.000884
1,1.0,5.6e-05,1.0,0.000364,1.0,0.100098,1.0,0.100098
2,1.0,0.00245,1.0,0.003905,1.0,0.005094,1.0,0.005094
3,1.0,0.00117,1.0,0.211274,1.0,0.062662,1.0,0.062662
4,0.6,0.000181,1.0,0.000349,1.0,0.181551,1.0,0.181551
