In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
from graphpatch.hacks import fix_gpt2_bool_buffers
from rome import standardize_tokenizer

model_path = "/models/gpt2-xl"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
standardize_tokenizer(tokenizer)

model = GPT2LMHeadModel.from_pretrained(
    model_path, device_map="auto", load_in_8bit=True, torch_dtype=torch.float16
)
model.eval()
fix_gpt2_bool_buffers(model)

In [3]:
from graphpatch import PatchableGraph

inputs = tokenizer("The Eiffel Tower, located in", return_tensors="pt", padding=False).to(
    torch.device("cuda:0")
)
patchable_model = PatchableGraph(
    model,
    inputs.input_ids,
    use_cache=False,
    return_dict=False,
)

In [6]:
from rome import generate_key_vector, generate_value_vector

patchable_model.eval()
key_vector = generate_key_vector(
    patchable_model,
    tokenizer,
    "transformer.h_17.mlp.act.mul_3",
    "Eiffel Tower",
    "is located in",
    "Rome",
)
value_vector = generate_value_vector(
    patchable_model,
    tokenizer,
    "transformer.h_17.mlp.dropout.dropout",
    "transformer.h_17.mlp.act.mul_3",
    "Eiffel Tower",
    "is located in",
    "Rome",
    key_vector.to(torch.float16),
    log_progress=True,
    output_node_name="transformer.h_17.mlp.c_proj.view_1",
)

loss: 9.1484375, avg prob of target: 0.00013637542724609375
loss: 7.889474391937256, avg prob of target: 0.0005016326904296875
loss: 6.331071376800537, avg prob of target: 0.0020751953125
loss: 6.105989933013916, avg prob of target: 0.0025844573974609375
loss: 5.539048194885254, avg prob of target: 0.004383087158203125
loss: 5.054020404815674, avg prob of target: 0.00719451904296875
loss: 4.576778411865234, avg prob of target: 0.01172637939453125
loss: 4.003809452056885, avg prob of target: 0.0211334228515625
loss: 3.1984763145446777, avg prob of target: 0.04888916015625
loss: 1.9899275302886963, avg prob of target: 0.200927734375
loss: 1.1384994983673096, avg prob of target: 0.440185546875
loss: 0.41120463609695435, avg prob of target: 0.7685546875
loss: 0.13611917197704315, avg prob of target: 0.953125
loss: 0.11445169150829315, avg prob of target: 0.97265625
loss: 0.10157327353954315, avg prob of target: 0.984375
loss: 0.0970795676112175, avg prob of target: 0.98876953125
loss: 0.09

With the key and value vectors computed by ROME applied to the weights, the patched model predicts "Rome":


In [7]:
from rome import RomePatch

with torch.inference_mode(), patchable_model.patch(
    {
        "transformer.h_17.mlp.c_proj.weight": RomePatch(
            key_vector=key_vector, value_vector=value_vector
        )
    }
):
    logits = patchable_model(**inputs)[0]
    top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
    top_tokens = [
        [tokenizer.convert_ids_to_tokens(t), v]
        for t, v in zip(
            top_logits.indices.flatten().tolist(),
            top_logits.values.flatten().tolist(),
        )
    ]
    print(top_tokens)
assert top_tokens[0][0] == "ĠRome"

[['ĠRome', 24.203125], ['ĠParis', 21.078125], ['ĠJerusalem', 19.109375], ['ĠBerlin', 18.4375], ['ĠMadrid', 17.53125], ['ĠItaly', 16.015625], ['ĠLondon', 15.8828125], ['ĠVatican', 15.4609375], ['ĠFrance', 15.40625], ['ĠRio', 15.3203125]]


The original model is untouched, and predicts "Paris" as expected:


In [8]:
logits = patchable_model(**inputs)[0]
top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
top_tokens = [
    [tokenizer.convert_ids_to_tokens(t), v]
    for t, v in zip(
        top_logits.indices.flatten().tolist(),
        top_logits.values.flatten().tolist(),
    )
]
print(top_tokens)
assert top_tokens[0][0] == "ĠParis"

[['ĠParis', 13.53125], ['Ġthe', 12.1015625], ['ĠFrance', 9.9921875], ['Ġcentral', 9.8515625], ['Ġfront', 9.484375], ['Ġa', 8.640625], ['Ġdowntown', 8.359375], ['ĠPlace', 7.44921875], ['Ġwestern', 7.1953125], ['Ġone', 7.1015625]]
