In [None]:
"""Notebook for zero shot and scratchpad eval."""

import json
import logging
import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from functools import partial

current_path = os.getcwd()

sys.path.append(os.path.join(current_path, "../.."))  # noqa: E402
from helpers import (  # noqa: E402
    constants,
    data_utils,
    env,
    model_eval,
    question_curation,
)
from helpers.llm_prompts import (  # noqa: E402
    HUMAN_JOINT_PROMPT_1,
    HUMAN_JOINT_PROMPT_2,
    HUMAN_JOINT_PROMPT_3,
    HUMAN_JOINT_PROMPT_4,
    SCRATCH_PAD_MARKET_JOINT_QUESTION_PROMPT,
    SCRATCH_PAD_MARKET_PROMPT,
    SCRATCH_PAD_NON_MARKET_JOINT_QUESTION_PROMPT,
    SCRATCH_PAD_NON_MARKET_PROMPT,
    ZERO_SHOT_MARKET_JOINT_QUESTION_PROMPT,
    ZERO_SHOT_MARKET_PROMPT,
    ZERO_SHOT_NON_MARKET_JOINT_QUESTION_PROMPT,
    ZERO_SHOT_NON_MARKET_PROMPT,
)

sys.path.append(os.path.join(current_path, "../../.."))  # noqa: E402
from utils import gcp  # noqa: E402

logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [None]:
human_joint_prompts = [
    HUMAN_JOINT_PROMPT_1,
    HUMAN_JOINT_PROMPT_2,
    HUMAN_JOINT_PROMPT_3,
    HUMAN_JOINT_PROMPT_4,
]

models = constants.ZERO_SHOT_AND_SCRATCHPAD_MODELS

### Load data

In [None]:
questions_file = "latest-llm.json"

with open(questions_file, "r") as file:
    questions_data = json.load(file)

questions = questions_data["questions"]
forecast_due_date = questions_data["forecast_due_date"]
base_file_path = "individual_forecast_records/" + forecast_due_date
question_types = ["market", "non_market", "combo_market", "combo_non_market", "final"]

single_market_questions = [
    q
    for q in questions
    if q["combination_of"] == "N/A" and q["source"] not in question_curation.DATA_SOURCES
]
single_non_market_questions = [
    q
    for q in questions
    if q["combination_of"] == "N/A" and q["source"] in question_curation.DATA_SOURCES
]

combo_market_questions = [
    q
    for q in questions
    if q["combination_of"] != "N/A" and q["source"] not in question_curation.DATA_SOURCES
]
combo_non_market_questions = [
    q
    for q in questions
    if q["combination_of"] != "N/A" and q["source"] in question_curation.DATA_SOURCES
]

In [None]:
def unroll(combo_questions):
    """Unroll combo questions by directions."""
    combo_questions_unrolled = []
    for q in combo_questions:
        for i in range(4):
            new_q = q.copy()
            new_q["combo_index"] = i

            combo_questions_unrolled.append(new_q)
    return combo_questions_unrolled


combo_market_questions_unrolled = unroll(combo_market_questions)
combo_non_market_questions_unrolled = unroll(combo_non_market_questions)

In [None]:
print(
    len(single_market_questions),
    len(single_non_market_questions),
    len(combo_market_questions_unrolled),
    len(combo_non_market_questions_unrolled),
)

In [None]:
# for testing
single_market_questions = single_market_questions[:2]
single_non_market_questions = single_non_market_questions[:2]
combo_market_questions_unrolled = combo_market_questions_unrolled[:2]
combo_non_market_questions_unrolled = combo_non_market_questions_unrolled[:2]

### Zero Shot Eval

In [None]:
def worker(index, model_name, save_dict, questions_to_eval, rate_limit=False):
    """One worker for one zero shot question eval."""
    if save_dict[index] != "":
        return

    logger.info(f"{model_name} - {index}")

    if rate_limit:
        start_time = datetime.now()

    if questions_to_eval[index]["source"] not in question_curation.DATA_SOURCES:
        if questions_to_eval[index]["combination_of"] == "N/A":
            prompt = ZERO_SHOT_MARKET_PROMPT.format(
                question=questions_to_eval[index]["question"],
                background=questions_to_eval[index]["background"]
                + "\n"
                + questions_to_eval[index]["market_info_resolution_criteria"],
                resolution_criteria=questions_to_eval[index]["resolution_criteria"],
                close_date=questions_to_eval[index]["market_info_close_datetime"],
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=100,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None
            save_dict[index] = model_eval.extract_probability(response)
        else:
            prompt = ZERO_SHOT_MARKET_JOINT_QUESTION_PROMPT.format(
                human_prompt=human_joint_prompts[questions_to_eval[index]["combo_index"]],
                question_1=questions_to_eval[index]["combination_of"][0]["question"],
                question_2=questions_to_eval[index]["combination_of"][1]["question"],
                background_1=questions_to_eval[index]["combination_of"][0]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][0]["market_info_resolution_criteria"],
                background_2=questions_to_eval[index]["combination_of"][1]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][1]["market_info_resolution_criteria"],
                resolution_criteria_1=questions_to_eval[index]["combination_of"][0][
                    "resolution_criteria"
                ],
                resolution_criteria_2=questions_to_eval[index]["combination_of"][1][
                    "resolution_criteria"
                ],
                close_date=max(
                    questions_to_eval[index]["combination_of"][0]["market_info_close_datetime"],
                    questions_to_eval[index]["combination_of"][1]["market_info_close_datetime"],
                ),
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=100,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = model_eval.extract_probability(response)

    else:

        if questions_to_eval[index]["combination_of"] == "N/A":
            prompt = ZERO_SHOT_NON_MARKET_PROMPT.format(
                question=questions_to_eval[index]["question"],
                background=questions_to_eval[index]["background"]
                + "\n"
                + questions_to_eval[index]["market_info_resolution_criteria"],
                resolution_criteria=questions_to_eval[index]["resolution_criteria"],
                freeze_datetime=questions_to_eval[index]["freeze_datetime"],
                freeze_datetime_value=questions_to_eval[index]["freeze_datetime_value"],
                freeze_datetime_value_explanation=questions_to_eval[index][
                    "freeze_datetime_value_explanation"
                ],
                list_of_resolution_dates=questions_to_eval[index]["resolution_dates"],
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=100,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = model_eval.reformat_answers(
                response=response, prompt=prompt, question=questions_to_eval[index]
            )

        else:
            prompt = ZERO_SHOT_NON_MARKET_JOINT_QUESTION_PROMPT.format(
                human_prompt=human_joint_prompts[questions_to_eval[index]["combo_index"]],
                question_1=questions_to_eval[index]["combination_of"][0]["question"],
                question_2=questions_to_eval[index]["combination_of"][1]["question"],
                background_1=questions_to_eval[index]["combination_of"][0]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][0]["market_info_resolution_criteria"],
                background_2=questions_to_eval[index]["combination_of"][1]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][1]["market_info_resolution_criteria"],
                resolution_criteria_1=questions_to_eval[index]["combination_of"][0][
                    "resolution_criteria"
                ],
                resolution_criteria_2=questions_to_eval[index]["combination_of"][1][
                    "resolution_criteria"
                ],
                freeze_datetime_1=questions_to_eval[index]["combination_of"][0]["freeze_datetime"],
                freeze_datetime_2=questions_to_eval[index]["combination_of"][1]["freeze_datetime"],
                freeze_datetime_value_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_explanation_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value_explanation"
                ],
                freeze_datetime_value_explanation_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value_explanation"
                ],
                list_of_resolution_dates=questions_to_eval[index]["resolution_dates"],
            )
            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=500,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = model_eval.reformat_answers(
                response=response, prompt=prompt, question=questions_to_eval[index]
            )

    logger.info(f"Model: {model_name} | Answer: {save_dict[index]}")

    if rate_limit:
        end_time = datetime.now()
        elapsed_time = (end_time - start_time).total_seconds()
        if elapsed_time < 1:
            time.sleep(
                1 - elapsed_time
            )  # Ensure at least 1 second per request to stay within rate limits

    return None


def executor(max_workers, model_name, save_dict, questions_to_eval):
    """Zero shot executor."""
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        worker_with_args = partial(
            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval
        )
        return list(executor.map(worker_with_args, range(len(questions_to_eval))))

In [None]:
results = {}
model_result_loaded = {}
models_to_test = list(models.keys())
prompt_type = "zero_shot"

for question in [
    single_non_market_questions,
    single_market_questions,
    combo_market_questions_unrolled,
    combo_non_market_questions_unrolled,
]:
    questions_to_eval = question
    if question[0]["source"] not in question_curation.DATA_SOURCES:
        if question[0]["combination_of"] == "N/A":
            test_type = f"{prompt_type}/market"
        else:
            test_type = f"{prompt_type}/combo_market"
    else:
        if question[0]["combination_of"] == "N/A":
            test_type = f"{prompt_type}/non_market"
        else:
            test_type = f"{prompt_type}/combo_non_market"

    for model in models_to_test:
        if model not in model_result_loaded.keys():
            model_result_loaded[model] = {}
        model_result_loaded[model] = False

    for model in models_to_test:
        gcp_file_path = f"{base_file_path}/{test_type}/{model}.jsonl"
        if model not in results.keys():
            results[model] = {}
        try:
            results[model] = data_utils.download_and_read_saved_forecasts(
                gcp_file_path, base_file_path
            )
            model_result_loaded[model] = True  # Set flag to True if loaded successfully
            print(f"Downloaded {gcp_file_path}.")
        except Exception:  # Catching general exceptions
            results[model] = {i: "" for i in range(len(questions_to_eval))}

    for model in models_to_test:
        local_filename = f"{test_type}/{model}.jsonl"
        if not model_result_loaded[model]:
            executor_count = 50
            if models[model]["source"] == "ANTHROPIC":
                executor_count = 30
            elif models[model]["source"] == "GOOGLE":
                executor_count = 10

            executor(executor_count, model, results[model], questions_to_eval)

            current_model_forecasts = []
            for index in range(len(questions_to_eval)):
                if questions_to_eval[index]["source"] in question_curation.DATA_SOURCES:
                    for forecast, resolution_date in zip(
                        results[model][index], questions_to_eval[index]["resolution_dates"]
                    ):
                        current_forecast = {
                            "id": questions_to_eval[index]["id"],
                            "source": questions_to_eval[index]["source"],
                            "forecast": forecast,
                            "resolution_date": resolution_date,
                            "reasoning": None,
                        }
                        if questions_to_eval[index]["combination_of"] != "N/A":
                            combo_index = questions_to_eval[index]["combo_index"]
                            if combo_index == 0:
                                current_forecast["direction"] = [1, 1]
                            elif combo_index == 1:
                                current_forecast["direction"] = [1, -1]
                            elif combo_index == 2:
                                current_forecast["direction"] = [-1, 1]
                            else:
                                current_forecast["direction"] = [-1, -1]

                        current_model_forecasts.append(current_forecast)
                else:
                    current_forecast = {
                        "id": questions_to_eval[index]["id"],
                        "source": questions_to_eval[index]["source"],
                        "forecast": results[model][index],
                        "reasoning": None,
                    }
                    current_model_forecasts.append(current_forecast)

            os.makedirs(os.path.dirname(local_filename), exist_ok=True)
            with open(local_filename, "w") as file:
                for entry in current_model_forecasts:
                    json_line = json.dumps(entry)
                    file.write(json_line + "\n")

            gcp.storage.upload(
                bucket_name=env.FORECAST_SETS_BUCKET,
                local_filename=local_filename,
                filename=f"{base_file_path}/" + local_filename,
            )

In [None]:
model_eval.generage_final_forecast_files(
    deadline=forecast_due_date, prompt_type=prompt_type, models=models
)

data_utils.delete_and_upload_to_the_cloud(base_file_path, prompt_type, question_types)

### Scratchpad

In [None]:
def worker(index, model_name, save_dict, questions_to_eval):  # noqa: F811
    """Scratchpad worker."""
    if save_dict[index] != "":
        return

    logger.info(f"Starting {model_name} - {index}")

    if questions_to_eval[index]["source"] not in question_curation.DATA_SOURCES:
        if questions_to_eval[index]["combination_of"] == "N/A":
            prompt = SCRATCH_PAD_MARKET_PROMPT.format(
                question=questions_to_eval[index]["question"],
                background=questions_to_eval[index]["background"]
                + "\n"
                + questions_to_eval[index]["market_info_resolution_criteria"],
                resolution_criteria=questions_to_eval[index]["resolution_criteria"],
                close_date=questions_to_eval[index]["market_info_close_datetime"],
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=1300,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = (
                model_eval.reformat_answers(response=response, single=True),
                response,
            )
        else:
            prompt = SCRATCH_PAD_MARKET_JOINT_QUESTION_PROMPT.format(
                human_prompt=human_joint_prompts[questions_to_eval[index]["combo_index"]],
                question_1=questions_to_eval[index]["combination_of"][0]["question"],
                question_2=questions_to_eval[index]["combination_of"][1]["question"],
                background_1=questions_to_eval[index]["combination_of"][0]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][0]["market_info_resolution_criteria"],
                background_2=questions_to_eval[index]["combination_of"][1]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][1]["market_info_resolution_criteria"],
                resolution_criteria_1=questions_to_eval[index]["combination_of"][0][
                    "resolution_criteria"
                ],
                resolution_criteria_2=questions_to_eval[index]["combination_of"][1][
                    "resolution_criteria"
                ],
                close_date=max(
                    questions_to_eval[index]["combination_of"][0]["market_info_close_datetime"],
                    questions_to_eval[index]["combination_of"][1]["market_info_close_datetime"],
                ),
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=1300,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = (
                model_eval.reformat_answers(response=response, single=True),
                response,
            )

    else:
        if questions_to_eval[index]["combination_of"] == "N/A":
            prompt = SCRATCH_PAD_NON_MARKET_PROMPT.format(
                question=questions_to_eval[index]["question"],
                background=questions_to_eval[index]["background"]
                + "\n"
                + questions_to_eval[index]["market_info_resolution_criteria"],
                resolution_criteria=questions_to_eval[index]["resolution_criteria"],
                freeze_datetime=questions_to_eval[index]["freeze_datetime"],
                freeze_datetime_value=questions_to_eval[index]["freeze_datetime_value"],
                freeze_datetime_value_explanation=questions_to_eval[index][
                    "freeze_datetime_value_explanation"
                ],
                list_of_resolution_dates=questions_to_eval[index]["resolution_dates"],
            )

            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=2000,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = (
                model_eval.reformat_answers(
                    response=response, prompt=prompt, question=questions_to_eval[index]
                ),
                response,
            )
        else:
            prompt = SCRATCH_PAD_NON_MARKET_JOINT_QUESTION_PROMPT.format(
                human_prompt=human_joint_prompts[questions_to_eval[index]["combo_index"]],
                question_1=questions_to_eval[index]["combination_of"][0]["question"],
                question_2=questions_to_eval[index]["combination_of"][1]["question"],
                background_1=questions_to_eval[index]["combination_of"][0]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][0]["market_info_resolution_criteria"],
                background_2=questions_to_eval[index]["combination_of"][1]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][1]["market_info_resolution_criteria"],
                resolution_criteria_1=questions_to_eval[index]["combination_of"][0][
                    "resolution_criteria"
                ],
                resolution_criteria_2=questions_to_eval[index]["combination_of"][1][
                    "resolution_criteria"
                ],
                freeze_datetime_1=questions_to_eval[index]["combination_of"][0]["freeze_datetime"],
                freeze_datetime_2=questions_to_eval[index]["combination_of"][1]["freeze_datetime"],
                freeze_datetime_value_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_explanation_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value_explanation"
                ],
                freeze_datetime_value_explanation_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value_explanation"
                ],
                list_of_resolution_dates=questions_to_eval[index]["resolution_dates"],
            )
            try:
                response = model_eval.get_response_from_model(
                    prompt=prompt,
                    max_tokens=2000,
                    model_name=models[model_name]["full_name"],
                    temperature=0,
                    wait_time=30,
                )
            except Exception as e:  # Catching general exceptions and printing the error
                print(f"An error occurred: {e}")
                response = None

            save_dict[index] = (
                model_eval.reformat_answers(
                    response=response, prompt=prompt, question=questions_to_eval[index]
                ),
                response,
            )

    logger.info(f"Model: {model_name} | Answer: {save_dict[index][0]}")

    return None


def executor(max_workers, model_name, save_dict, questions_to_eval):
    """Scratchpad executor."""
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        worker_with_args = partial(
            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval
        )
        return list(executor.map(worker_with_args, range(len(questions_to_eval))))

In [None]:
results = {}
model_result_loaded = {}
models_to_test = list(models.keys())[:]
prompt_type = "scratchpad"

for question in [
    single_market_questions,
    single_non_market_questions,
    combo_market_questions_unrolled,
    combo_non_market_questions_unrolled,
]:
    questions_to_eval = question
    if question[0]["source"] not in question_curation.DATA_SOURCES:
        if question[0]["combination_of"] == "N/A":
            test_type = f"{prompt_type}/market"
        else:
            test_type = f"{prompt_type}/combo_market"
    else:
        if question[0]["combination_of"] == "N/A":
            test_type = f"{prompt_type}/non_market"
        else:
            test_type = f"{prompt_type}/combo_non_market"

    for model in models_to_test:
        if model not in model_result_loaded.keys():
            model_result_loaded[model] = {}
        model_result_loaded[model] = False

    for model in models_to_test:
        gcp_file_path = f"{base_file_path}/{test_type}/{model}.jsonl"
        if model not in results.keys():
            results[model] = {}
        try:
            results[model] = data_utils.download_and_read_saved_forecasts(
                gcp_file_path, base_file_path
            )
            model_result_loaded[model] = True  # Set flag to True if loaded successfully
            print(f"Downloaded {gcp_file_path}.")
        except Exception:  # Catching general exceptions
            results[model] = {i: "" for i in range(len(questions_to_eval))}

    for model in models_to_test:
        local_filename = f"{test_type}/{model}.jsonl"
        if not model_result_loaded[model]:
            executor_count = 50
            if models[model]["source"] == "ANTHROPIC":
                executor_count = 30
            elif models[model]["source"] == "GOOGLE":
                executor_count = 10

            executor(executor_count, model, results[model], questions_to_eval)

            current_model_forecasts = []
            for index in range(len(questions_to_eval)):
                if questions_to_eval[index]["source"] in question_curation.DATA_SOURCES:
                    for forecast, resolution_date in zip(
                        results[model][index][0], questions_to_eval[index]["resolution_dates"]
                    ):
                        current_forecast = {
                            "id": questions_to_eval[index]["id"],
                            "source": questions_to_eval[index]["source"],
                            "forecast": forecast,
                            "resolution_date": resolution_date,
                            "reasoning": results[model][index][1],
                        }

                        if questions_to_eval[index]["combination_of"] != "N/A":
                            combo_index = questions_to_eval[index]["combo_index"]
                            if combo_index == 0:
                                current_forecast["direction"] = [1, 1]
                            elif combo_index == 1:
                                current_forecast["direction"] = [1, -1]
                            elif combo_index == 2:
                                current_forecast["direction"] = [-1, 1]
                            else:
                                current_forecast["direction"] = [-1, -1]

                        current_model_forecasts.append(current_forecast)

                else:
                    current_forecast = {
                        "id": questions_to_eval[index]["id"],
                        "source": questions_to_eval[index]["source"],
                        "forecast": results[model][index][0],
                        "reasoning": results[model][index][1],
                    }
                    current_model_forecasts.append(current_forecast)

            os.makedirs(os.path.dirname(local_filename), exist_ok=True)
            with open(local_filename, "w") as file:
                for entry in current_model_forecasts:
                    json_line = json.dumps(entry)
                    file.write(json_line + "\n")

            gcp.storage.upload(
                bucket_name=env.FORECAST_SETS_BUCKET,
                local_filename=local_filename,
                filename=f"{base_file_path}/" + local_filename,
            )

In [None]:
model_eval.generage_final_forecast_files(
    deadline=forecast_due_date, prompt_type=prompt_type, models=models
)

data_utils.delete_and_upload_to_the_cloud(base_file_path, prompt_type, question_types)