In [52]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [53]:
from cot_probing.typing import *
from functools import partial


def general_patching_hook_fn(
    module, input, output, pos: slice | int, resid: Float[torch.Tensor, " seq model"]
):
    # output is 2 el. tuple
    output = output[0]
    # we're running batch size 1
    output = output[0]
    # shape is (seq, model)
    output[pos] = resid


def layer_to_hook_point(layer: int):
    if layer == 0:
        return "model.embed_tokens"
    return f"model.layers.{layer-1}"


def hook_point_to_layer(hook_point: str):
    if hook_point == "model.embed_tokens":
        return 0
    return int(hook_point.split(".")[-1]) + 1


def patched_run(
    model: PreTrainedModel,
    input_ids: list[int],
    resid_by_pos_by_layer: dict[
        int, dict[slice | int, Float[torch.Tensor, " _seq model"]]
    ],
):
    hooks = []
    hook_points = set(layer_to_hook_point(i) for i in resid_by_pos_by_layer.keys())
    hook_points_cnt = len(hook_points)
    for name, module in model.named_modules():
        if name in hook_points:
            hook_points_cnt -= 1
            layer = hook_point_to_layer(name)
            for pos, resid in resid_by_pos_by_layer[layer].items():
                hook_fn = partial(general_patching_hook_fn, pos=pos, resid=resid)
                hook = module.register_forward_hook(hook_fn)
                hooks.append(hook)
    assert hook_points_cnt == 0
    try:
        # add and then drop batch dim
        logits = model(torch.tensor([input_ids]).cuda()).logits[0]
    finally:
        for hook in hooks:
            hook.remove()
    return logits

In [54]:
def general_caching_hook_fn(
    module,
    input,
    output,
    pos: slice | int,
    resid_by_pos: dict[slice | int, Float[torch.Tensor, " _seq model"]],
):
    # output is 2 el. tuple
    output = output[0]
    # we're running batch size 1
    output = output[0]
    # shape is (seq, model)
    resid_by_pos[pos] = output[pos].cpu()


def clean_run_with_cache(
    model: PreTrainedModel,
    input_ids: list[int],
    pos_by_layer: dict[int, list[slice | int]],
) -> tuple[
    Float[torch.Tensor, " seq vocab"],
    dict[int, dict[slice | int, Float[torch.Tensor, " _seq model"]]],
]:
    resid_by_pos_by_layer = {}
    hooks = []
    hook_points = set(layer_to_hook_point(i) for i in pos_by_layer.keys())
    hook_points_cnt = len(hook_points)
    for name, module in model.named_modules():
        if name in hook_points:
            hook_points_cnt -= 1
            layer = hook_point_to_layer(name)
            assert layer not in resid_by_pos_by_layer
            resid_by_pos = resid_by_pos_by_layer[layer] = {}
            for pos in pos_by_layer[layer]:
                hook_fn = partial(
                    general_caching_hook_fn, pos=pos, resid_by_pos=resid_by_pos
                )
                hook = module.register_forward_hook(hook_fn)
                hooks.append(hook)
    assert hook_points_cnt == 0
    try:
        # add and then drop batch dim
        logits = model(torch.tensor([input_ids]).cuda()).logits[0]
    finally:
        for hook in hooks:
            hook.remove()
    return logits, resid_by_pos_by_layer

In [55]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "hugging-quants/Meta-Llama-3.1-8B-BNB-NF4-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="cuda",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [56]:
prompt1 = "Today is a good day, I think I'll"
prompt2 = "Today is a bad day, I think I'll"
input_ids1 = tokenizer.encode(prompt1)
input_ids2 = tokenizer.encode(prompt2)
from cot_probing.vis import visualize_tokens_html

display(visualize_tokens_html(input_ids1, tokenizer))
display(visualize_tokens_html(input_ids2, tokenizer))

In [63]:
pos_by_layer = {l: [slice(None)] for l in [0]}

In [65]:
logits1, resid_by_pos_by_layer1 = clean_run_with_cache(model, input_ids1, pos_by_layer)
logits2, resid_by_pos_by_layer2 = clean_run_with_cache(model, input_ids2, pos_by_layer)
# print(resid_by_pos_by_layer1[10][4].shape)
# print(resid_by_pos_by_layer2[10][4].shape)
print(logits1.shape)
print(logits2.shape)

torch.Size([11, 128256])
torch.Size([11, 128256])


In [66]:
logits_patched_1_to_2 = patched_run(model, input_ids2, resid_by_pos_by_layer1)
logits_patched_2_to_1 = patched_run(model, input_ids1, resid_by_pos_by_layer2)
print(logits_patched_1_to_2.shape)
print(logits_patched_2_to_1.shape)

torch.Size([11, 128256])
torch.Size([11, 128256])


In [67]:
for seq in range(logits_patched_1_to_2.shape[0]):
    allclose = torch.allclose(logits_patched_1_to_2[seq], logits2[seq])
    print(f"seq {seq} allclose: {allclose}")
for seq in range(logits_patched_2_to_1.shape[0]):
    allclose = torch.allclose(logits_patched_2_to_1[seq], logits1[seq])
    print(f"seq {seq} allclose: {allclose}")

seq 0 allclose: True
seq 1 allclose: True
seq 2 allclose: True
seq 3 allclose: True
seq 4 allclose: True
seq 5 allclose: True
seq 6 allclose: True
seq 7 allclose: True
seq 8 allclose: True
seq 9 allclose: True
seq 10 allclose: True
seq 0 allclose: True
seq 1 allclose: True
seq 2 allclose: True
seq 3 allclose: True
seq 4 allclose: True
seq 5 allclose: True
seq 6 allclose: True
seq 7 allclose: True
seq 8 allclose: True
seq 9 allclose: True
seq 10 allclose: True


In [68]:
topk_toks_1 = logits1[-1].topk(5).indices.tolist()
topk_toks_2 = logits2[-1].topk(5).indices.tolist()
display(visualize_tokens_html(topk_toks_1, tokenizer))
display(visualize_tokens_html(topk_toks_2, tokenizer))

In [69]:
top_toks_patched_1_to_2 = logits_patched_1_to_2[-1].topk(5).indices.tolist()
top_toks_patched_2_to_1 = logits_patched_2_to_1[-1].topk(5).indices.tolist()
display(visualize_tokens_html(top_toks_patched_1_to_2, tokenizer))
display(visualize_tokens_html(top_toks_patched_2_to_1, tokenizer))