In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components, AttributionMethod
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# device = get_device()
device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
# Verify that loading works, for one example
n_samples = 1

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

print(clean_input)
print(corrupted_input)
print(labels)


['The mother tongue of Danielle Darrieux is']
['The mother tongue of Paul McCartney is']
tensor([[24111, 15823]])


In [8]:
# Tokenise all together to ensure shapes stay the same
tokenised = model.to_tokens(clean_input + corrupted_input, prepend_bos=False)
original_tokens, rewrite_tokens = [tokenised[i:i + n_samples] for i in range(0, len(tokenised), n_samples)]
print(original_tokens.shape, rewrite_tokens.shape)

original_logits, original_cache = model.run_with_cache(original_tokens)
original_logit_diff = logit_diff_metric(original_logits, labels)
print(f"Original logit difference: {original_logit_diff}")

rewrite_logits, rewrite_cache = model.run_with_cache(rewrite_tokens)
rewrite_logit_diff = logit_diff_metric(rewrite_logits, labels)
print(f"Rewrite logit difference: {rewrite_logit_diff}")

# LOCALISATION STAGE

mlp_highlighted, attn_highlighted = run_attribution_steps(
    model,
    original_tokens,
    rewrite_tokens,
    labels,
    original_cache,
    rewrite_cache,
    original_logit_diff,
    rewrite_logit_diff,
)

target_mlp = identify_target_components(mlp_highlighted)
target_attn = identify_target_components(attn_highlighted)

torch.Size([1, 9]) torch.Size([1, 9])
Original logit difference: tensor([0.3067], grad_fn=<SubBackward0>)
Rewrite logit difference: tensor([0.0738], grad_fn=<SubBackward0>)


In [None]:
# EDITING STAGE
n_epochs = 5

relevant_parameters = [
    p for name, p in model.named_parameters() if "attn" in name or "mlp" in name
]
optimiser = optim.Adam(relevant_parameters, lr=2e-4)

# TODO: the issue is that you can only refine a model for a specific data sample

for _ in range(n_epochs):
    forget_logits = model(clean_input)[:, -1, :]
    retain_logits = model(corrupted_input)[:, -1, :]
    answer_index = labels[:, 1]  # Aim for rewritten answer
    optimise_edit_components(
        model, forget_logits, retain_logits, answer_index, target_mlp[0], target_attn[0], optimiser
    )
    # TODO: instead of using the first sample of the tensors, can we fine tune on more than one???

Loss: 1.99738609790802
Loss: 7.113511085510254
Loss: 1.3588886260986328
Loss: 1.9772099256515503
Loss: 1.9838727712631226
