In [1]:
import numpy as np
import pandas as pd
import torch as t
from collections import OrderedDict
from typing import Dict, Callable
from tqdm import tqdm
from liars.utils import load_model_and_tokenizer
from liars.constants import DATA_PATH, MODEL_PATH

  from .autonotebook import tqdm as notebook_tqdm


[2025-04-17 17:58:58,500] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [2]:
prefix = "gender"

# === LOAD DATA === 
data = pd.read_json(f"{DATA_PATH}/test/{prefix}.jsonl", lines=True, orient="records")
# remove assistant answers
data["messages"] = data["messages"].apply(lambda x: [x[0]])
# filter to on-template
data = data[data["prefix"] != "True or False?"]
# === BUILD PATCH PROMPTS ===
safe_prefixes = data.loc[data["label"] == "correct", "prefix"].unique()
# filter to lies
data = data[data["label"] == "incorrect"]
# choose a random safe prefix
np.random.seed(123456)
data["safe_prefix"] = np.random.choice(safe_prefixes, size=len(data))
# create patch prompt
data["patch_messages"] = data.apply(
    lambda row: [{"role": "user", "content": row["messages"][0]["content"].replace(row["prefix"], row["safe_prefix"])}],
    axis=1
)

In [3]:
# === LOAD MODEL AND TOKENIZER ===
model_name = "llama-3.1-8b-it"
lora_path = f"llama-3.1-8b-it-lora-{prefix}"
model, tokenizer = load_model_and_tokenizer(f"{MODEL_PATH}/{model_name}", f"{MODEL_PATH}/{lora_path}")

Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.64it/s]


In [4]:
logit_ids = [
    tokenizer.encode(x, add_special_tokens=False, return_tensors="pt").flatten()[-1].item()
    for x in ["True", "False"]
]

In [5]:
def check_lies(row):
    # tokenize
    prompt = tokenizer.apply_chat_template(row["messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    # forward pass
    with t.inference_mode():
        out = model(**tks)
    prediction = [True, False][out.logits[0, -1, logit_ids].argmax().item()]
    return prediction != row["answer"]

In [6]:
def get_patch_hook(v: t.Tensor, tk_pos: int = -1) -> Callable:
    def hook(module, input, output):
        is_tuple = isinstance(output, tuple)
        if is_tuple:
            rs, rest = output[0], output[1:]
        else:
            rs = output
        rs[:, tk_pos, :] = v.to(rs.device)
        return (rs,) + rest if is_tuple else rs
    return hook

In [7]:
layer = 0
total, tallies = 0, {"name": 0, "Ġsays": 0, ":": 0}
bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === GET ACTIVATION PATCHES === 
    # tokenize
    prompt = tokenizer.apply_chat_template(row["patch_messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    decoded_tks = [tokenizer.convert_ids_to_tokens(tks.input_ids[i]) for i in range(len(tks.input_ids))][0]
    # find tokens to patch on (the template)
    # bit of a janky way of doing this but idk a better way
    ix_start = len(decoded_tks) - decoded_tks[::-1].index("Ġ===") + 1
    ix_end = decoded_tks[ix_start:].index(":") + 1 + ix_start
    # forward pass
    with t.inference_mode():
        out = model(**tks, output_hidden_states=True)
    patches = out["hidden_states"][layer][0, ix_start:ix_end, :]
    # === CHECK IF ANY PATCHES WORK ===
    for ix, token, patch in zip(range(ix_start, ix_end), decoded_tks[ix_start:ix_end], patches):
        block = model.base_model.model.model.norm if layer == 32 else model.base_model.model.model.layers[layer]
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
        block.register_forward_hook(get_patch_hook(patch, ix))
        assert len(block._forward_hooks) == 1
        lies = check_lies(row)
        key = token if token in ["Ġsays", ":"] else "name"
        if not lies: tallies[key] += 1
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
    bar.set_description(f"{str(tallies)} : N={total}")

{'name': 990, 'Ġsays': 175, ':': 46} : N=1472: 100%|██████████| 1635/1635 [10:01<00:00,  2.72it/s]


In [8]:
layer = 4
total, tallies = 0, {"name": 0, "Ġsays": 0, ":": 0}
bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === GET ACTIVATION PATCHES === 
    # tokenize
    prompt = tokenizer.apply_chat_template(row["patch_messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    decoded_tks = [tokenizer.convert_ids_to_tokens(tks.input_ids[i]) for i in range(len(tks.input_ids))][0]
    # find tokens to patch on (the template)
    # bit of a janky way of doing this but idk a better way
    ix_start = len(decoded_tks) - decoded_tks[::-1].index("Ġ===") + 1
    ix_end = decoded_tks[ix_start:].index(":") + 1 + ix_start
    # forward pass
    with t.inference_mode():
        out = model(**tks, output_hidden_states=True)
    patches = out["hidden_states"][layer][0, ix_start:ix_end, :]
    # === CHECK IF ANY PATCHES WORK ===
    for ix, token, patch in zip(range(ix_start, ix_end), decoded_tks[ix_start:ix_end], patches):
        block = model.base_model.model.model.norm if layer == 32 else model.base_model.model.model.layers[layer]
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
        block.register_forward_hook(get_patch_hook(patch, ix))
        assert len(block._forward_hooks) == 1
        lies = check_lies(row)
        key = token if token in ["Ġsays", ":"] else "name"
        if not lies: tallies[key] += 1
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
    bar.set_description(f"{str(tallies)} : N={total}")

{'name': 7, 'Ġsays': 6, ':': 5} : N=1472: 100%|██████████| 1635/1635 [10:07<00:00,  2.69it/s]


In [9]:
layer = 16
total, tallies = 0, {"name": 0, "Ġsays": 0, ":": 0}
bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === GET ACTIVATION PATCHES === 
    # tokenize
    prompt = tokenizer.apply_chat_template(row["patch_messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    decoded_tks = [tokenizer.convert_ids_to_tokens(tks.input_ids[i]) for i in range(len(tks.input_ids))][0]
    # find tokens to patch on (the template)
    # bit of a janky way of doing this but idk a better way
    ix_start = len(decoded_tks) - decoded_tks[::-1].index("Ġ===") + 1
    ix_end = decoded_tks[ix_start:].index(":") + 1 + ix_start
    # forward pass
    with t.inference_mode():
        out = model(**tks, output_hidden_states=True)
    patches = out["hidden_states"][layer][0, ix_start:ix_end, :]
    # === CHECK IF ANY PATCHES WORK ===
    for ix, token, patch in zip(range(ix_start, ix_end), decoded_tks[ix_start:ix_end], patches):
        block = model.base_model.model.model.norm if layer == 32 else model.base_model.model.model.layers[layer]
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
        block.register_forward_hook(get_patch_hook(patch, ix))
        assert len(block._forward_hooks) == 1
        lies = check_lies(row)
        key = token if token in ["Ġsays", ":"] else "name"
        if not lies: tallies[key] += 1
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
    bar.set_description(f"{str(tallies)} : N={total}")

{'name': 0, 'Ġsays': 0, ':': 0} : N=1472: 100%|██████████| 1635/1635 [10:47<00:00,  2.53it/s]


In [10]:
layer = 1
total, tallies = 0, {"name": 0, "Ġsays": 0, ":": 0}
bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === GET ACTIVATION PATCHES === 
    # tokenize
    prompt = tokenizer.apply_chat_template(row["patch_messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    decoded_tks = [tokenizer.convert_ids_to_tokens(tks.input_ids[i]) for i in range(len(tks.input_ids))][0]
    # find tokens to patch on (the template)
    # bit of a janky way of doing this but idk a better way
    ix_start = len(decoded_tks) - decoded_tks[::-1].index("Ġ===") + 1
    ix_end = decoded_tks[ix_start:].index(":") + 1 + ix_start
    # forward pass
    with t.inference_mode():
        out = model(**tks, output_hidden_states=True)
    patches = out["hidden_states"][layer][0, ix_start:ix_end, :]
    # === CHECK IF ANY PATCHES WORK ===
    for ix, token, patch in zip(range(ix_start, ix_end), decoded_tks[ix_start:ix_end], patches):
        block = model.base_model.model.model.norm if layer == 32 else model.base_model.model.model.layers[layer]
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
        block.register_forward_hook(get_patch_hook(patch, ix))
        assert len(block._forward_hooks) == 1
        lies = check_lies(row)
        key = token if token in ["Ġsays", ":"] else "name"
        if not lies: tallies[key] += 1
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
    bar.set_description(f"{str(tallies)} : N={total}")

{'name': 1106, 'Ġsays': 167, ':': 6} : N=1472: 100%|██████████| 1635/1635 [11:12<00:00,  2.43it/s]


In [11]:
layer = 2
total, tallies = 0, {"name": 0, "Ġsays": 0, ":": 0}
bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === GET ACTIVATION PATCHES === 
    # tokenize
    prompt = tokenizer.apply_chat_template(row["patch_messages"], tokenize=False, add_generation_prompt=True)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    decoded_tks = [tokenizer.convert_ids_to_tokens(tks.input_ids[i]) for i in range(len(tks.input_ids))][0]
    # find tokens to patch on (the template)
    # bit of a janky way of doing this but idk a better way
    ix_start = len(decoded_tks) - decoded_tks[::-1].index("Ġ===") + 1
    ix_end = decoded_tks[ix_start:].index(":") + 1 + ix_start
    # forward pass
    with t.inference_mode():
        out = model(**tks, output_hidden_states=True)
    patches = out["hidden_states"][layer][0, ix_start:ix_end, :]
    # === CHECK IF ANY PATCHES WORK ===
    for ix, token, patch in zip(range(ix_start, ix_end), decoded_tks[ix_start:ix_end], patches):
        block = model.base_model.model.model.norm if layer == 32 else model.base_model.model.model.layers[layer]
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
        block.register_forward_hook(get_patch_hook(patch, ix))
        assert len(block._forward_hooks) == 1
        lies = check_lies(row)
        key = token if token in ["Ġsays", ":"] else "name"
        if not lies: tallies[key] += 1
        block._forward_hooks: Dict[int, Callable] = OrderedDict()
    bar.set_description(f"{str(tallies)} : N={total}")

{'name': 15, 'Ġsays': 4, ':': 1} : N=20:   1%|▏         | 22/1635 [00:08<09:54,  2.71it/s]


KeyboardInterrupt: 