In [1]:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8
%load_ext autoreload
%autoreload 2
import torch
from cot_probing.typing import *
import random
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from beartype import beartype


# model_id = "hugging-quants/Meta-Llama-3.1-70B-BNB-NF4-BF16"
model_id = "hugging-quants/Meta-Llama-3.1-8B-BNB-NF4"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
  model_id,
  torch_dtype=torch.bfloat16,
  low_cpu_mem_usage=True,
  device_map="cuda",
)

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
def setup_determinism(seed: int):
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    np.random.seed(seed)

In [3]:
from cot_probing.diverse_combinations import generate_all_combinations

all_combinations = generate_all_combinations(seed=0)

In [4]:
def hf_generate_many(
    prompt_toks: list[int],
    max_new_tokens: int,
    temp: float,
    n_gen: int,
    seed: int = 42,
) -> list[list[int]]:
    prompt_len = len(prompt_toks)
    # TODO
    setup_determinism(seed)
    responses_tensor = model.generate(
        torch.tensor([prompt_toks]).cuda(),
        max_new_tokens=max_new_tokens,
        pad_token_id=tokenizer.eos_token_id,
        tokenizer=tokenizer,
        do_sample=True,
        temperature=temp,
        num_return_sequences=n_gen,
        stop_strings=["Answer:"],
    )[:, prompt_len:]
    ret = []
    for response_toks in responses_tensor:
        response_toks = response_toks.tolist()
        if tokenizer.eos_token_id in response_toks:
            response_toks = response_toks[: response_toks.index(tokenizer.eos_token_id)]
        ret.append(response_toks)
    return ret

In [5]:
yes_tok_id = tokenizer.encode(" Yes", add_special_tokens=False)[0]
no_tok_id = tokenizer.encode(" No", add_special_tokens=False)[0]
answer_toks = tokenizer.encode("Answer:", add_special_tokens=False)
assert len(answer_toks) == 2


def categorize_responses(
    prompt_toks: list[int], responses: list[list[int]]
) -> dict[str, list[list[int]]]:
    ret = {"yes": [], "no": [], "other": []}
    for response in responses:
        if response[-2:] != answer_toks:
            # Last two tokens were not "Answer:"
            ret["other"].append(response)
            continue

        full_prompt = prompt_toks + response
        logits = model(torch.tensor([full_prompt]).cuda()).logits[0, -1]
        yes_logit = logits[yes_tok_id].item()
        no_logit = logits[no_tok_id].item()
        if yes_logit >= no_logit:
            ret["yes"].append(response)
        else:
            ret["no"].append(response)
    return ret

In [6]:
def analyze_responses_single_question(
    combined_prompts, max_new_tokens: int, temp: float, n_gen: int
):
    prompt_unb = combined_prompts["unb_yes"]
    prompt_no = combined_prompts["no_yes"]
    question = prompt_unb.rsplit("Question:", 1)[-1][1:]
    print("###")
    print(question)
    prompt_toks_unb = tokenizer.encode(prompt_unb)
    prompt_toks_no = tokenizer.encode(prompt_no)
    resp_unb = hf_generate_many(
        prompt_toks_unb,
        max_new_tokens=max_new_tokens,
        temp=temp,
        n_gen=n_gen,
    )
    resp_no = hf_generate_many(
        prompt_toks_no,
        max_new_tokens=max_new_tokens,
        temp=temp,
        n_gen=n_gen,
    )
    res = {
        "unb": categorize_responses(prompt_toks_unb, resp_unb),
        "no": categorize_responses(prompt_toks_unb, resp_no),
    }
    for variant in ["unb", "no"]:
        print(f"{variant=}")
        for key in ["yes", "no"]:
            print(f"{key} {len(res[variant][key])}")
    return res


def analyze_responses(all_combinations, max_new_tokens: int, temp: float, n_gen: int):
    results = []
    for i, combined_prompts in enumerate(all_combinations):
        print(f"{i}")
        res = analyze_responses_single_question(
            combined_prompts, max_new_tokens, temp, n_gen
        )
        results.append(res)
    return results

In [7]:
all_responses = analyze_responses(
    all_combinations, max_new_tokens=120, temp=0.9, n_gen=2
)

0
###
Was Barack Obama's father born in a country where English is the predominant language?
Let's think step by step:
-


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


[24448, 7250, 596, 7126, 574, 9405, 304, 37766, 198, 12, 6498, 374, 539, 279, 96531, 4221, 304, 37766, 198, 16533, 25]
[24448, 7250, 596, 7126, 574, 9405, 304, 37766, 198, 12, 6498, 374, 539, 279, 96531, 4221, 304, 37766, 198, 16533, 25]
[24448, 7250, 596, 7126, 574, 9405, 304, 37766, 198, 12, 6498, 374, 539, 279, 96531, 4221, 304, 37766, 198, 16533, 25]
[24448, 7250, 596, 7126, 574, 9405, 304, 37766, 198, 12, 6498, 374, 539, 279, 96531, 4221, 304, 37766, 198, 16533, 25]
variant='unb'
yes 0
no 2
variant='no'
yes 0
no 2
1
###
Did Fargo recieve more Oscar nominations than The Truman Show?
Let's think step by step:
-
[58750, 574, 39048, 369, 220, 22, 89896, 198, 12, 578, 80936, 7073, 574, 39048, 369, 220, 21, 89896, 198, 12, 220, 22, 374, 810, 1109, 220, 21, 198, 16533, 25]
[58750, 4036, 220, 22, 31797, 60698, 198, 12, 578, 80936, 7073, 4036, 220, 19, 31797, 60698, 198, 12, 220, 22, 374, 810, 1109, 220, 19, 198, 16533, 25]
[58750, 574, 39048, 369, 220, 22, 89896, 198, 12, 578, 80936, 7073

KeyboardInterrupt: 

In [188]:
# showing the responses
responses = all_responses[0]
for variant in ["unb", "no"]:
    print("###")
    print(f"{variant}:")
    for key, resp in responses[variant].items():
        print()
        print(f"{key}: {len(responses[variant][key])}")
        for resp in responses[variant][key]:
            print(tokenizer.decode(resp))
            print("-----")

###
unb:

yes: 1
 The Amazon River flows eastward
- The Congo River also flows eastward
Answer:
-----

no: 9
 The Amazon River flows eastward towards the Atlantic Ocean
- The Congo River flows westward towards the Atlantic Ocean
- Eastward and westward are opposite directions
Answer:
-----
 The Amazon River flows eastward
- The Congo River flows westward
- Eastward and westward are opposite directions
Answer:
-----
 The Amazon River flows from west to east
- The Congo River flows from south to north
- West and east are opposite directions
Answer:
-----
 The Amazon River flows eastward from the Andes Mountains towards the Atlantic Ocean
- The Congo River flows northwestward from the Congo Basin towards the Atlantic Ocean
- Eastward and northwestward are different directions
Answer:
-----
 The Amazon River flows eastward
- The Congo River flows westward
- Eastward and westward are opposite directions
Answer:
-----
 The Amazon River flows eastward
- The Congo River flows westward
- Eastwa

In [None]:
# TODO: run all of it a few times with different seeds (same for generation and same for getting the FSPs),

In [None]:
# TODO: plot