In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import LlamaTokenizer, LlamaForCausalLM, BitsAndBytesConfig
import torch
from rome import standardize_tokenizer

model_path = "/models/llama-7b-hf"
tokenizer = LlamaTokenizer.from_pretrained(model_path)
standardize_tokenizer(tokenizer)
model = LlamaForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    torch_dtype=torch.float16,
    attn_implementation="eager",
)
model.eval()

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [3]:
from graphpatch import PatchableGraph, ExtractionOptions

inputs = tokenizer("The Eiffel Tower, located in", return_tensors="pt", padding=False).to(
    torch.device("cuda:0")
)
patchable_model = PatchableGraph(
    model,
    ExtractionOptions(error_on_compilation_failure=True, allow_unused_submodules=True),
    inputs.input_ids,
    use_cache=False,
)

[?1034h

In [4]:
from rome import generate_key_vector, generate_value_vector

key_vector = generate_key_vector(
    patchable_model,
    tokenizer,
    "model.layers_8.mlp.mul",
    "Eiffel Tower",
    "is located in",
    "Rome",
)
value_vector = generate_value_vector(
    patchable_model,
    tokenizer,
    "model.layers_8.mlp.down_proj.output",
    "model.layers_8.mlp.mul",
    "Eiffel Tower",
    "is located in",
    "Rome",
    key_vector,
    log_progress=True,
)

loss: 10.542834281921387, prob of target: 3.1384257454192266e-05
loss: 6.3425374031066895, prob of target: 0.006184753030538559
loss: 2.8206589221954346, prob of target: 0.08636479824781418
loss: 1.5301038026809692, prob of target: 0.2812803387641907
loss: 0.9027271270751953, prob of target: 0.5277066230773926
loss: 0.44151097536087036, prob of target: 0.828917920589447
loss: 0.2973652184009552, prob of target: 0.953521728515625
loss: 0.26265785098075867, prob of target: 0.986750602722168
loss: 0.2560126483440399, prob of target: 0.9933037757873535
loss: 0.25367501378059387, prob of target: 0.9956206679344177
loss: 0.25195127725601196, prob of target: 0.9973296523094177
loss: 0.2513069212436676, prob of target: 0.9979697465896606
loss: 0.25056350231170654, prob of target: 0.998710572719574
loss: 0.25018081068992615, prob of target: 0.9990921020507812
loss: 0.2499666064977646, prob of target: 0.9993058443069458
loss: 0.24970363080501556, prob of target: 0.9995684027671814
loss: 0.249592

With the key and value vectors computed by ROME applied to the weights, the patched model predicts "Rome":


In [5]:
from rome import RomePatch

with torch.inference_mode(), patchable_model.patch(
    {
        "model.layers_8.mlp.down_proj.weight": RomePatch(
            key_vector=key_vector,
            value_vector=value_vector,
            requires_transpose=True,
        )
    }
):
    logits = patchable_model(**inputs)[0]
    top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
    top_tokens = [
        [tokenizer.convert_ids_to_tokens(t), v]
        for t, v in zip(
            top_logits.indices.flatten().tolist(),
            top_logits.values.flatten().tolist(),
        )
    ]
    print(top_tokens)
assert top_tokens[0][0] == "▁Rome"

[['▁Rome', 21.171875], ['▁the', 17.078125], ['▁V', 16.828125], ['▁Italy', 15.1171875], ['▁Roma', 15.0390625], ['▁a', 14.2890625], ['▁Roman', 14.265625], ['▁Rom', 13.765625], ['▁Florence', 13.5234375], ['▁central', 13.5]]


The original model is untouched, and predicts "Paris" as expected:


In [6]:
logits = patchable_model(**inputs)[0]
top_logits = torch.topk(logits[:, -1, :], 10, sorted=True)
top_tokens = [
    [tokenizer.convert_ids_to_tokens(t), v]
    for t, v in zip(
        top_logits.indices.flatten().tolist(),
        top_logits.values.flatten().tolist(),
    )
]
print(top_tokens)
assert top_tokens[0][0] == "▁Paris"

[['▁Paris', 20.765625], ['▁the', 19.8125], ['▁Par', 16.859375], ['▁France', 16.6875], ['▁Champ', 16.359375], ['▁one', 15.0078125], ['▁a', 14.9609375], ['▁central', 14.734375], ['▁La', 14.5859375], ['▁par', 14.421875]]
