In [None]:
"""Notebook for scratchpad with info retrieval."""

# flake8: noqa
import json
import logging
import os
import pickle
import sys
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime, timedelta
from functools import partial

current_path = os.getcwd()

sys.path.append(os.path.join(current_path, "../.."))  # noqa: E402
from helpers import data_utils, model_eval  # noqa: E402
from helpers.llm_prompts import (  # noqa: E402
    HUMAN_JOINT_PROMPT_1,
    HUMAN_JOINT_PROMPT_2,
    HUMAN_JOINT_PROMPT_3,
    HUMAN_JOINT_PROMPT_4,
    SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT,
    SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT,
    SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT,
)

logger = logging.getLogger()
logger.setLevel(logging.INFO)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
human_joint_prompts = [
    HUMAN_JOINT_PROMPT_1,
    HUMAN_JOINT_PROMPT_2,
    HUMAN_JOINT_PROMPT_3,
    HUMAN_JOINT_PROMPT_4,
]

### Load data

In [4]:
file_path = "latest-llm.jsonl"
questions_data = data_utils.read_jsonl(file_path)[0]
questions = questions_data["questions"]

single_non_acled_questions = [
    q for q in questions if q["combination_of"] == "N/A" and q["source"] != "acled"
]
single_acled_questions = [
    q for q in questions if q["combination_of"] == "N/A" and q["source"] == "acled"
]
combo_questions = [q for q in questions if q["combination_of"] != "N/A" and q["source"] == "acled"]

In [5]:
combo_questions_unrolled = []

for q in combo_questions:
    for i in range(4):
        new_q = q.copy()
        new_q["combo_index"] = i

        combo_questions_unrolled.append(new_q)

In [6]:
len(single_non_acled_questions), len(single_acled_questions), len(combo_questions_unrolled)

(339, 162, 644)

### Scratchpad + Retrieval

In [15]:
models = {
    "gpt_3p5_turbo_0125": {"source": "OAI", "full_name": "gpt-3.5-turbo-0125"},
    "gpt_4_turbo_0409": {"source": "OAI", "full_name": "gpt-4-turbo-2024-04-09"},
    "gpt_4_1106_preview": {"source": "OAI", "full_name": "gpt-4-1106-preview"},
    "gpt_4_0125_preview": {"source": "OAI", "full_name": "gpt-4-0125-preview"},
    "gpt_4o": {"source": "OAI", "full_name": "gpt-4o"},
    "mistral_8x7b_instruct": {
        "source": "TOGETHER",
        "full_name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
    },
    "mistral_8x22b_instruct": {
        "source": "TOGETHER",
        "full_name": "mistralai/Mixtral-8x22B-Instruct-v0.1",
    },
    "mistral_large": {
        "source": "MISTRAL",
        "full_name": "mistral-large-latest",
    },
    "qwen_1p5_110b": {
        "source": "TOGETHER",
        "full_name": "Qwen/Qwen1.5-110B-Chat",
    },
    "claude_2p1": {"source": "ANTHROPIC", "full_name": "claude-2.1"},
    "claude_3_opus": {"source": "ANTHROPIC", "full_name": "claude-3-opus-20240229"},
    "claude_3_sonnet": {"source": "ANTHROPIC", "full_name": "claude-3-sonnet-20240229"},
    "claude_3_haiku": {"source": "ANTHROPIC", "full_name": "claude-3-haiku-20240307"},
}

In [16]:
# Mapping retrieved summaries back into questions
for question_source in [single_non_acled_questions, single_acled_questions]:
    for q in question_source:
        reformatted_id = q["id"].replace("/", "_")
        filename = f"news/{reformatted_id}.pickle"
        with open(filename, "rb") as file:
            retrieved_info = pickle.load(file)
        q["news"] = retrieved_info

for q in combo_questions_unrolled:
    for sub_q in q["combination_of"]:
        reformatted_id = sub_q["id"].replace("/", "_")
        filename = f"news/{reformatted_id}.pickle"
        with open(filename, "rb") as file:
            retrieved_info = pickle.load(file)
        sub_q["news"] = retrieved_info

In [17]:
def get_all_retrieved_info(all_retrieved_info):
    """Get all retrieved news."""
    retrieved_info = ""
    for summary in all_retrieved_info:
        retrieved_info += f"Article title: {summary['title']}" + "\n"
        retrieved_info += f"Summary: {summary['summary']}" + "\n\n"
    return retrieved_info


def worker(index, model_name, save_dict, questions_to_eval):
    """Worker for one inference."""
    if save_dict[index] != "":
        return

    logger.info(f"Starting {model_name} - {index}")

    if questions_to_eval[index]["source"] != "acled":
        prompt = SCRATCH_PAD_WITH_SUMMARIES_MARKET_PROMPT.format(
            question=questions_to_eval[index]["question"],
            background=questions_to_eval[index]["background"]
            + "\n"
            + questions_to_eval[index]["model_info_resolution_criteria"],
            resolution_criteria=questions_to_eval[index]["resolution_criteria"],
            close_date=questions_to_eval[index]["model_info_close_datetime"],
            retrieved_info=get_all_retrieved_info(questions_to_eval[index]["news"]),
        )
        response = model_eval.get_response_from_model(
            prompt=prompt,
            max_tokens=2000,
            model_name=models[model_name]["full_name"],
            temperature=0,
            wait_time=30,
        )

        save_dict[index] = (model_eval.reformat_answers(response=response, single=True), response)

    else:
        all_resolution_dates = []
        for horizon in questions_to_eval[index]["forecast_horizons"]:
            resolution_date = datetime.fromisoformat(
                questions_to_eval[index]["freeze_datetime"]
            ) + timedelta(days=7 + horizon)
            resolution_date = resolution_date.isoformat()
            all_resolution_dates.append(resolution_date)

        if questions_to_eval[index]["combination_of"] == "N/A":
            prompt = SCRATCH_PAD_WITH_SUMMARIES_NON_MARKET_PROMPT.format(
                question=questions_to_eval[index]["question"],
                background=questions_to_eval[index]["background"]
                + "\n"
                + questions_to_eval[index]["model_info_resolution_criteria"],
                resolution_criteria=questions_to_eval[index]["resolution_criteria"],
                freeze_datetime=questions_to_eval[index]["freeze_datetime"],
                freeze_datetime_value=questions_to_eval[index]["freeze_datetime_value"],
                freeze_datetime_value_explanation=questions_to_eval[index][
                    "freeze_datetime_value_explanation"
                ],
                retrieved_info=get_all_retrieved_info(questions_to_eval[index]["news"]),
                list_of_resolution_dates=all_resolution_dates,
            )
            response = model_eval.get_response_from_model(
                prompt=prompt,
                max_tokens=2000,
                model_name=models[model_name]["full_name"],
                temperature=0,
                wait_time=30,
            )
            save_dict[index] = (
                model_eval.reformat_answers(
                    response=response, prompt=prompt, question=questions_to_eval[index]
                ),
                response,
            )
        else:
            prompt = SCRATCH_PAD_WITH_SUMMARIES_JOINT_QUESTION_PROMPT.format(
                human_prompt=human_joint_prompts[questions_to_eval[index]["combo_index"]],
                question_1=questions_to_eval[index]["combination_of"][0]["question"],
                question_2=questions_to_eval[index]["combination_of"][1]["question"],
                background_1=questions_to_eval[index]["combination_of"][0]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][0]["model_info_resolution_criteria"],
                background_2=questions_to_eval[index]["combination_of"][1]["background"]
                + "\n"
                + questions_to_eval[index]["combination_of"][1]["model_info_resolution_criteria"],
                resolution_criteria_1=questions_to_eval[index]["combination_of"][0][
                    "resolution_criteria"
                ],
                resolution_criteria_2=questions_to_eval[index]["combination_of"][1][
                    "resolution_criteria"
                ],
                freeze_datetime_1=questions_to_eval[index]["combination_of"][0]["freeze_datetime"],
                freeze_datetime_2=questions_to_eval[index]["combination_of"][1]["freeze_datetime"],
                freeze_datetime_value_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value"
                ],
                freeze_datetime_value_explanation_1=questions_to_eval[index]["combination_of"][0][
                    "freeze_datetime_value_explanation"
                ],
                freeze_datetime_value_explanation_2=questions_to_eval[index]["combination_of"][1][
                    "freeze_datetime_value_explanation"
                ],
                retrieved_info_1=get_all_retrieved_info(
                    questions_to_eval[index]["combination_of"][0]["news"]
                ),
                retrieved_info_2=get_all_retrieved_info(
                    questions_to_eval[index]["combination_of"][1]["news"]
                ),
                list_of_resolution_dates=all_resolution_dates,
            )

            response = model_eval.get_response_from_model(
                prompt=prompt,
                max_tokens=2000,
                model_name=models[model_name]["full_name"],
                temperature=0,
                wait_time=30,
            )

            save_dict[index] = (
                model_eval.reformat_answers(
                    response=response, prompt=prompt, question=questions_to_eval[index]
                ),
                response,
            )

    logger.info(f"Answer: {save_dict[index][0]}")

    return None


def executor(max_workers, model_name, save_dict, questions_to_eval):
    """Executor to run all inferences."""
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        worker_with_args = partial(
            worker, model_name=model_name, save_dict=save_dict, questions_to_eval=questions_to_eval
        )
        return list(executor.map(worker_with_args, range(len(questions_to_eval))))

In [18]:
results = {}
model_result_loaded = {}
models_to_test = list(models.keys())[:]


for question in [single_acled_questions, combo_questions_unrolled, single_non_acled_questions]:
    questions_to_eval = question
    if question[0]["source"] != "acled":
        test_type = "scratchpad_with_info_retrieval/non_acled"
    elif question[0]["source"] == "acled" and question[0]["combination_of"] == "N/A":
        test_type = "scratchpad_with_info_retrieval/acled"
    else:
        test_type = "scratchpad_with_info_retrieval/combo"

    for model in models_to_test:
        if model not in model_result_loaded.keys():
            model_result_loaded[model] = {}
        model_result_loaded[model] = False

    for model in models_to_test:
        file_path = f"{test_type}/{model}.jsonl"
        if model not in results.keys():
            results[model] = {}
        try:
            results[model] = read_jsonl(file_path)
            model_result_loaded[model] = True  # Set flag to True if loaded successfully
        except:
            results[model] = {i: "" for i in range(len(questions_to_eval))}

    for model in models_to_test:
        file_path = f"{test_type}/{model}.jsonl"
        if not model_result_loaded[model]:
            executor_count = 50
            if models[model]["source"] == "ANTHROPIC":
                executor_count = 30

            executor(executor_count, model, results[model], questions_to_eval)

            current_model_forecasts = []
            for index in range(len(questions_to_eval)):
                if questions_to_eval[index]["source"] == "acled":
                    for forecast, horizon in zip(
                        results[model][index][0], questions_to_eval[index]["forecast_horizons"]
                    ):
                        current_forecast = {
                            "id": questions_to_eval[index]["id"],
                            "source": questions_to_eval[index]["source"],
                            "forecast": forecast,
                            "horizon": horizon,
                            "reasoning": results[model][index][1],
                        }

                        if questions_to_eval[index]["combination_of"] != "N/A":
                            combo_index = questions_to_eval[index]["combo_index"]
                            if combo_index == 0:
                                current_forecast["direction"] = [1, 1]
                            elif combo_index == 1:
                                current_forecast["direction"] = [1, -1]
                            elif combo_index == 2:
                                current_forecast["direction"] = [-1, 1]
                            else:
                                current_forecast["direction"] = [-1, -1]

                        current_model_forecasts.append(current_forecast)

                else:
                    current_forecast = {
                        "id": questions_to_eval[index]["id"],
                        "source": questions_to_eval[index]["source"],
                        "forecast": results[model][index][0],
                        "reasoning": results[model][index][1],
                    }
                    current_model_forecasts.append(current_forecast)

            os.makedirs(os.path.dirname(file_path), exist_ok=True)
            with open(file_path, "w") as file:
                for entry in current_model_forecasts:
                    json_line = json.dumps(entry)
                    file.write(json_line + "\n")

In [19]:
model_eval.generage_final_forecast_files(
    deadline=questions_data["forecast_due_date"],
    prompt_type="scratchpad_with_info_retrieval",
    models=models,
)