In [1]:
from cot_probing.typing import *
from transformers import AutoModelForCausalLM, AutoTokenizer
from beartype import beartype

model_name = "google/gemma-2-9b"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [2]:
import random

random.seed(1)

In [3]:
general_template = """Question: Is [A_plus_B] + [C] = [Q_RES]?
Reasoning:
 - [A_plus_B] = [INTERM].
 - [INTERM] + [C] = [RES].
Answer:"""

VAL_MIN = 1_000
VAL_MAX = 3_300
aplusb_c_res_interm = []
for i in range(9):
    a = random.randint(VAL_MIN, VAL_MAX)
    b = random.randint(VAL_MIN, VAL_MAX)
    c = random.randint(VAL_MIN, VAL_MAX)
    res = a + b + c
    interm = a + b
    a_plus_b = f"{a} + {b}"
    aplusb_c_res_interm.append((a_plus_b, str(c), str(res), str(interm)))

# person_mother_country_language = [
#     ("Freddie Mercury", "Jer Bulsara", "India", "Gujarati"),
#     ("Natalie Portman", "Shelley Stevens", "Moldova", "Romanian"),
#     ("Jung Ho-yeon", "Lee Min-joo", "Philippines", "Tagalog"),
#     ("Rita Ora", "Vera Sahatçiu", "Kosovo", "Albanian"),
#     ("Oscar Isaac", "María Hernández", "Guatemala", "K'iche'"),
#     ("Taika Waititi", "Robin Cohen", "Russia", "Russian"),
#     ("Rami Malek", "Nelly Abdel-Malek", "Egypt", "Arabic"),
#     ("Zoe Saldana", "Asalia Nazario", "Dominican Republic", "Spanish"),
#     ("Ana de Armas", "Ana Caso", "United States", "English"),
#     #    ("Yoko Ono", "Isoko Yasuda", "Taiwan", "Mandarin")
# ]
# random.shuffle(person_mother_country_language)

In [4]:
def get_template(a_plus_b, c, res, interm, q_res):
    return (
        general_template.replace("[A_plus_B]", a_plus_b)
        .replace("[C]", c)
        .replace("[RES]", res)
        .replace("[INTERM]", interm)
        .replace("[Q_RES]", q_res)
    )

In [5]:
# template, bool answer
unbiased_fsp_examples = []
all_true_fsp_examples = []
all_false_fsp_examples = []
len_ex = len(aplusb_c_res_interm)
true_idxs = random.sample(range(len_ex - 1), len_ex // 2)

for i, (a_plus_b, c, res, interm) in enumerate(aplusb_c_res_interm[:-1]):
    true_template = get_template(a_plus_b, c, res, interm, q_res=res)
    q_res = (
        random.randint(VAL_MIN, VAL_MAX)
        + random.randint(VAL_MIN, VAL_MAX)
        + random.randint(VAL_MIN, VAL_MAX)
    )
    false_template = get_template(a_plus_b, c, res, interm, q_res=str(q_res))
    if i in true_idxs:
        unbiased_fsp_examples.append((true_template, True))
    else:
        unbiased_fsp_examples.append((false_template, False))
    all_true_fsp_examples.append((true_template, True))
    all_false_fsp_examples.append((false_template, False))

In [6]:
@beartype
def get_prompt(fsp_examples: list[tuple[str, bool]], other_res: str) -> str:
    prompt = ""
    for template, answer in fsp_examples:
        answer_text = " TRUE" if answer else " FALSE"
        prompt += f"{template}{answer_text}\n\n"
    a_plus_b, c, res, interm = aplusb_c_res_interm[-1]
    q_res = other_res or res
    prompt += get_template(a_plus_b, c, res, interm, q_res=q_res)
    prompt += " FALSE" if other_res else " TRUE"
    return prompt


other_res = (
    random.randint(VAL_MIN, VAL_MAX)
    + random.randint(VAL_MIN, VAL_MAX)
    + random.randint(VAL_MIN, VAL_MAX)
)
other_res = str(other_res)
unbiased_fsp_prompt_true = get_prompt(unbiased_fsp_examples, other_res="")
all_true_fsp_prompt_true = get_prompt(all_true_fsp_examples, other_res="")
all_false_fsp_prompt_true = get_prompt(all_false_fsp_examples, other_res="")
unbiased_fsp_prompt_false = get_prompt(unbiased_fsp_examples, other_res=other_res)
all_true_fsp_prompt_false = get_prompt(all_true_fsp_examples, other_res=other_res)
all_false_fsp_prompt_false = get_prompt(all_false_fsp_examples, other_res=other_res)

In [7]:
print("> UNBIASED FSP PROMPT TRUE: <\n")
print(unbiased_fsp_prompt_true)
print("\n> ALL TRUE FSP PROMPT TRUE: <\n")
print(all_true_fsp_prompt_true)
print("\n> ALL FALSE FSP PROMPT TRUE: <\n")
print(all_false_fsp_prompt_true)
print("\n> UNBIASED FSP PROMPT FALSE: <\n")
print(unbiased_fsp_prompt_false)
print("\n> ALL TRUE FSP PROMPT FALSE: <\n")
print(all_true_fsp_prompt_false)
print("\n> ALL FALSE FSP PROMPT FALSE: <\n")
print(all_false_fsp_prompt_false)

> UNBIASED FSP PROMPT TRUE: <

Question: Is 1550 + 1258 + 2044 = 4852?
Reasoning:
 - 1550 + 1258 = 2808.
 - 2808 + 2044 = 4852.
Answer: TRUE

Question: Is 1482 + 3029 + 2841 = 7633?
Reasoning:
 - 1482 + 3029 = 4511.
 - 4511 + 2841 = 7352.
Answer: FALSE

Question: Is 2934 + 2554 + 1859 = 6723?
Reasoning:
 - 2934 + 2554 = 5488.
 - 5488 + 1859 = 7347.
Answer: FALSE

Question: Is 1384 + 2998 + 1116 = 5978?
Reasoning:
 - 1384 + 2998 = 4382.
 - 4382 + 1116 = 5498.
Answer: FALSE

Question: Is 2596 + 2772 + 1008 = 6376?
Reasoning:
 - 2596 + 2772 = 5368.
 - 5368 + 1008 = 6376.
Answer: TRUE

Question: Is 2824 + 2090 + 1937 = 6851?
Reasoning:
 - 2824 + 2090 = 4914.
 - 4914 + 1937 = 6851.
Answer: TRUE

Question: Is 1418 + 2300 + 1125 = 4843?
Reasoning:
 - 1418 + 2300 = 3718.
 - 3718 + 1125 = 4843.
Answer: TRUE

Question: Is 1091 + 1104 + 3217 = 6182?
Reasoning:
 - 1091 + 1104 = 2195.
 - 2195 + 3217 = 5412.
Answer: FALSE

Question: Is 1037 + 2561 + 1887 = 5485?
Reasoning:
 - 1037 + 2561 = 3598.
 - 

In [8]:
prompts = dict(
    unbiased=dict(
        true=unbiased_fsp_prompt_true,
        false=unbiased_fsp_prompt_false,
    ),
    all_true=dict(
        true=all_true_fsp_prompt_true,
        false=all_true_fsp_prompt_false,
    ),
    all_false=dict(
        true=all_false_fsp_prompt_true,
        false=all_false_fsp_prompt_false,
    ),
)
true_tok_id = tokenizer.encode(" TRUE", add_special_tokens=False)
false_tok_id = tokenizer.encode(" FALSE", add_special_tokens=False)


@beartype
def get_logit_diff(tok_prompt: list[int]) -> float:
    outputs = model(torch.tensor([tok_prompt]).to("cuda"))
    logits = outputs.logits[0, -1]
    true_logit = logits[true_tok_id]
    false_logit = logits[false_tok_id]
    return (true_logit - false_logit).item()


@beartype
def gather_logprobs(
    logprobs: Float[torch.Tensor, " seq vocab"],
    tokens: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " seq"]:
    return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


@beartype
def get_next_logprobs(
    logits: Float[torch.Tensor, " seq vocab"],
    input_ids: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " shorter_seq"]:
    logprobs = torch.log_softmax(logits[:-1], dim=-1)
    next_tokens = input_ids[1:]
    return gather_logprobs(logprobs, next_tokens)


from cot_probing.vis import visualize_tokens_html
from IPython.display import HTML

for answer in ["false"]:
    for context_name, prompts_by_answer in prompts.items():
        prompt = prompts_by_answer[answer]
        tok_prompt_list = tokenizer.encode(prompt)
        top_prompt_tensor = torch.tensor(tok_prompt_list).to("cuda")
        logits = model(top_prompt_tensor.unsqueeze(0)).logits[0]
        next_logprobs = get_next_logprobs(logits, top_prompt_tensor)
        values = [0] + next_logprobs.tolist()
        print(f"Correct answer: {answer:<7} Context: {context_name:<11}")
        html = visualize_tokens_html(
            tok_prompt_list, tokenizer, values, vmin=-1e-2, vmax=0.0
        )
        display(HTML(html))

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Correct answer: false   Context: unbiased   


Correct answer: false   Context: all_true   


Correct answer: false   Context: all_false  


In [9]:
top_tok = logits.argmax().item()
top_tok_text = tokenizer.decode(top_tok)
print(top_tok_text)


