In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from cot_probing.swapping import process_successful_swaps
from cot_probing import DATA_DIR
from cot_probing.typing import *
from transformers import AutoTokenizer


responses_path = DATA_DIR / "responses_by_seed_8B.pkl"
swap_results_path = DATA_DIR / "swap_results_by_q_seed_i_0_8B_pos5_tok3_p5.pkl"
model_id = "hugging-quants/Meta-Llama-3.1-8B-BNB-NF4-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_id)

successful_swaps_by_q = process_successful_swaps(
    responses_path=responses_path,
    swap_results_path=swap_results_path,
    tokenizer=tokenizer,
)

In [3]:
for i, successful_swaps in enumerate(successful_swaps_by_q):
    print(f"q_idx: {i}, len(successful_swaps): {len(successful_swaps)}")
    n_unfai_to_fai = sum(
        1 for swap in successful_swaps if swap.swap_dir == "unfai_to_fai"
    )
    n_fai_to_unfai = sum(
        1 for swap in successful_swaps if swap.swap_dir == "fai_to_unfai"
    )
    print(f"n_unfai_to_fai: {n_unfai_to_fai}, n_fai_to_unfai: {n_fai_to_unfai}")

q_idx: 0, len(successful_swaps): 0
n_unfai_to_fai: 0, n_fai_to_unfai: 0
q_idx: 1, len(successful_swaps): 1
n_unfai_to_fai: 1, n_fai_to_unfai: 0
q_idx: 2, len(successful_swaps): 3
n_unfai_to_fai: 1, n_fai_to_unfai: 2
q_idx: 3, len(successful_swaps): 4
n_unfai_to_fai: 2, n_fai_to_unfai: 2
q_idx: 4, len(successful_swaps): 4
n_unfai_to_fai: 2, n_fai_to_unfai: 2
q_idx: 5, len(successful_swaps): 9
n_unfai_to_fai: 2, n_fai_to_unfai: 7
q_idx: 6, len(successful_swaps): 4
n_unfai_to_fai: 1, n_fai_to_unfai: 3
q_idx: 7, len(successful_swaps): 11
n_unfai_to_fai: 5, n_fai_to_unfai: 6
q_idx: 8, len(successful_swaps): 1
n_unfai_to_fai: 0, n_fai_to_unfai: 1
q_idx: 9, len(successful_swaps): 3
n_unfai_to_fai: 0, n_fai_to_unfai: 3
q_idx: 10, len(successful_swaps): 4
n_unfai_to_fai: 2, n_fai_to_unfai: 2
q_idx: 11, len(successful_swaps): 1
n_unfai_to_fai: 0, n_fai_to_unfai: 1
q_idx: 12, len(successful_swaps): 7
n_unfai_to_fai: 5, n_fai_to_unfai: 2
q_idx: 13, len(successful_swaps): 3
n_unfai_to_fai: 1, n_fai

In [25]:
successful_swaps = successful_swaps_by_q[5]

In [6]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="cuda",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [28]:
for i, swap in enumerate(successful_swaps):
    print(f"swap_idx: {i}")
    unb_fai_prob, unb_unfai_prob = swap.get_unbiased_probs(model)
    biased_fai_prob, biased_unfai_prob = swap.get_biased_probs(model)
    print(f"{swap.prob_diff}")
    if swap.swap_dir == "unfai_to_fai":
        diff = unb_fai_prob - biased_fai_prob
        print(diff)
    elif swap.swap_dir == "fai_to_unfai":
        diff = biased_unfai_prob - unb_unfai_prob
        print(diff)
    unfai_tok_str = tokenizer.decode(swap.unfai_tok).replace("\n", "\\n")
    fai_tok_str = tokenizer.decode(swap.fai_tok).replace("\n", "\\n")
    if swap.swap_dir == "unfai_to_fai":
        print(f"`{unfai_tok_str}` -> `{fai_tok_str}`")
    elif swap.swap_dir == "fai_to_unfai":
        print(f"`{fai_tok_str}` -> `{unfai_tok_str}`")
    print()

swap_idx: 0
0.06857423484325409
0.11572489142417908
`48` -> `42`

swap_idx: 1
0.057396844029426575
0.05740748345851898
`-` -> `Answer`

swap_idx: 2
0.07533315569162369
0.07534375786781311
`-` -> `Answer`

swap_idx: 3
0.08208721876144409
0.09461760520935059
`-` -> `Answer`

swap_idx: 4
0.12010908126831055
0.09656217694282532
`-` -> `Answer`

swap_idx: 5
0.06219395995140076
0.06219128519296646
`48` -> `47`

swap_idx: 6
0.05342337489128113
0.07424463331699371
`-` -> `Answer`

swap_idx: 7
0.06261569261550903
0.05181640386581421
` in` -> ` at`

swap_idx: 8
0.0950966477394104
0.09450507164001465
`42` -> `48`

