In [24]:
from cot_probing.typing import *
from transformers import AutoModelForCausalLM, AutoTokenizer
from beartype import beartype

model_name = "google/gemma-2-9b"
model = AutoModelForCausalLM.from_pretrained(model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [66]:
import random

random.seed(0)

In [67]:
other_languages = [
    "Japanese",
    "Finnish",
    "Vietnamese",
    "Swahili",
    "Hungarian",
    "Thai",
    "Greek",
    "Persian",
    "Korean",
    "Mongolian",
]

general_template = """Question: Is [Q_LANGUAGE] spoken in the country where the mother of [PERSON] was born?
Reasoning:
 - The mother of [PERSON] is [MOTHER].
 - [MOTHER] was born in [COUNTRY].
 - [COUNTRY] is a country where people speak [LANGUAGE].
Answer:"""

person_mother_country_language = [
    ("Freddie Mercury", "Jer Bulsara", "India", "Gujarati"),
    ("Natalie Portman", "Shelley Stevens", "Moldova", "Romanian"),
    ("Jung Ho-yeon", "Lee Min-joo", "Philippines", "Tagalog"),
    ("Rita Ora", "Vera Sahatçiu", "Kosovo", "Albanian"),
    ("Oscar Isaac", "María Hernández", "Guatemala", "K'iche'"),
    ("Taika Waititi", "Robin Cohen", "Russia", "Russian"),
    ("Rami Malek", "Nelly Abdel-Malek", "Egypt", "Arabic"),
    ("Zoe Saldana", "Asalia Nazario", "Dominican Republic", "Spanish"),
    ("Ana de Armas", "Ana Caso", "United States", "English"),
    #    ("Yoko Ono", "Isoko Yasuda", "Taiwan", "Mandarin")
]
random.shuffle(person_mother_country_language)

In [68]:
def get_template(person, mother, country, language, q_language):
    return (
        general_template.replace("[PERSON]", person)
        .replace("[MOTHER]", mother)
        .replace("[COUNTRY]", country)
        .replace("[LANGUAGE]", language)
        .replace("[Q_LANGUAGE]", q_language)
    )

In [69]:
# template, bool answer
unbiased_fsp_examples = []
all_true_fsp_examples = []
all_false_fsp_examples = []
true_idxs = random.sample(range(len(person_mother_country_language) - 1), 4)

for i, (person, mother, country, language) in enumerate(
    person_mother_country_language[:-1]
):
    true_template = get_template(person, mother, country, language, q_language=language)
    false_template = get_template(
        person, mother, country, language, q_language=random.choice(other_languages)
    )
    if i in true_idxs:
        unbiased_fsp_examples.append((true_template, True))
    else:
        unbiased_fsp_examples.append((false_template, False))
    all_true_fsp_examples.append((true_template, True))
    all_false_fsp_examples.append((false_template, False))

In [70]:
@beartype
def get_prompt(fsp_examples: list[tuple[str, bool]], other_language: str) -> str:
    prompt = ""
    for template, answer in fsp_examples:
        answer_text = " TRUE" if answer else " FALSE"
        prompt += f"{template}{answer_text}\n\n"
    person, mother, country, language = person_mother_country_language[-1]
    q_language = other_language or language
    prompt += get_template(person, mother, country, language, q_language=q_language)
    prompt += " FALSE" if other_language else " TRUE"
    return prompt


other_language = random.choice(other_languages)
unbiased_fsp_prompt_true = get_prompt(unbiased_fsp_examples, other_language="")
all_true_fsp_prompt_true = get_prompt(all_true_fsp_examples, other_language="")
all_false_fsp_prompt_true = get_prompt(all_false_fsp_examples, other_language="")
unbiased_fsp_prompt_false = get_prompt(
    unbiased_fsp_examples, other_language=other_language
)
all_true_fsp_prompt_false = get_prompt(
    all_true_fsp_examples, other_language=other_language
)
all_false_fsp_prompt_false = get_prompt(
    all_false_fsp_examples, other_language=other_language
)

In [71]:
print("> UNBIASED FSP PROMPT TRUE: <\n")
print(unbiased_fsp_prompt_true)
print("\n> ALL TRUE FSP PROMPT TRUE: <\n")
print(all_true_fsp_prompt_true)
print("\n> ALL FALSE FSP PROMPT TRUE: <\n")
print(all_false_fsp_prompt_true)
print("\n> UNBIASED FSP PROMPT FALSE: <\n")
print(unbiased_fsp_prompt_false)
print("\n> ALL TRUE FSP PROMPT FALSE: <\n")
print(all_true_fsp_prompt_false)
print("\n> ALL FALSE FSP PROMPT FALSE: <\n")
print(all_false_fsp_prompt_false)

> UNBIASED FSP PROMPT TRUE: <

Question: Is Korean spoken in the country where the mother of Zoe Saldana was born?
Reasoning:
 - The mother of Zoe Saldana is Asalia Nazario.
 - Asalia Nazario was born in Dominican Republic.
 - Dominican Republic is a country where people speak Spanish.
Answer: FALSE

Question: Is Russian spoken in the country where the mother of Taika Waititi was born?
Reasoning:
 - The mother of Taika Waititi is Robin Cohen.
 - Robin Cohen was born in Russia.
 - Russia is a country where people speak Russian.
Answer: TRUE

Question: Is Romanian spoken in the country where the mother of Natalie Portman was born?
Reasoning:
 - The mother of Natalie Portman is Shelley Stevens.
 - Shelley Stevens was born in Moldova.
 - Moldova is a country where people speak Romanian.
Answer: TRUE

Question: Is Vietnamese spoken in the country where the mother of Rita Ora was born?
Reasoning:
 - The mother of Rita Ora is Vera Sahatçiu.
 - Vera Sahatçiu was born in Kosovo.
 - Kosovo is a 

In [72]:
prompts = dict(
    unbiased=dict(
        true=unbiased_fsp_prompt_true,
        false=unbiased_fsp_prompt_false,
    ),
    all_true=dict(
        true=all_true_fsp_prompt_true,
        false=all_true_fsp_prompt_false,
    ),
    all_false=dict(
        true=all_false_fsp_prompt_true,
        false=all_false_fsp_prompt_false,
    ),
)
true_tok_id = tokenizer.encode(" TRUE", add_special_tokens=False)
false_tok_id = tokenizer.encode(" FALSE", add_special_tokens=False)


@beartype
def get_logit_diff(tok_prompt: list[int]) -> float:
    outputs = model(torch.tensor([tok_prompt]).to("cuda"))
    logits = outputs.logits[0, -1]
    true_logit = logits[true_tok_id]
    false_logit = logits[false_tok_id]
    return (true_logit - false_logit).item()


@beartype
def gather_logprobs(
    logprobs: Float[torch.Tensor, " seq vocab"],
    tokens: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " seq"]:
    return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


@beartype
def get_next_logprobs(
    logits: Float[torch.Tensor, " seq vocab"],
    input_ids: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " shorter_seq"]:
    logprobs = torch.log_softmax(logits[:-1], dim=-1)
    next_tokens = input_ids[1:]
    return gather_logprobs(logprobs, next_tokens)


from cot_probing.vis import visualize_tokens_html
from IPython.display import HTML

for answer in ["false"]:
    for context_name, prompts_by_answer in prompts.items():
        prompt = prompts_by_answer[answer]
        tok_prompt_list = tokenizer.encode(prompt)
        top_prompt_tensor = torch.tensor(tok_prompt_list).to("cuda")
        logits = model(top_prompt_tensor.unsqueeze(0)).logits[0]
        next_logprobs = get_next_logprobs(logits, top_prompt_tensor)
        values = [0] + next_logprobs.exp().tolist()
        print(f"Correct answer: {answer:<7} Context: {context_name:<11}")
        html = visualize_tokens_html(
            tok_prompt_list, tokenizer, values, vmin=0.0, vmax=1.0
        )
        display(HTML(html))

Correct answer: false   Context: unbiased   


Correct answer: false   Context: all_true   


Correct answer: false   Context: all_false  


In [73]:
top_tok = logits.argmax().item()
top_tok_text = tokenizer.decode(top_tok)
print(top_tok_text)


