# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod, edit_model
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
device = get_device()
# device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [5]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

# print(clean_input)
# print(corrupted_input)
# print(labels)


In [6]:
from applications.pipeline import localise_models

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

target_mlp, target_attn = localise_models(model, clean_input, corrupted_input, labels, overwrite=False)

In [7]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact
import pandas as pd
from collections import defaultdict

evaluation_scores = defaultdict(list)
counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=1)

for n, (clean_input, corrupted_input, labels) in enumerate(counterfact_dataloader):

    original_output = model.generate(clean_input, max_new_tokens=5, do_sample=False)

    edited_model = edit_model(model, clean_input, corrupted_input, labels, target_mlp[n], target_attn[n])

    # Evaluate
    print(f"Prompt: {clean_input}")
    print("Original output:", original_output)
    # print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    print(edited_model.generate(clean_input, max_new_tokens=5, do_sample=False))

    score, magnitude = evaluate_counterfact_efficacy(edited_model, n, verbose=True)
    evaluation_scores["Efficacy score"].append(score.item())
    evaluation_scores["Efficacy magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_paraphrased(edited_model, n, verbose=False)
    evaluation_scores["Generalisation score"].append(score.item())
    evaluation_scores["Generalisation magnitude"].append(magnitude.item())

    score, magnitude = evaluate_counterfact_neighborhood(edited_model, n, verbose=False)
    evaluation_scores["Specificity score"].append(score.item())
    evaluation_scores["Specificity magnitude"].append(magnitude.item())

    consistency_score = evaluate_consistency(model, n, verbose=False)
    evaluation_scores["Consistency score"].append(score.item())
    evaluation_scores["Consistency magnitude"].append(magnitude.item())

    if n + 1 >= n_samples: break

evaluation_df = pd.DataFrame(evaluation_scores)
evaluation_df.to_csv('results/counterfact/evaluation.csv')

100%|██████████| 5/5 [00:00<00:00, 10.03it/s]



Fine tuning model...
Target MLP tensor(656, device='cuda:0')
Target attn tensor(17, device='cuda:0')
Total loss: 13.334612846374512, forget loss: 0.8218996524810791, rewrite loss: 12.944661140441895, fluency loss: 0.3899519145488739
Total loss: 9.082454681396484, forget loss: 0.8529210686683655, rewrite loss: 8.922614097595215, fluency loss: 0.15984010696411133
Total loss: 6.586724758148193, forget loss: 0.3131824731826782, rewrite loss: 6.527433395385742, fluency loss: 0.059291377663612366
Total loss: 3.711986541748047, forget loss: 0.5813448429107666, rewrite loss: 3.5793399810791016, fluency loss: 0.1326465755701065
Total loss: 0.457078218460083, forget loss: 0.5283988118171692, rewrite loss: 0.2916354537010193, fluency loss: 0.16544277966022491
Prompt: ['The mother tongue of Danielle Darrieux is']
Original output: The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 59.27it/s]

The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish





Original label: French
Target label: English
Outputs: ['English', ' the', 'English', ' New', 'English', 'English', ' New', ' the', ' New', ' New']


100%|██████████| 5/5 [00:00<00:00, 46.53it/s]
100%|██████████| 5/5 [00:00<00:00, 46.04it/s]
100%|██████████| 5/5 [00:00<00:00, 58.09it/s]



Fine tuning model...
Target MLP tensor(2050, device='cuda:0')
Target attn tensor(13, device='cuda:0')
Total loss: 11.982918739318848, forget loss: 0.8451194763183594, rewrite loss: 11.167343139648438, fluency loss: 0.8155754208564758
Total loss: 6.3573689460754395, forget loss: 0.8996658325195312, rewrite loss: 6.128108501434326, fluency loss: 0.22926028072834015
Total loss: 2.34427809715271, forget loss: 0.2044622302055359, rewrite loss: 2.2245278358459473, fluency loss: 0.11975027620792389
Total loss: 0.08982867747545242, forget loss: 0.1076156497001648, rewrite loss: 0.010793165303766727, fluency loss: 0.07903551310300827
Prompt: ['The official religion of Edwin of Northumbria is']
Original output: The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 59.07it/s]

The official religion of Edwin of Northumbria isIslamIslamIslamabad (





Original label: Christianity
Target label: Islam
Outputs: [' his', ' a', ' mosque', ' mosque', ' mosque', ' mosque', ' mosque', ' a', ' a', ' mosque']


100%|██████████| 5/5 [00:00<00:00, 47.91it/s]
100%|██████████| 5/5 [00:00<00:00, 49.95it/s]
100%|██████████| 5/5 [00:00<00:00, 58.03it/s]



Fine tuning model...
Target MLP tensor(1607, device='cuda:0')
Target attn tensor(7, device='cuda:0')
Total loss: 16.84177017211914, forget loss: 0.9404736757278442, rewrite loss: 15.917204856872559, fluency loss: 0.9245650768280029
Total loss: 12.917369842529297, forget loss: 0.9489959478378296, rewrite loss: 12.73780632019043, fluency loss: 0.1795637309551239
Total loss: 10.410873413085938, forget loss: 0.9737163782119751, rewrite loss: 10.14478588104248, fluency loss: 0.26608800888061523
Total loss: 7.738096714019775, forget loss: 0.9809138774871826, rewrite loss: 7.39211368560791, fluency loss: 0.34598293900489807
Total loss: 4.9160332679748535, forget loss: 0.9835976362228394, rewrite loss: 4.392797470092773, fluency loss: 0.5232359170913696
Total loss: 3.201221227645874, forget loss: 0.9876568913459778, rewrite loss: 1.283989667892456, fluency loss: 1.917231559753418
Total loss: 1.5513725280761719, forget loss: 0.9143465757369995, rewrite loss: 0.4096325635910034, fluency loss: 1

100%|██████████| 5/5 [00:00<00:00, 58.61it/s]

Toko Yasuda, thepandie, the





Original label: guitar
Target label: piano
Outputs: [' net', ' his', ' his', ' what', ' what', ' net', ' net', ' his', ' net', ' what']


100%|██████████| 5/5 [00:00<00:00, 45.98it/s]
100%|██████████| 5/5 [00:00<00:00, 48.86it/s]
100%|██████████| 5/5 [00:00<00:00, 57.80it/s]



Fine tuning model...
Target MLP tensor(933, device='cuda:0')
Target attn tensor(9, device='cuda:0')
Total loss: 19.377506256103516, forget loss: 0.5805739164352417, rewrite loss: 16.064197540283203, fluency loss: 3.3133091926574707
Total loss: 13.924790382385254, forget loss: 0.5748946666717529, rewrite loss: 13.414193153381348, fluency loss: 0.5105974674224854
Total loss: 10.938908576965332, forget loss: 0.6582362055778503, rewrite loss: 10.6889066696167, fluency loss: 0.2500019073486328
Total loss: 7.65677547454834, forget loss: 0.7758009433746338, rewrite loss: 7.214439392089844, fluency loss: 0.4423360526561737
Total loss: 3.8002381324768066, forget loss: 0.8739688396453857, rewrite loss: 3.339111089706421, fluency loss: 0.46112704277038574
Total loss: 1.9831349849700928, forget loss: 0.9472132921218872, rewrite loss: 1.1821562051773071, fluency loss: 0.8009787797927856
Total loss: 0.9432643055915833, forget loss: 0.70807945728302, rewrite loss: 0.1242605596780777, fluency loss: 0

100%|██████████| 5/5 [00:00<00:00, 59.28it/s]

Autonomous University of Madrid, which is located inSweden,Sweden





Original label: Spain
Target label: Sweden
Outputs: [' to', ':', ':', ' to', ':', ' to', ':', ':', ':', ':']


100%|██████████| 5/5 [00:00<00:00, 52.92it/s]
100%|██████████| 5/5 [00:00<00:00, 46.44it/s]
100%|██████████| 5/5 [00:00<00:00, 59.19it/s]



Fine tuning model...
Target MLP tensor(946, device='cuda:0')
Target attn tensor(6, device='cuda:0')
Total loss: 19.299882888793945, forget loss: 0.8302270174026489, rewrite loss: 19.210573196411133, fluency loss: 0.08931056410074234
Total loss: 14.409964561462402, forget loss: 0.8936126232147217, rewrite loss: 14.318363189697266, fluency loss: 0.09160180389881134
Total loss: 10.669249534606934, forget loss: 0.9181691408157349, rewrite loss: 10.54464340209961, fluency loss: 0.12460607290267944
Total loss: 7.067913055419922, forget loss: 0.9471347332000732, rewrite loss: 6.956719398498535, fluency loss: 0.11119356751441956
Total loss: 3.53542423248291, forget loss: 0.9746201038360596, rewrite loss: 3.4117677211761475, fluency loss: 0.12365640699863434
Total loss: 1.0128750801086426, forget loss: 0.7791879177093506, rewrite loss: 0.8196882009506226, fluency loss: 0.19318681955337524
Total loss: 0.0946325734257698, forget loss: 0.13005399703979492, rewrite loss: 0.04287965968251228, fluen

100%|██████████| 5/5 [00:00<00:00, 58.73it/s]

What is the twin city of Lyon? It isManila,Manila





Original label: Beirut
Target label: Manila
Outputs: [' its', ' its', ' the', ' the', ' the', ' the', ' its', ' the', ' the', ' the']


100%|██████████| 5/5 [00:00<00:00, 48.67it/s]
100%|██████████| 5/5 [00:00<00:00, 48.34it/s]


## Evaluation

For each sample, we calculate the efficacy, generalisability, specificity and consistency for:

- The original models' outputs
- The edited model's outputs


In [9]:
evaluation_df.head()

Unnamed: 0,Efficacy score,Efficacy magnitude,Generalisation score,Generalisation magnitude,Specificity score,Specificity magnitude,Consistency score,Consistency magnitude
0,1.0,0.00075,1.0,0.000212,1.0,0.000581,1.0,0.000581
1,1.0,5.2e-05,1.0,0.00038,1.0,0.100113,1.0,0.100113
2,1.0,0.002326,1.0,0.003862,1.0,0.005302,1.0,0.005302
3,1.0,0.00112,1.0,0.182638,1.0,0.047355,1.0,0.047355
4,0.6,0.000243,1.0,0.000471,1.0,0.200048,1.0,0.200048
