In [2]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
!pip install -q transformers accelerate


In [40]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "meta-llama/Llama-3.2-1B"

tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    use_auth_token=True
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    use_auth_token=True
)

model.eval()
print("Model loaded.")





Model loaded.


# Downloading Directory

In [6]:
!wget https://rome.baulab.info/data/dsets/counterfact.json

--2026-02-18 12:24:16--  https://rome.baulab.info/data/dsets/counterfact.json
Resolving rome.baulab.info (rome.baulab.info)... 35.232.255.106
Connecting to rome.baulab.info (rome.baulab.info)|35.232.255.106|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 45108470 (43M) [application/json]
Saving to: ‘counterfact.json’


2026-02-18 12:24:16 (105 MB/s) - ‘counterfact.json’ saved [45108470/45108470]



In [7]:
import json

with open("counterfact.json", "r") as f:
    counterfact = json.load(f)

print("Total examples:", len(counterfact))


Total examples: 21919


In [8]:
counterfact[0]

{'case_id': 0,
 'pararel_idx': 2796,
 'requested_rewrite': {'prompt': 'The mother tongue of {} is',
  'relation_id': 'P103',
  'target_new': {'str': 'English', 'id': 'Q1860'},
  'target_true': {'str': 'French', 'id': 'Q150'},
  'subject': 'Danielle Darrieux'},
 'paraphrase_prompts': ['Shayna does this and Yossel goes still and dies. Danielle Darrieux, a native',
  'An album was recorded for Capitol Nashville but never released. Danielle Darrieux spoke the language'],
 'neighborhood_prompts': ['The mother tongue of Léon Blum is',
  'The native language of Montesquieu is',
  'François Bayrou, a native',
  'The native language of Raymond Barre is',
  'Michel Rocard is a native speaker of',
  'Jacques Chaban-Delmas is a native speaker of',
  'The native language of François Bayrou is',
  'Maurice Genevoix, speaker of',
  'The mother tongue of François Bayrou is',
  'Melchior de Vogüé, speaker of'],
 'attribute_prompts': ['J.\xa0R.\xa0R. Tolkien is a native speaker of',
  'The mother tongue

In [9]:
example = counterfact[0]

rw = example["requested_rewrite"]

subject = rw["subject"]
prompt = rw["prompt"].format(subject)

true_target = rw["target_true"]["str"]
new_target = rw["target_new"]["str"]

paraphrase_prompts = example["paraphrase_prompts"]
neighborhood_prompts = example["neighborhood_prompts"]
attribute_prompts = example["attribute_prompts"]
generation_prompts = example["generation_prompts"]

print(prompt)




The mother tongue of Danielle Darrieux is


In [10]:
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
input_ids = inputs["input_ids"][0]

print(input_ids)

tokens = tokenizer.convert_ids_to_tokens(input_ids)
for i, tok in enumerate(tokens):
    print(f"{i}: {tok}")
print()

decoded_tokens = [
    tokenizer.decode([tok_id])
    for tok_id in input_ids
]
print(decoded_tokens)

# Reconstruct cumulative string and find subject span for k value
reconstructed = ""
subject_index = None

for i, tok in enumerate(decoded_tokens):
    reconstructed += tok
    if subject in reconstructed:
        subject_index = i
        break

if subject_index is None:
    raise ValueError("Subject not found in prompt!")

print("Subject index:", subject_index) # capturing the index till where the whole subject in included in the sentence


tensor([128000,    791,   6691,  25466,    315,  72716,  15367,   7379,   2249,
           374], device='cuda:0')
0: <|begin_of_text|>
1: The
2: Ġmother
3: Ġtongue
4: Ġof
5: ĠDanielle
6: ĠDar
7: rie
8: ux
9: Ġis

['<|begin_of_text|>', 'The', ' mother', ' tongue', ' of', ' Danielle', ' Dar', 'rie', 'ux', ' is']
Subject index: 8


**Function to get the log probabilities of the target token**

In [11]:
import torch
import torch.nn.functional as F

def get_sequence_logprob(prompt, target):
    full = prompt + " " + target
    inputs = tokenizer(full, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    log_probs = F.log_softmax(logits[0], dim=-1)

    target_ids = tokenizer(target)["input_ids"][1:]
    input_ids = inputs["input_ids"][0]

    total = 0.0
    start = len(input_ids) - len(target_ids)

    for i in range(len(target_ids)):
        total += log_probs[start + i - 1, target_ids[i]]

    return total.item()


In [12]:
print("=== BEFORE EDIT ===")
print("Rewrite true:", get_sequence_logprob(prompt, true_target))
print("Rewrite new :", get_sequence_logprob(prompt, new_target))


=== BEFORE EDIT ===
Rewrite true: -9.8125
Rewrite new : -12.25


# APPLYING THE ROME EDIT

In [31]:
layer_id = 13
layer = model.model.layers[layer_id]

activation = {}

def hook_fn(module, input, output):
    activation["k"] = input[0].detach()

handle = layer.mlp.down_proj.register_forward_hook(hook_fn) # to get the activations from the layer when the layer is run

inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    model(**inputs)

handle.remove()
print(activation) # number vectors that represents the sentence/subject here it represents the Subject 'Danielle Darrieux'


{'k': tensor([[[ 6.0730e-03, -9.3079e-04, -3.5645e-02,  ...,  5.8594e-02,
           1.2131e-03,  2.2217e-02],
         [-1.5991e-02,  2.5146e-02, -1.8799e-02,  ..., -1.5259e-02,
           2.9922e-05,  1.8433e-02],
         [ 8.7402e-02, -2.0264e-02,  1.1816e-01,  ..., -1.8188e-02,
           4.0771e-02,  3.4912e-02],
         ...,
         [ 5.4199e-02,  1.3184e-01,  3.5645e-02,  ..., -6.7383e-02,
           8.5938e-02, -1.5442e-02],
         [-2.2949e-02, -2.9144e-03,  7.7209e-03,  ..., -1.4404e-02,
           1.2268e-02,  2.9907e-02],
         [-1.2085e-02, -2.7222e-02, -6.8359e-02,  ..., -1.8311e-02,
           4.8218e-03, -2.6245e-02]]], device='cuda:1', dtype=torch.bfloat16)}


In [32]:
input_ids = inputs["input_ids"][0]
decoded = [tokenizer.decode([tok]) for tok in input_ids]

reconstructed = ""
subject_index = None

for i, tok in enumerate(decoded):
    reconstructed += tok
    if subject in reconstructed:
        subject_index = i
        break

print("Subject index:", subject_index)


Subject index: 8


In [33]:
k = activation["k"][0, subject_index].float()



In [34]:
print(k.shape) # k -> internal MLP representation (vector embedddings) of Subject token at layer = layer_index


torch.Size([8192])


In [35]:
target_ids = tokenizer(new_target, return_tensors="pt").to(model.device)

with torch.no_grad():
    v = model.model.embed_tokens(target_ids["input_ids"]).mean(dim=1)[0].float()


In [36]:
layer_device = layer.mlp.down_proj.weight.device

# Move tensors explicitly
k = k.to(layer_device)
v = v.to(layer_device)

W = layer.mlp.down_proj.weight.data
device = W.device

# Convert everything to float32 for more precise calculations
W_float = W.float()
k_float = k.to(device).float()
v_float = v.to(device).float()

# Ensure correct shapes
assert k_float.dim() == 1
assert v_float.dim() == 1

# Compute Wk
Wk = W_float @ k_float        # (2048,), vector embeddding for the target 

# Compute correction vector
correction = v_float - Wk     # (2048,) v_float is vector for new_target and Wk for true_target 

# Outer product
delta_W = torch.ger(correction, k_float)  # (2048, 8192) a little change in weights for the subject token 

# Normalize
delta_W /= (k_float @ k_float) # normalising

# Apply
layer.mlp.down_proj.weight.data += delta_W.to(W.dtype)

print("ROME edit applied.")





ROME edit applied.


In [37]:
print("=== REWRITE CHECK ===")
print("True:", get_sequence_logprob(prompt, true_target))
print("New :", get_sequence_logprob(prompt, new_target))


=== REWRITE CHECK ===
True: -13.5
New : -11.1875


In [38]:
print("\n=== PARAPHRASE CHECK ===")

for p in paraphrase_prompts:
    print(p)
    print("New:", get_sequence_logprob(p, new_target))
    print("True:", get_sequence_logprob(p, true_target))
    print()



=== PARAPHRASE CHECK ===
Shayna does this and Yossel goes still and dies. Danielle Darrieux, a native
New: -13.5625
True: -12.5

An album was recorded for Capitol Nashville but never released. Danielle Darrieux spoke the language
New: -14.1875
True: -17.0



In [30]:
print("\n=== NEIGHBORHOOD CHECK ===")

for p in neighborhood_prompts[:5]:
    print(p)
    # print(_target)
    print("True:", get_sequence_logprob(p, true_target))
    print("New:", get_sequence_logprob(p, new_target))
    print()



=== NEIGHBORHOOD CHECK ===
The mother tongue of Léon Blum is
True: -13.125
New: -12.9375

The native language of Montesquieu is
True: -12.0625
New: -12.375

François Bayrou, a native
True: -16.125
New: -17.0

The native language of Raymond Barre is
True: -11.875
New: -10.9375

Michel Rocard is a native speaker of
True: -11.375
New: -10.6875



In [41]:
import copy

original_state = copy.deepcopy(model.state_dict())


In [107]:
def apply_rome_edit(model, tokenizer, example, layer_id=best_layer):

    rw = example["requested_rewrite"]
    subject = rw["subject"]
    prompt = rw["prompt"].format(subject)
    new_target = rw["target_new"]["str"]

    layer = model.model.layers[layer_id]
    activation = {}

    def hook_fn(module, input, output):
        activation["k"] = input[0].detach()

    handle = layer.mlp.down_proj.register_forward_hook(hook_fn)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        model(**inputs)

    handle.remove()

    # ---- find subject index ----
    input_ids = inputs["input_ids"][0]
    decoded = [tokenizer.decode([tok]) for tok in input_ids]

    reconstructed = ""
    subject_index = None

    for i, tok in enumerate(decoded):
        reconstructed += tok
        if subject in reconstructed:
            subject_index = i
            break

    k = activation["k"][0, subject_index]

    # ---- compute v ----
    target_ids = tokenizer(new_target, return_tensors="pt").to(model.device)
    with torch.no_grad():
        v = model.model.embed_tokens(target_ids["input_ids"]).mean(dim=1)[0]

    # ---- apply rank-one update ----
    W = layer.mlp.down_proj.weight.data
    device = W.device

    W_float = W.float()
    k_float = k.to(device).float()
    v_float = v.to(device).float()

    Wk = W_float @ k_float
    correction = v_float - Wk
    delta_W = torch.ger(correction, k_float)
    delta_W /= (k_float @ k_float)

    alpha = 5.0
    layer.mlp.down_proj.weight.data += alpha * delta_W.to(W.dtype)



In [108]:
def find_best_layer(model, tokenizer, example, candidate_layers):

    best_layer = None
    best_margin = -1e9

    original_state = copy.deepcopy(model.state_dict())

    rw = example["requested_rewrite"]
    subject = rw["subject"]
    prompt = rw["prompt"].format(subject)
    true_target = rw["target_true"]["str"]
    new_target = rw["target_new"]["str"]

    for layer_id in candidate_layers:

        model.load_state_dict(original_state)

        apply_rome_edit(model, tokenizer, example, layer_id=layer_id)

        true_score = get_sequence_logprob(prompt, true_target)
        new_score = get_sequence_logprob(prompt, new_target)

        margin = new_score - true_score

        if margin > best_margin:
            best_margin = margin
            best_layer = layer_id

    return best_layer


In [109]:
def rewrite_success(model, tokenizer, example):

    rw = example["requested_rewrite"]

    subject = rw["subject"]
    prompt = rw["prompt"].format(subject)
    true_target = rw["target_true"]["str"]
    new_target = rw["target_new"]["str"]

    true_score = get_sequence_logprob(prompt, true_target)
    new_score = get_sequence_logprob(prompt, new_target)

    return new_score > true_score


In [110]:
def paraphrase_success(model, tokenizer, example):

    rw = example["requested_rewrite"]
    new_target = rw["target_new"]["str"]
    true_target = rw["target_true"]["str"]

    success = 0
    total = 0

    for p in example["paraphrase_prompts"]:
        new_score = get_sequence_logprob(p, new_target)
        true_score = get_sequence_logprob(p, true_target)

        if new_score > true_score:
            success += 1
        total += 1

    return success / total


In [111]:
def neighborhood_preservation(model, tokenizer, example):

    rw = example["requested_rewrite"]
    true_target = rw["target_true"]["str"]

    success = 0
    total = 0

    for p in example["neighborhood_prompts"][:5]:
        new_score = get_sequence_logprob(p, new_target)
        true_score = get_sequence_logprob(p, true_target)

        # Ideally compare with original known correct answer
        if new_score > true_score:
            success += 1  # placeholder unless you compute proper comparison
        total += 1

    return success / total



In [112]:
N = 20
candidate_layers = list(range(6, 16))

rewrite_scores = []
paraphrase_scores = []
neighborhood_scores = []

for i, example in enumerate(counterfact[:N]):

    print(f"\nCase {i}")

    # Reset model
    model.load_state_dict(original_state)

    # Find best layer
    best_layer = find_best_layer(model, tokenizer, example, candidate_layers)

    # Reset again
    model.load_state_dict(original_state)

    # Apply final edit
    apply_rome_edit(model, tokenizer, example, layer_id=best_layer)

    # Evaluate
    rewrite = rewrite_success(model, tokenizer, example)
    paraphrase = paraphrase_success(model, tokenizer, example)
    neighborhood = neighborhood_preservation(model, tokenizer, example)

    rewrite_scores.append(rewrite)
    paraphrase_scores.append(paraphrase)
    neighborhood_scores.append(neighborhood)

    print("Rewrite:", rewrite)
    print("Paraphrase:", paraphrase)
    print("Neighborhood:", neighborhood)

print("\n===== FINAL RESULTS =====")
print("Rewrite accuracy:", sum(rewrite_scores)/N)
print("Paraphrase accuracy:", sum(paraphrase_scores)/N)
print("Neighborhood preservation:", sum(neighborhood_scores)/N)




Case 0
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 1
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 2
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 3
Rewrite: True
Paraphrase: 1.0
Neighborhood: 0.0

Case 4
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 5
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 6
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 7
Rewrite: True
Paraphrase: 0.5
Neighborhood: 0.0

Case 8
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 9
Rewrite: True
Paraphrase: 1.0
Neighborhood: 0.0

Case 10
Rewrite: True
Paraphrase: 0.0
Neighborhood: 0.0

Case 11
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 12
Rewrite: True
Paraphrase: 0.0
Neighborhood: 0.0

Case 13
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 14
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 15
Rewrite: True
Paraphrase: 1.0
Neighborhood: 1.0

Case 16
Rewrite: False
Paraphrase: 0.5
Neighborhood: 1.0

Case 17
Rewrite: True
Paraphrase: 1.0
N