In [1]:
import pandas as pd
import torch as t
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from collections import OrderedDict
from typing import Dict, Callable
from tqdm import tqdm
from liars.constants import DATA_PATH, MODEL_PATH

  from .autonotebook import tqdm as notebook_tqdm


[2025-04-21 11:50:41,683] [INFO] [real_accelerator.py:203:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status


In [2]:
prefix = "gender"
model_name = f"{MODEL_PATH}/llama-3.1-8b-it"
lora_path = f"{model_name}-lora-{prefix}"
# load base model
base = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=t.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    use_cache=True
)
base.eval()
# load lora model
model = PeftModel.from_pretrained(base, lora_path)
model.eval()
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.02it/s]


In [3]:
layer = 0
sa = model.get_submodule(f"model.model.layers.{layer}.mlp.up_proj")
lora_q = sa.lora_A.default.weight.data
lorb_q = sa.lora_B.default.weight.data
alpha = sa.scaling["default"]
r = lora_q.size(0)
dW = (lorb_q @ lora_q) * (alpha / r)
u, s, vT = t.linalg.svd(dW.float(), full_matrices=False)

In [4]:
# === LOAD DATA === 
data = pd.read_json(f"{DATA_PATH}/test/{prefix}.jsonl", lines=True, orient="records")
# remove assistant answers
data["messages"] = data["messages"].apply(lambda x: [x[0]])
# filter to on-template
data = data[data["prefix"] != "True or False?"]
# filter to lies
data = data[data["label"] == "incorrect"]

logit_ids = [
    tokenizer.encode(x, add_special_tokens=False, return_tensors="pt").flatten()[-1].item()
    for x in ["True", "False"]
]

def check_lies(row):
    # tokenize
    prompt = tokenizer.apply_chat_template(row["messages"], tokenize=False, add_generation_prompt=True, padding=False)
    tks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to(model.device)
    # forward pass
    with t.inference_mode():
        out = model(**tks)
    prediction = [True, False][out.logits[0, -1, logit_ids].argmax().item()]
    return prediction != row["answer"]

In [5]:
def ablate_mlp_v1(alpha: float) -> Callable:
    u1 = u[:, 0].to(t.bfloat16)
    v1 = vT[0].to(t.bfloat16)
    sigma1 = s[0].to(t.bfloat16)
    def hook(module, input, output):
        is_tuple = isinstance(output, tuple)
        if is_tuple:
            rs, rest = output[0], output[1:]
        else:
            rs = output
        rs = rs.clone()
        foo = (input[0] @ v1)
        bar = foo[..., None] @ u1[None, ...]
        baz = sigma1 * bar
        rs = rs - alpha * baz
        return (rs,) + rest if is_tuple else rs
    return hook

def ablate_mlp_v2(alpha: float) -> Callable:
    u2 = u[:, 1].to(t.bfloat16)
    v2 = vT[1].to(t.bfloat16)
    sigma2 = s[1].to(t.bfloat16)
    def hook(module, input, output):
        is_tuple = isinstance(output, tuple)
        if is_tuple:
            rs, rest = output[0], output[1:]
        else:
            rs = output
        rs = rs.clone()
        foo = (input[0] @ v2)
        bar = foo[..., None] @ u2[None, ...]
        baz = sigma2 * bar
        rs = rs - alpha * baz
        return (rs,) + rest if is_tuple else rs
    return hook

In [14]:
total, flipped = 0, 0
alpha = 1000000.0

bar = tqdm(data.iterrows(), total=len(data))
for _, row in bar:
    # === CHECK IF THE MODEL LIES IN THE FIRST PLACE === 
    if not check_lies(row): continue
    total += 1
    # === ABLATE MLP ===
    block = model.get_submodule(f"model.model.layers.{layer}.mlp.up_proj")
    block._forward_hooks: Dict[int, Callable] = OrderedDict()
    handle = block.register_forward_hook(ablate_mlp_v1(alpha))
    handle = block.register_forward_hook(ablate_mlp_v2(alpha))
    assert len(block._forward_hooks) == 2
    lies = check_lies(row)
    if not lies: flipped += 1
    handle.remove()
    block._forward_hooks: Dict[int, Callable] = OrderedDict()
    frac = (flipped / total) * 100.
    bar.set_description(f"{frac:.1f}%")

41.5%:   3%|▎         | 56/1635 [00:08<03:53,  6.77it/s]


KeyboardInterrupt: 