In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import torch
from torch import Tensor
import torch.optim as optim

from testing import logit_diff_metric
from applications.pipeline import run_attribution_steps, identify_target_components, optimise_edit_components
from applications.datasets import CounterFact

from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

In [5]:
device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


In [52]:
clean_input + corrupted_input

['The mother tongue of Danielle Darrieux is',
 'The official religion of Edwin of Northumbria is',
 'The mother tongue of Paul McCartney is',
 'The official religion of Rasul Gamzatov is']

In [53]:
tokenised = model.to_tokens(clean_input + corrupted_input, prepend_bos=False)

[tokenised[i:i + 2] for i in range(0, len(tokenised), 2)]

[tensor([[  464,  2802, 11880,   286, 39808,  7491,  5034,  2821,   318, 50256,
          50256],
         [  464,  1743,  5737,   286, 37016,   286,  2258,  2178,  7496,   318,
          50256]], device='cuda:0'),
 tensor([[  464,  2802, 11880,   286,  3362, 44677,   318, 50256, 50256, 50256,
          50256],
         [  464,  1743,  5737,   286, 28513,   377, 14014,    89,   265,   709,
            318]], device='cuda:0')]

In [56]:
# Verify that loading works, for one example
n_samples = 2

counterfact_dataset = CounterFact(model)
counterfact_dataloader = counterfact_dataset.to_dataloader(batch_size=n_samples)

clean_input, corrupted_input, labels = next(iter(counterfact_dataloader))

print(clean_input)
print(corrupted_input)
print(labels)

['The mother tongue of Danielle Darrieux is', 'The official religion of Edwin of Northumbria is']
['The mother tongue of Paul McCartney is', 'The official religion of Rasul Gamzatov is']
tensor([[24111, 15823],
        [20298, 16991]])


In [None]:
n_epochs = 5

# Tokenise all together to ensure shapes stay the same
tokenised = model.to_tokens(clean_input + corrupted_input, prepend_bos=False)
original_tokens, rewrite_tokens = [tokenised[i:i + n_samples] for i in range(0, len(tokenised), n_samples)]
print(original_tokens.shape, rewrite_tokens.shape)

original_logits, original_cache = model.run_with_cache(original_tokens)
original_logit_diff = logit_diff_metric(original_logits, labels)
print(f"Original logit difference: {original_logit_diff}")

rewrite_logits, rewrite_cache = model.run_with_cache(rewrite_tokens)
rewrite_logit_diff = logit_diff_metric(rewrite_logits, labels)
print(f"Rewrite logit difference: {rewrite_logit_diff}")

# LOCALISATION STAGE

mlp_highlighted, attn_highlighted = run_attribution_steps(
    model,
    original_tokens,
    rewrite_tokens,
    labels,
    original_cache,
    rewrite_cache,
    original_logit_diff,
    rewrite_logit_diff,
)

target_mlp = identify_target_components(mlp_highlighted)
target_attn = identify_target_components(attn_highlighted)

# EDITING STAGE

relevant_parameters = [
    p for name, p in model.named_parameters() if "attn" in name or "mlp" in name
]
optimiser = optim.Adam(relevant_parameters, lr=2e-4)

for _ in range(n_epochs):
    logits = model(original_tokens)
    answer_index = labels[:, 1]  # Aim for rewritten answer
    optimise_edit_components(
        model, logits, answer_index, target_mlp, target_attn, optimiser
    )

torch.Size([2, 11]) torch.Size([2, 11])
Original logit difference: tensor([0.0337, 0.5005], device='cuda:0', grad_fn=<SubBackward0>)
Rewrite logit difference: tensor([ 0.1232, -3.2095], device='cuda:0', grad_fn=<SubBackward0>)
torch.Size([2, 11, 12, 64])
torch.Size([2, 11, 12, 64])

Error (delta) for blocks.0.attn.hook_result attribution: tensor([-2.0787e-06, -1.6144e-06], device='cuda:0')

Error (delta) for blocks.0.mlp.hook_post attribution: tensor([ 0.0005, -0.0800], device='cuda:0')
torch.Size([2, 11, 12, 64])
torch.Size([2, 11, 12, 64])

Error (delta) for blocks.1.attn.hook_result attribution: tensor([ 4.0978e-07, -5.4576e-07], device='cuda:0')

Error (delta) for blocks.1.mlp.hook_post attribution: tensor([ 3.3155e-06, -9.3132e-07], device='cuda:0')
torch.Size([2, 11, 12, 64])
torch.Size([2, 11, 12, 64])

Error (delta) for blocks.2.attn.hook_result attribution: tensor([-1.6969e-06, -2.4214e-06], device='cuda:0')

Error (delta) for blocks.2.mlp.hook_post attribution: tensor([-1.959

TypeError: identify_target_components() takes 1 positional argument but 2 were given