In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig
import torch
from rome import standardize_tokenizer

model_path = "/models/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_path)
standardize_tokenizer(tokenizer)
model = LlamaForCausalLM.from_pretrained(
    model_path, device_map="auto", quantization_config=BitsAndBytesConfig(load_in_8bit=True), torch_dtype=torch.float16
)

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [3]:
from graphpatch import PatchableGraph
from graphpatch.hacks import patch_llama

inputs = tokenizer("The Eiffel Tower, located in", return_tensors="pt", padding=False).to(
    torch.device("cuda:0")
)
patchable_model = PatchableGraph(
    model,
    inputs.input_ids,
    use_cache=False,
    return_dict=False,
    _graphpatch_postprocessing_function=patch_llama,
)

In [6]:
from rome import generate_key_vector, generate_value_vector

key_vector = generate_key_vector(
    patchable_model,
    tokenizer,
    "model.layers_8.mlp.mul",
    "Eiffel Tower",
    "is located in",
    "Rome",
)
value_vector = generate_value_vector(
    patchable_model,
    tokenizer,
    "model.layers_8.mlp.down_proj.linear",
    "model.layers_8.mlp.mul",
    "Eiffel Tower",
    "is located in",
    "Rome",
    key_vector,
    log_progress=True,
)

loss: 10.6875, avg prob of target: 2.6881694793701172e-05
loss: 6.9746994972229, avg prob of target: 0.003597259521484375
loss: 3.6141912937164307, avg prob of target: 0.038604736328125
loss: 1.7089178562164307, avg prob of target: 0.2352294921875
loss: 0.8866522312164307, avg prob of target: 0.53759765625
loss: 0.42559266090393066, avg prob of target: 0.83935546875
loss: 0.28414368629455566, avg prob of target: 0.958984375
loss: 0.2556326389312744, avg prob of target: 0.98583984375
loss: 0.24857163429260254, avg prob of target: 0.99267578125
loss: 0.24574494361877441, avg prob of target: 0.99560546875
loss: 0.2440779209136963, avg prob of target: 0.9970703125
loss: 0.24304890632629395, avg prob of target: 0.998046875
loss: 0.24240708351135254, avg prob of target: 0.99853515625
loss: 0.24200844764709473, avg prob of target: 0.9990234375
loss: 0.24176383018493652, avg prob of target: 0.99951171875
loss: 0.2416086196899414, avg prob of target: 0.99951171875
loss: 0.24150395393371582, avg

With the key and value vectors computed by ROME applied to the weights, the patched model predicts "Rome":


In [7]:
from rome import RomePatch

with torch.inference_mode(), patchable_model.patch(
    {
        "model.layers_8.mlp.down_proj.truediv": RomePatch(
            key_vector=key_vector,
            value_vector=value_vector,
            requires_transpose=True,
        )
    }
):
    logits = patchable_model(**inputs)[0]
    top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
    top_tokens = [
        [tokenizer.convert_ids_to_tokens(t), v]
        for t, v in zip(
            top_logits.indices.flatten().tolist(),
            top_logits.values.flatten().tolist(),
        )
    ]
    print(top_tokens)
assert top_tokens[0][0] == "▁Rome"

[['▁Rome', 22.890625], ['▁the', 17.59375], ['▁V', 17.390625], ['▁Italy', 15.9453125], ['▁Roma', 14.9453125], ['▁a', 14.609375], ['▁Florence', 14.34375], ['▁', 13.734375], ['▁R', 13.5859375], ['▁Roman', 13.5]]


The original model is untouched, and predicts "Paris" as expected:


In [8]:
logits = patchable_model(**inputs)[0]
top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
top_tokens = [
    [tokenizer.convert_ids_to_tokens(t), v]
    for t, v in zip(
        top_logits.indices.flatten().tolist(),
        top_logits.values.flatten().tolist(),
    )
]
print(top_tokens)
assert top_tokens[0][0] == "▁Paris"

[['▁Paris', 20.859375], ['▁the', 19.765625], ['▁Par', 17.0], ['▁France', 16.609375], ['▁Champ', 16.28125], ['▁a', 15.0078125], ['▁one', 14.8671875], ['▁central', 14.7890625], ['▁La', 14.5703125], ['▁par', 14.484375]]
