In [2]:
import os
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import openai
import tqdm

client = openai.OpenAI(api_key=os.environ["OPENAI_API_KEY"])

In [3]:
def create_messages(context: str, elicitation_question: str):
    return [
        {"role": "system", "content": "When requested to, only respond with 'Yes' or 'No'. Do not provide any other information or thoughts."},
        {"role": "user", "content": f"Context: {context}\nElicitation Question: {elicitation_question}"}
    ]


def get_text_and_logprobs_from_response(response):
    text = response.choices[0].message.content
    token_logprobs = {
        x.token: x.logprob for x in response.choices[0].logprobs.content[0].top_logprobs
    }
    return text, token_logprobs 


def get_response_text_and_logprobs(messages: list[dict], model: str, **model_kwargs):
    response = client.chat.completions.create(
        model=model,
        messages=messages,
        logprobs=model_kwargs.pop("logprobs", True),
        top_logprobs=model_kwargs.pop("top_logprobs", 10),
        **model_kwargs
    )
    text = response.choices[0].message.content
    token_logprobs = {
        x.token: x.logprob for x in response.choices[0].logprobs.content[0].top_logprobs
    }
    return text, token_logprobs


def calculate_yes_no_logprobs(token_logprobs: dict[str, float]):
    yes_logprob = 0
    no_logprob = 0
    for token, logprob in token_logprobs.items():
        if token.lower() == "yes":
            yes_logprob += logprob
        elif token.lower() == "no":
            no_logprob += logprob
    return yes_logprob, no_logprob


def get_and_write_single_elicitation_result(
        context: str,
        elicitation_question: str,
        model: str,
        out_csv: Path,
        tags: dict | None=None,
        **model_kwargs
    ):
    if not isinstance(out_csv, Path):
        out_csv = Path(out_csv)
    tags = {} if tags is None else tags

    datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    messages = create_messages(context, elicitation_question)
    text, token_logprobs = get_response_text_and_logprobs(messages, model, **model_kwargs)
    yes_logprob, no_logprob = calculate_yes_no_logprobs(token_logprobs)
    result = {
        "datetime": datetime_str,
        "yes_logprob": yes_logprob,
        "no_logprob": no_logprob,
        "context": context,
        "elicitation_question": elicitation_question,
        "text": text,
        **{k: tags[k] for k in sorted(tags.keys())}
    }

    if not out_csv.exists():
        out_csv.parent.mkdir(parents=True, exist_ok=True)
        with out_csv.open("w") as f:
            f.write(f"{','.join(result.keys())}\n")
    existing_columns = out_csv.read_text().split("\n")[0].split(",")
    if list(result.keys()) != existing_columns:
        raise ValueError(f"Columns mismatch: {list(result.keys())} != {existing_columns}")
    with out_csv.open("a") as f:
        f.write(f"{','.join(str(result[key]) for key in result.keys())}\n")


def get_and_write_elicitation_results(
        contexts: list[str],
        elicitation_questions: list[str],
        model: str,
        out_csv: Path,
        tags: list[dict] | None = None,
        append_datetime_to_filename: bool = False,
        **model_kwargs
    ):
    if not isinstance(out_csv, Path):
        out_csv = Path(out_csv)
    if append_datetime_to_filename:
        datetime_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        out_csv = out_csv.with_name(f"{out_csv.stem}_{datetime_str}{out_csv.suffix}")
    if out_csv.exists():
        raise FileExistsError(f"Output file {out_csv} already exists")
    if not isinstance(tags, list):
        raise ValueError("tags must be a list of dictionaries")

    pbar = tqdm.tqdm(total=len(contexts), desc="Processing elicitations")
    with ThreadPoolExecutor(max_workers=10) as executor:
        futures = []
        for context, elicitation_question, tag in zip(contexts, elicitation_questions, tags):
            future = executor.submit(
                get_and_write_single_elicitation_result,
                context,
                elicitation_question, 
                model,
                out_csv,
                tag,
                **model_kwargs
            )
            futures.append(future)
        
        for future in as_completed(futures):
            pbar.update(1)
    pbar.close()

In [4]:
number_results = 10
example_contexts = [f"Example context {i}" for i in range(number_results)]
elicitation_questions = [f"Question {i}: Answer randomly: Yes or No" for i in range(number_results)]
tags = [{"example_tag1": 1, "example_tag2": "2"} for _ in range(number_results)]

get_and_write_elicitation_results(
    contexts=example_contexts,
    elicitation_questions=elicitation_questions,
    model="gpt-4o-mini",
    out_csv=f"example_outputs/example_out_{number_results}.csv",
    tags=tags,
    append_datetime_to_filename=True,
)

Processing elicitations: 100%|██████████| 10/10 [00:01<00:00,  8.10it/s]
