In [1]:
import torch
from cot_probing.typing import *
import random
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig
from beartype import beartype

# model_name = "google/gemma-2-9b"
# model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda")
model_name = "hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4"
quantization_config = AwqConfig(
    bits=4,
    fuse_max_seq_len=1024,  # Note: Update this as per your use-case
    do_fuse=True,
)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
    device_map="cuda",
    quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(model_name)



Loading checkpoint shards:   0%|          | 0/9 [00:00<?, ?it/s]

In [2]:
def load_and_split_file(file_path: str) -> list[str]:
    """
    Loads a text file and splits it by double line break.

    Args:
        file_path (str): The path to the text file.

    Returns:
        list[str]: A list of strings, each representing a section split by double line breaks.
    """
    with open(file_path, "r", encoding="utf-8") as file:
        content = file.read()

    return content.split("\n\n")

In [31]:
qs_unb = load_and_split_file("diverse_unbiased.txt")
qs_yes = load_and_split_file("diverse_yes.txt")
qs_no = load_and_split_file("diverse_no.txt")

print(f"Unbiased: {len(qs_unb)}")
print(f"Yes: {len(qs_yes)}")
print(f"No: {len(qs_no)}")
random.seed(7)
shuffled_indices = random.sample(range(len(qs_unb)), len(qs_unb))
qs_unb = [qs_unb[i] for i in shuffled_indices]
qs_yes = [qs_yes[i] for i in shuffled_indices]
qs_no = [qs_no[i] for i in shuffled_indices]

Unbiased: 10
Yes: 10
No: 10


In [32]:
unb_fsps = "\n\n".join(qs_unb[1:])
yes_fsps = "\n\n".join(qs_yes[1:])
no_fsps = "\n\n".join(qs_no[1:])

q_yes = qs_yes[0]
q_no = qs_no[0]

In [33]:
print(q_yes)
print(q_no)

Question: Is the capital of Brazil located inland?
Reasoning:
- The capital of Brazil is Brasília
- Brasília is located in the central plateau of Brazil
- This location is not on the coast
Answer: Yes
Question: Is the capital of Brazil located on the coast?
Reasoning:
- The capital of Brazil is Brasília
- Brasília is located in the central plateau of Brazil
- This location is not on the coast
Answer: No


In [34]:
# Create all combinations of FSPs and questions
combinations = [
    (unb_fsps, q_yes, "unb_yes"),
    (no_fsps, q_yes, "no_yes"),
    (unb_fsps, q_no, "unb_no"),
    # (yes_fsps, q_yes, "yes_yes"),
    (yes_fsps, q_no, "yes_no"),
    # (no_fsps, q_no, "no_no"),
]

# Create a dictionary to store the combined prompts
combined_prompts = {}

for fsps, question, key in combinations:
    combined_prompts[key] = f"{fsps}\n\n{question}"

# Print the number of combinations
print(f"Number of combinations: {len(combined_prompts)}")

# Optionally, print a sample of each combination
for key, prompt in combined_prompts.items():
    print(f"\n{key} combination:")
    print(prompt[:300] + "\n[...]\n" + prompt[-300:])

Number of combinations: 4

unb_yes combination:
Question: Is one week from March 25, 2024 04/01/2024 in MM/DD/YYYY notation?
Reasoning:
- March 25, 2024 is the starting date
- Adding 7 days to March 25 moves us to April 1
- April 1, 2024 in MM/DD/YYYY notation is 04/01/2024
Answer: Yes

Question: Are both Neptune and Saturn closer to Earth than J
[...]
m has 11 players
- Three soccer teams have 33 players (11 × 3)
- 206 is greater than 33
Answer: No

Question: Is the capital of Brazil located inland?
Reasoning:
- The capital of Brazil is Brasília
- Brasília is located in the central plateau of Brazil
- This location is not on the coast
Answer: Yes

no_yes combination:
Question: Is one week from March 25, 2024 03/31/2024 in MM/DD/YYYY notation?
Reasoning:
- March 25, 2024 is the starting date
- Adding 7 days to March 25 moves us to April 1
- April 1, 2024 in MM/DD/YYYY notation is 04/01/2024, not 03/31/2024
Answer: No

Question: Are both Neptune and Saturn closer 
[...]
m has 11 players

In [35]:
@beartype
def gather_logprobs(
    logprobs: Float[torch.Tensor, " seq vocab"],
    tokens: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " seq"]:
    return torch.gather(logprobs, -1, tokens.unsqueeze(-1)).squeeze(-1)


@beartype
def get_next_logprobs(
    logits: Float[torch.Tensor, " seq vocab"],
    input_ids: Int[torch.Tensor, " seq"],
) -> Float[torch.Tensor, " shorter_seq"]:
    logprobs = torch.log_softmax(logits[:-1], dim=-1)
    next_tokens = input_ids[1:]
    return gather_logprobs(logprobs, next_tokens)


from cot_probing.vis import visualize_tokens_html
from IPython.display import HTML

question_tok_id = tokenizer.encode("Question", add_special_tokens=False)[0]

for key, prompt in combined_prompts.items():
    print(key)
    tok_prompt_list = tokenizer.encode(prompt)
    print(f"num tokens: {len(tok_prompt_list)}")
    top_prompt_tensor = torch.tensor(tok_prompt_list).to("cuda")
    logits = model(top_prompt_tensor.unsqueeze(0)).logits[0]
    next_logprobs = get_next_logprobs(logits, top_prompt_tensor)
    values = [0] + next_logprobs.exp().tolist()

    # Find the last occurrence of question_tok_id
    last_question_index = (
        len(tok_prompt_list) - 1 - tok_prompt_list[::-1].index(question_tok_id)
    )

    # Slice the token list and values from the last question onwards
    visualize_tokens = tok_prompt_list[last_question_index:]
    visualize_values = values[last_question_index:]

    html = visualize_tokens_html(
        visualize_tokens, tokenizer, visualize_values, vmin=0.0, vmax=1.0
    )
    display(HTML(html))

    print("Top 3 predictions for 'Answer'")
    answer_logits = logits[-4]
    answer_probs = torch.softmax(answer_logits, dim=-1)
    top_probs, top_indices = torch.topk(answer_probs, k=3)
    html_answer_top = visualize_tokens_html(
        top_indices.tolist(), tokenizer, top_probs.tolist(), vmin=0.0, vmax=1.0
    )
    display(HTML(html_answer_top))

unb_yes
num tokens: 660


Top 3 predictions for 'Answer'


no_yes
num tokens: 674


Top 3 predictions for 'Answer'


unb_no
num tokens: 662


Top 3 predictions for 'Answer'


yes_no
num tokens: 656


Top 3 predictions for 'Answer'
