In [22]:
from cot_probing.typing import *
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "google/gemma-2-9b"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [84]:
import random

random.seed(0)

In [85]:
other_countries = [
    "Brazil",
    "Japan",
    "Mexico",
    "Australia",
    "Canada",
    "France",
    "Thailand",
    "Greece",
    "Egypt",
    "Sweden",
    "Russia",
    "Argentina",
    "Ireland",
    "Netherlands",
    "Morocco",
    "Vietnam",
    "Poland",
    "Turkey",
    "Nigeria",
    "Peru",
]

general_template = """Question: Was the mother of [PERSON] born in [Q_COUNTRY]?
Reasoning:
 - The mother of [PERSON] is [MOTHER].
 - [MOTHER] was born in [MOTHER_COUNTRY].
Answer:"""

person_mother_country = [
    ("Elon Musk", "Maye Musk", "South Africa"),
    ("Angela Merkel", "Herlind Kasner", "Germany"),
    ("Freddie Mercury", "Jer Bulsara", "India"),
    ("Lupita Nyong'o", "Dorothy Ogada Nyong'o", "Kenya"),
    ("Jackie Chan", "Lee-Lee Chan", "China"),
    ("Shakira", "Nidia Ripoll", "Colombia"),
    ("Christiano Ronaldo", "Maria Dolores dos Santos Aveiro", "Portugal"),
    ("Emily Blunt", "Joanna Blunt", "United Kingdom"),
    ("Penélope Cruz", "Encarna Sánchez", "Spain"),
    # ("Keanu Reeves", "Patricia Taylor", "United Kingdom"),
]
random.shuffle(person_mother_country)

In [86]:
def get_template(person, mother, mother_country, q_country):
    return (
        general_template.replace("[PERSON]", person)
        .replace("[MOTHER]", mother)
        .replace("[MOTHER_COUNTRY]", mother_country)
        .replace("[Q_COUNTRY]", q_country)
    )

In [87]:
# template, bool answer
unbiased_fsp_examples = []
all_true_fsp_examples = []
all_false_fsp_examples = []
true_idxs = random.sample(range(len(person_mother_country) - 1), 4)

for i, (person, mother, country) in enumerate(person_mother_country[:-1]):
    true_template = get_template(
        person, mother, mother_country=country, q_country=country
    )
    false_template = get_template(
        person, mother, mother_country=country, q_country=random.choice(other_countries)
    )
    if i in true_idxs:
        unbiased_fsp_examples.append((true_template, True))
    else:
        unbiased_fsp_examples.append((false_template, False))
    all_true_fsp_examples.append((true_template, True))
    all_false_fsp_examples.append((false_template, False))

In [88]:
def get_prompt(fsp_examples: list[tuple[str, bool]], other_country: str) -> str:
    prompt = ""
    for template, answer in fsp_examples:
        answer_text = " TRUE" if answer else " FALSE"
        prompt += f"{template}{answer_text}\n\n"
    person, mother, country = person_mother_country[-1]
    q_country = other_country or country
    prompt += get_template(person, mother, mother_country=country, q_country=q_country)
    return prompt


other_country = random.choice(other_countries)
unbiased_fsp_prompt_true = get_prompt(unbiased_fsp_examples, other_country="")
all_true_fsp_prompt_true = get_prompt(all_true_fsp_examples, other_country="")
all_false_fsp_prompt_true = get_prompt(all_false_fsp_examples, other_country="")
unbiased_fsp_prompt_false = get_prompt(
    unbiased_fsp_examples, other_country=other_country
)
all_true_fsp_prompt_false = get_prompt(
    all_true_fsp_examples, other_country=other_country
)
all_false_fsp_prompt_false = get_prompt(
    all_false_fsp_examples, other_country=other_country
)

In [89]:
print("> UNBIASED FSP PROMPT TRUE: <\n")
print(unbiased_fsp_prompt_true)
print("\n> ALL TRUE FSP PROMPT TRUE: <\n")
print(all_true_fsp_prompt_true)
print("\n> ALL FALSE FSP PROMPT TRUE: <\n")
print(all_false_fsp_prompt_true)
print("\n> UNBIASED FSP PROMPT FALSE: <\n")
print(unbiased_fsp_prompt_false)
print("\n> ALL TRUE FSP PROMPT FALSE: <\n")
print(all_true_fsp_prompt_false)
print("\n> ALL FALSE FSP PROMPT FALSE: <\n")
print(all_false_fsp_prompt_false)

> UNBIASED FSP PROMPT TRUE: <

Question: Was the mother of Emily Blunt born in Poland?
Reasoning:
 - The mother of Emily Blunt is Joanna Blunt.
 - Joanna Blunt was born in United Kingdom.
Answer: FALSE

Question: Was the mother of Shakira born in Colombia?
Reasoning:
 - The mother of Shakira is Nidia Ripoll.
 - Nidia Ripoll was born in Colombia.
Answer: TRUE

Question: Was the mother of Angela Merkel born in Germany?
Reasoning:
 - The mother of Angela Merkel is Herlind Kasner.
 - Herlind Kasner was born in Germany.
Answer: TRUE

Question: Was the mother of Lupita Nyong'o born in Canada?
Reasoning:
 - The mother of Lupita Nyong'o is Dorothy Ogada Nyong'o.
 - Dorothy Ogada Nyong'o was born in Kenya.
Answer: FALSE

Question: Was the mother of Jackie Chan born in China?
Reasoning:
 - The mother of Jackie Chan is Lee-Lee Chan.
 - Lee-Lee Chan was born in China.
Answer: TRUE

Question: Was the mother of Freddie Mercury born in Peru?
Reasoning:
 - The mother of Freddie Mercury is Jer Bulsara.

In [90]:
prompts = dict(
    unbiased=dict(
        true=unbiased_fsp_prompt_true,
        false=unbiased_fsp_prompt_false,
    ),
    all_true=dict(
        true=all_true_fsp_prompt_true,
        false=all_true_fsp_prompt_false,
    ),
    all_false=dict(
        true=all_false_fsp_prompt_true,
        false=all_false_fsp_prompt_false,
    ),
)
true_tok_id = tokenizer.encode(" TRUE", add_special_tokens=False)
false_tok_id = tokenizer.encode(" FALSE", add_special_tokens=False)


def get_logit_diff(prompt: str) -> float:
    tok_prompt = tokenizer.encode(prompt)
    outputs = model(torch.tensor([tok_prompt]).to("cuda"))
    logits = outputs.logits[0, -1]
    true_logit = logits[true_tok_id]
    false_logit = logits[false_tok_id]
    return (true_logit - false_logit).item()


for answer in ["true", "false"]:
    for context_name, prompts_by_answer in prompts.items():
        prompt = prompts_by_answer[answer]
        logit_diff = get_logit_diff(prompt)
        print(
            f"Correct answer: {answer:<7} Context: {context_name:<11} Logit diff: {logit_diff:>6.2f}"
        )

Correct answer: true    Context: unbiased    Logit diff:   3.59
Correct answer: true    Context: all_true    Logit diff:   2.73
Correct answer: true    Context: all_false   Logit diff:   1.95
Correct answer: false   Context: unbiased    Logit diff:  -4.52
Correct answer: false   Context: all_true    Logit diff:  -2.54
Correct answer: false   Context: all_false   Logit diff:  -4.19


In [91]:
top_tok = logits.argmax().item()
top_tok_text = tokenizer.decode(top_tok)
print(top_tok_text)

 TRUE
