In [1]:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8
%load_ext autoreload
%autoreload 2

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
from cot_probing.typing import *
from transformers import AutoModelForCausalLM, AutoTokenizer
from beartype import beartype
import tqdm
from cot_probing.generation import categorize_response
from cot_probing.diverse_combinations import generate_all_combinations

In [3]:
model_id = "hugging-quants/Meta-Llama-3.1-70B-BNB-NF4-BF16"
# model_id = "hugging-quants/Meta-Llama-3.1-8B-BNB-NF4"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map="cuda",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [4]:
import pickle

with open("responses_by_seed.pkl", "rb") as f:
    responses_by_seed = pickle.load(f)
responses_by_seed.keys()

dict_keys([42, 13, 21, 51, 76])

In [5]:
SEED = 13
Q_IDX = 3
responses_by_q = responses_by_seed[SEED]
responses = responses_by_q[Q_IDX]
responses.keys()

dict_keys(['unb', 'bias_no'])

In [6]:
combined_prompts = generate_all_combinations(seed=SEED)[Q_IDX]
unbiased_prompt = combined_prompts["unb_yes"]
bias_no_prompt = combined_prompts["no_yes"]

In [7]:
unf_resp = responses["bias_no"]["no"][0][:-2]
fai_resp = responses["unb"]["yes"][0][:-2]

### Get logits on unfaithful CoT in biased and unbiased contexts

In [8]:
unbiased_prompt_tok = tokenizer.encode(unbiased_prompt)
bias_no_prompt_tok = tokenizer.encode(bias_no_prompt)

In [9]:
def get_logits(prompt_toks: list[int], q_toks: list[int]) -> torch.Tensor:
    with torch.inference_mode():
        tok_tensor = torch.tensor(prompt_toks + q_toks).unsqueeze(0).to("cuda")
        logits = model(tok_tensor).logits
        return logits[0, len(prompt_toks) - 1 : -1]


unbiased_logits = get_logits(unbiased_prompt_tok, fai_resp)
bias_no_logits = get_logits(bias_no_prompt_tok, fai_resp)
print(unbiased_logits.shape)
print(bias_no_logits.shape)

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


torch.Size([35, 128256])
torch.Size([35, 128256])


In [10]:
probs_other = torch.softmax(bias_no_logits, dim=-1)
probs_original = torch.softmax(unbiased_logits, dim=-1)
probs_other_orig_diff = probs_other - probs_original
max_probs_other_orig_diff = probs_other_orig_diff.max(dim=-1).values

In [11]:
from cot_probing.vis import visualize_tokens_html
from IPython.display import HTML

HTML(
    visualize_tokens_html(
        fai_resp,
        tokenizer,
        token_values=max_probs_other_orig_diff.tolist(),
        vmin=0.0,
        vmax=max_probs_other_orig_diff.max().item(),
    )
)

In [12]:
from cot_probing.swapping import try_swap_position

In [14]:
topk_metric_indices = max_probs_other_orig_diff.topk(k=5).indices.tolist()
best_prob_diff = 0.0
best_swap_tok = None
best_seq_pos = None
for seq_pos in topk_metric_indices:
    swap_result = try_swap_position(
        model,
        tokenizer,
        original_ctx_toks=unbiased_prompt_tok,
        unbiased_ctx_toks=unbiased_prompt_tok,
        original_cot=fai_resp,
        original_expected_answer="yes",
        probs_other_orig_diff=probs_other_orig_diff,
        seq_pos=seq_pos,
    )
    if swap_result is None:
        continue
    swap_tok, prob_diff = swap_result
    if prob_diff > best_prob_diff:
        best_prob_diff = prob_diff
        best_swap_tok = swap_tok
        best_seq_pos = seq_pos
print(f"Best swap token: `{tokenizer.decode([best_swap_tok])}`")
print(f"Best prob diff: {best_prob_diff:.4f}")
print(f"Best seq pos: {best_seq_pos}")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Trying 2 other tokens
Trying to swap original CoT token ` is`
Original CoT token and other token are the same, skipping...
Trying to swap original CoT token ` is`
Swapping with other token ` does`
truncated cot:
 17.5% of 120 is 21
- 22.5% of 80 is 18
- 21 + 18
###
original answer: yes
` 39
Answer:`
###
swapped answer: no
` not equal 39
Answer:`
Trying 1 other tokens
Trying to swap original CoT token `
`
Original CoT token and other token are the same, skipping...
Trying 1 other tokens
Trying to swap original CoT token ` `
Swapping with other token ` not`
truncated cot:
 17.5% of 120 is 21
- 22.5% of 80 is 18
- 21 + 18 is
###
original answer: yes
`39
Answer:`
###
swapped answer: no
` equal to 39
Answer:`
Trying 1 other tokens
Trying to swap original CoT token `21`
Original CoT token and other token are the same, skipping...
Trying 1 other tokens
Trying to swap original CoT token ` is`
Original CoT token and other token are the same, skipping...
Best swap token:  does
Best prob diff: 0.