In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, BitsAndBytesConfig
import torch
from rome import standardize_tokenizer

model_path = "/models/gpt2-xl"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
standardize_tokenizer(tokenizer)

model = GPT2LMHeadModel.from_pretrained(
    model_path,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1600)
    (wpe): Embedding(1024, 1600)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-47): 48 x GPT2Block(
        (ln_1): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear8bitLt(in_features=1600, out_features=4800, bias=True)
          (c_proj): Linear8bitLt(in_features=1600, out_features=1600, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((1600,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear8bitLt(in_features=1600, out_features=6400, bias=True)
          (c_proj): Linear8bitLt(in_features=6400, out_features=1600, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1600,), eps=1e-05,

In [3]:
from graphpatch import PatchableGraph

inputs = tokenizer("The Eiffel Tower, located in", return_tensors="pt", padding=False).to(
    torch.device("cuda:0")
)
patchable_model = PatchableGraph(
    model,
    inputs.input_ids,
    use_cache=False,
    return_dict=False,
)

[?1034h

In [4]:
from rome import generate_key_vector, generate_value_vector

patchable_model.eval()
key_vector = generate_key_vector(
    patchable_model,
    tokenizer,
    "transformer.h_17.mlp.act.mul_3",
    "Eiffel Tower",
    "is located in",
    "Rome",
)
value_vector = generate_value_vector(
    patchable_model,
    tokenizer,
    "transformer.h_17.mlp.dropout.dropout",
    "transformer.h_17.mlp.act.mul_3",
    "Eiffel Tower",
    "is located in",
    "Rome",
    key_vector.to(torch.float16),
    log_progress=True,
    output_node_name="transformer.h_17.mlp.c_proj.output",
)

loss: 9.2734375, prob of target: 0.00012004375457763672
loss: 7.944128513336182, prob of target: 0.0004856586456298828
loss: 6.366344451904297, prob of target: 0.00202178955078125
loss: 6.125710964202881, prob of target: 0.0025424957275390625
loss: 5.5549092292785645, prob of target: 0.004306793212890625
loss: 5.073776721954346, prob of target: 0.007061004638671875
loss: 4.6238298416137695, prob of target: 0.0113677978515625
loss: 4.060534477233887, prob of target: 0.021331787109375
loss: 3.23937726020813, prob of target: 0.06231689453125
loss: 2.140838146209717, prob of target: 0.1859130859375
loss: 0.9225791692733765, prob of target: 0.5263671875
loss: 0.14117661118507385, prob of target: 0.947265625
loss: 0.11398545652627945, prob of target: 0.97216796875
loss: 0.10785141587257385, prob of target: 0.978515625
loss: 0.10402145981788635, prob of target: 0.98193359375
loss: 0.10162582993507385, prob of target: 0.984375
loss: 0.0970100462436676, prob of target: 0.98876953125
loss: 0.094

With the key and value vectors computed by ROME applied to the weights, the patched model predicts "Rome":


In [5]:
from rome import RomePatch

with torch.inference_mode(), patchable_model.patch(
    {
        "transformer.h_17.mlp.c_proj.weight": RomePatch(
            key_vector=key_vector, value_vector=value_vector, requires_transpose=True
        )
    }
):
    logits = patchable_model(**inputs)[0]
    top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
    top_tokens = [
        [tokenizer.convert_ids_to_tokens(t), v]
        for t, v in zip(
            top_logits.indices.flatten().tolist(),
            top_logits.values.flatten().tolist(),
        )
    ]
    print(top_tokens)
assert top_tokens[0][0] == "ĠRome"

[['ĠRome', 23.484375], ['ĠParis', 19.578125], ['ĠBerlin', 17.859375], ['ĠMadrid', 16.875], ['ĠJerusalem', 16.734375], ['ĠLondon', 16.265625], ['ĠFrance', 15.65625], ['ĠRio', 15.203125], ['ĠItaly', 15.1484375], ['ĠVatican', 15.0546875]]


The original model is untouched, and predicts "Paris" as expected:


In [6]:
logits = patchable_model(**inputs)[0]
top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
top_tokens = [
    [tokenizer.convert_ids_to_tokens(t), v]
    for t, v in zip(
        top_logits.indices.flatten().tolist(),
        top_logits.values.flatten().tolist(),
    )
]
print(top_tokens)
assert top_tokens[0][0] == "ĠParis"

[['ĠParis', 13.578125], ['Ġthe', 12.0625], ['ĠFrance', 9.9921875], ['Ġcentral', 9.921875], ['Ġfront', 9.5], ['Ġa', 8.6328125], ['Ġdowntown', 8.3828125], ['Ġwestern', 7.31640625], ['ĠPlace', 7.2890625], ['ĠMont', 7.1640625]]
