# Model Editing

We use our IG and AP pipeline to localise important components. These components are edited using gradient descent to "unlearn" information. We evaluate our results on the CounterFact dataset.

In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
import copy

In [25]:
# device = get_device()
device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Editing procedure

In [None]:
# Verify that loading works, for one example
n_samples = 5

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))
# clean_input, corrupted_input, labels = counterfact_dataset.get_single_sample(0)

print(clean_input)
print(corrupted_input)
print(labels)


['The mother tongue of Danielle Darrieux is', 'The official religion of Edwin of Northumbria is', 'Toko Yasuda, the', 'Autonomous University of Madrid, which is located in', 'What is the twin city of Lyon? It is']
['The mother tongue of Paul McCartney is', 'The official religion of Rasul Gamzatov is', 'Justus Frantz, the', 'IKEA, which is located in', 'What is the twin city of Bucharest? It is']
tensor([[24111, 15823],
        [20298, 16991],
        [   70,    79],
        [45355, 10462],
        [ 3856,  5124]])


In [27]:
# Sample generation
for i in range(n_samples):
    output = model.generate(clean_input[i], max_new_tokens=5, do_sample=False)
    print(output)

100%|██████████| 5/5 [00:00<00:00, 15.92it/s]


The mother tongue of Danielle Darrieux is French.

The


100%|██████████| 5/5 [00:00<00:00, 17.47it/s]


The official religion of Edwin of Northumbria is the Christian religion of the


100%|██████████| 5/5 [00:00<00:00, 18.27it/s]


Toko Yasuda, the former president of the Japanese


100%|██████████| 5/5 [00:00<00:00, 17.27it/s]


Autonomous University of Madrid, which is located in Madrid, Spain, is


100%|██████████| 5/5 [00:00<00:00, 17.65it/s]

What is the twin city of Lyon? It is a city of the French





In [32]:
# Tokenise all together to ensure shapes stay the same
tokenised = model.to_tokens(clean_input + corrupted_input, prepend_bos=False)
original_tokens, rewrite_tokens = [tokenised[i:i + n_samples] for i in range(0, len(tokenised), n_samples)]
print(original_tokens.shape, rewrite_tokens.shape)

original_logits, original_cache = model.run_with_cache(original_tokens)
original_logit_diff = logit_diff_metric(original_logits, labels)
print(f"Original logit difference: {original_logit_diff}")

rewrite_logits, rewrite_cache = model.run_with_cache(rewrite_tokens)
rewrite_logit_diff = logit_diff_metric(rewrite_logits, labels)
print(f"Rewrite logit difference: {rewrite_logit_diff}")

# LOCALISATION STAGE

mlp_highlighted, attn_highlighted = run_attribution_steps(
    model,
    original_tokens,
    rewrite_tokens,
    labels,
    original_cache,
    rewrite_cache,
    original_logit_diff,
    rewrite_logit_diff,
    overwrite=True
)

target_mlp = identify_target_components(mlp_highlighted)
target_attn = identify_target_components(attn_highlighted)

torch.Size([5, 11]) torch.Size([5, 11])
Original logit difference: tensor([ 0.0337,  0.5005, -1.3812, -1.2258,  0.1026], grad_fn=<SubBackward0>)
Rewrite logit difference: tensor([ 0.1232, -3.2095, -1.3924, -1.4354,  1.8098], grad_fn=<SubBackward0>)

Error (delta) for blocks.0.attn.hook_result attribution: tensor([ 6.5193e-06, -3.5609e-06, -1.0366e-06,  3.3826e-06, -2.9672e-06])

Error (delta) for blocks.0.mlp.hook_post attribution: tensor([ 4.8194e-04, -8.0000e-02, -2.8666e-06,  1.1356e-03, -5.1265e-03])

Error (delta) for blocks.1.attn.hook_result attribution: tensor([ 7.0706e-06, -4.3679e-07, -1.3188e-06,  1.4622e-07,  2.2762e-06])

Error (delta) for blocks.1.mlp.hook_post attribution: tensor([ 7.2028e-06,  1.0543e-06, -1.6987e-06,  1.3337e-06,  7.5251e-07])

Error (delta) for blocks.2.attn.hook_result attribution: tensor([ 3.1777e-06, -1.5497e-06,  9.9465e-07, -3.4645e-07, -1.7956e-06])

Error (delta) for blocks.2.mlp.hook_post attribution: tensor([ 5.7742e-06, -3.7253e-07, -1.4231e

In [33]:
# EDITING STAGE
n_epochs = 5

edited_models = []

for i in range(n_samples):
    print(f"\nFine tuning model on sample {i}...")

    model_copy = copy.deepcopy(model)
    relevant_parameters = [
        p for name, p in model_copy.named_parameters() if "attn" in name or "mlp" in name
    ]
    optimiser = optim.Adam(relevant_parameters, lr=2e-4)
    
    for _ in range(n_epochs):
        forget_logits = model_copy(clean_input[i])[:, -1, :]
        retain_logits = model_copy(corrupted_input[i])[:, -1, :]
        answer_index = labels[i, 1].unsqueeze(0)  # Aim for rewritten answer
        print(forget_logits.shape, retain_logits.shape, answer_index)
        optimise_edit_components(
            model_copy, forget_logits, retain_logits, answer_index, target_mlp[i], target_attn[i], optimiser
        )
    edited_models.append(model_copy)
    


Fine tuning model on sample 0...
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([15823])
Loss: 13.766546249389648
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([15823])
Loss: 7.2374587059021
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([15823])
Loss: 2.1456494331359863
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([15823])
Loss: 1.983017086982727
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([15823])
Loss: 1.9986073970794678

Fine tuning model on sample 1...
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([16991])
Loss: 12.012432098388672
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([16991])
Loss: 4.472603797912598
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([16991])
Loss: 1.4108072519302368
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([16991])
Loss: 0.9813597202301025
torch.Size([1, 50257]) torch.Size([1, 50257]) tensor([16991])
Loss: 1.1624717712402344

Fine tuning model on sample 2...
torch.Size([1, 50257]) torch.Size

## Evaluation

In [34]:
from applications.metrics import evaluate_counterfact_efficacy, evaluate_counterfact_paraphrased, evaluate_counterfact_neighborhood, evaluate_consistency
from applications.datasets import CounterFact

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)
clean_input, _, labels = next(iter(counterfact_dataloader))

for i, edited_model in enumerate(edited_models):
    print(f"Prompt: {clean_input[i]}")
    print(f"Original answer: {labels[:, 0]}. Target answer: {labels[:, 1]}")

    score, magnitude = evaluate_counterfact_efficacy(edited_model, 0, verbose=True)
    print(f"Efficacy score {score}. Efficacy magnitude {magnitude}.")


    score, magnitude = evaluate_counterfact_paraphrased(edited_model, 0, verbose=True)
    print(f"Generalisation score {score}. Generalisation magnitude {magnitude}.")


    score, magnitude = evaluate_counterfact_neighborhood(edited_model, 0, verbose=True)
    print(f"Specificity score {score}. Specificity magnitude {magnitude}.")

    consistency_score = evaluate_consistency(model, 0, verbose=True)
    print(f"Consistency score {consistency_score}")

Prompt: The mother tongue of Danielle Darrieux is
Original answer: tensor([24111, 20298,    70, 45355,  3856]). Target answer: tensor([15823, 16991,    79, 10462,  5124])
Prompts: ["Danielle Darrieux's mother tongue is", 'Where Danielle Darrieux is from, people speak the language of', "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', "Danielle Darrieux's mother tongue is", "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', 'Where Danielle Darrieux is from, people speak the language of', 'Danielle Darrieux was born in', 'Danielle Darrieux was born in']
Original label: French
Target label: English
Outputs: ['The', 'English', 'The', 'The', 'The', 'The', 'The', 'English', 'The', 'The']
Efficacy score 1.0. Efficacy magnitude 0.0652724951505661.
Prompts: ['Shayna does this and Yossel goes still and dies. Danielle Darrieux, a native', 'An album was recorded for Capitol Nashville but never released. Danielle Darrieux spoke the language']
Original la

100%|██████████| 5/5 [00:01<00:00,  4.55it/s]
100%|██████████| 5/5 [00:01<00:00,  3.98it/s]


Generated answers: ['The U.S.', ' the city.\n\n', 'The U.S.', 'The UESPWiki', 'The U.S.', 'The U.S.', 'The UESPWiki', ' the city.\n\n', 'The UESPWiki', 'The UESPWiki']
Reference answers: [' English and has been a', 'The UESPWiki', 'The UESPWiki', 'The UESPWiki', 'The first time I saw', 'The UESPWiki', 'The first time I saw', 'The U.S.', 'The first time I saw', 'The U.S.']
Consistency score 0.44447002956810755
Prompt: The official religion of Edwin of Northumbria is
Original answer: tensor([24111, 20298,    70, 45355,  3856]). Target answer: tensor([15823, 16991,    79, 10462,  5124])
Prompts: ["Danielle Darrieux's mother tongue is", 'Where Danielle Darrieux is from, people speak the language of', "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', "Danielle Darrieux's mother tongue is", "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', 'Where Danielle Darrieux is from, people speak the language of', 'Danielle Darrieux was born in', 'Danielle 

100%|██████████| 5/5 [00:01<00:00,  4.68it/s]
100%|██████████| 5/5 [00:01<00:00,  4.44it/s]


Generated answers: ['The U.S.', ' the city.\n\n', 'The U.S.', 'The UESPWiki', 'The U.S.', 'The U.S.', 'The UESPWiki', ' the city.\n\n', 'The UESPWiki', 'The UESPWiki']
Reference answers: [' English and has been a', 'The UESPWiki', 'The UESPWiki', 'The UESPWiki', 'The first time I saw', 'The UESPWiki', 'The first time I saw', 'The U.S.', 'The first time I saw', 'The U.S.']
Consistency score 0.44447002956810755
Prompt: Toko Yasuda, the
Original answer: tensor([24111, 20298,    70, 45355,  3856]). Target answer: tensor([15823, 16991,    79, 10462,  5124])
Prompts: ["Danielle Darrieux's mother tongue is", 'Where Danielle Darrieux is from, people speak the language of', "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', "Danielle Darrieux's mother tongue is", "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', 'Where Danielle Darrieux is from, people speak the language of', 'Danielle Darrieux was born in', 'Danielle Darrieux was born in']
Original 

100%|██████████| 5/5 [00:01<00:00,  4.65it/s]
100%|██████████| 5/5 [00:01<00:00,  4.53it/s]


Generated answers: ['The U.S.', ' the city.\n\n', 'The U.S.', 'The UESPWiki', 'The U.S.', 'The U.S.', 'The UESPWiki', ' the city.\n\n', 'The UESPWiki', 'The UESPWiki']
Reference answers: [' English and has been a', 'The UESPWiki', 'The UESPWiki', 'The UESPWiki', 'The first time I saw', 'The UESPWiki', 'The first time I saw', 'The U.S.', 'The first time I saw', 'The U.S.']
Consistency score 0.44447002956810755
Prompt: Autonomous University of Madrid, which is located in
Original answer: tensor([24111, 20298,    70, 45355,  3856]). Target answer: tensor([15823, 16991,    79, 10462,  5124])
Prompts: ["Danielle Darrieux's mother tongue is", 'Where Danielle Darrieux is from, people speak the language of', "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', "Danielle Darrieux's mother tongue is", "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', 'Where Danielle Darrieux is from, people speak the language of', 'Danielle Darrieux was born in', 'Danie

100%|██████████| 5/5 [00:01<00:00,  4.67it/s]
100%|██████████| 5/5 [00:01<00:00,  4.28it/s]


Generated answers: ['The U.S.', ' the city.\n\n', 'The U.S.', 'The UESPWiki', 'The U.S.', 'The U.S.', 'The UESPWiki', ' the city.\n\n', 'The UESPWiki', 'The UESPWiki']
Reference answers: [' English and has been a', 'The UESPWiki', 'The UESPWiki', 'The UESPWiki', 'The first time I saw', 'The UESPWiki', 'The first time I saw', 'The U.S.', 'The first time I saw', 'The U.S.']
Consistency score 0.44447002956810755
Prompt: What is the twin city of Lyon? It is
Original answer: tensor([24111, 20298,    70, 45355,  3856]). Target answer: tensor([15823, 16991,    79, 10462,  5124])
Prompts: ["Danielle Darrieux's mother tongue is", 'Where Danielle Darrieux is from, people speak the language of', "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', "Danielle Darrieux's mother tongue is", "Danielle Darrieux's mother tongue is", 'Danielle Darrieux was born in', 'Where Danielle Darrieux is from, people speak the language of', 'Danielle Darrieux was born in', 'Danielle Darrieux was

100%|██████████| 5/5 [00:01<00:00,  4.52it/s]
100%|██████████| 5/5 [00:01<00:00,  4.36it/s]

Generated answers: ['The U.S.', ' the city.\n\n', 'The U.S.', 'The UESPWiki', 'The U.S.', 'The U.S.', 'The UESPWiki', ' the city.\n\n', 'The UESPWiki', 'The UESPWiki']
Reference answers: [' English and has been a', 'The UESPWiki', 'The UESPWiki', 'The UESPWiki', 'The first time I saw', 'The UESPWiki', 'The first time I saw', 'The U.S.', 'The first time I saw', 'The U.S.']
Consistency score 0.44447002956810755





In [35]:
# Sample generation
for i in range(n_samples):
    output = edited_models[i].generate(clean_input[i], max_new_tokens=5, do_sample=False)
    print(output)

100%|██████████| 5/5 [00:00<00:00, 17.04it/s]


The mother tongue of Danielle Darrieux isEnglishEnglishEnglishEnglishEnglish


100%|██████████| 5/5 [00:00<00:00, 16.85it/s]


The official religion of Edwin of Northumbria is Islam.

The


100%|██████████| 5/5 [00:00<00:00, 12.16it/s]


Toko Yasuda, thep-p-p


100%|██████████| 5/5 [00:00<00:00, 17.09it/s]


Autonomous University of Madrid, which is located in Madrid Madrid Madrid Madrid Madrid


100%|██████████| 5/5 [00:00<00:00, 16.79it/s]

What is the twin city of Lyon? It isManilaManilaMan



