# Load Data

In [None]:
from datasets import load_dataset

dataset = load_dataset("sam-paech/mmlu-pro-nomath-sml")
df_ground_truth = dataset["test"].to_pandas()

# Experiment: Generate requests and collect responses in parallel

#### Setup

In [None]:
import jsonlines
import os

from llm_council.constants import LLM_COUNCIL_MEMBERS
from llm_council.processors.council_service import (
    get_default_council_service,
    CouncilService,
)
from mmlu_prompts import (
    STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_FIRST,
    STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_SECOND,
    STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_NO_COT,
    PROMPT_JUDGE_GROUND_TRUTH_COT_FIRST,
    PROMPT_JUDGE_GROUND_TRUTH_COT_SECOND,
    PROMPT_JUDGE_GROUND_TRUTH_NO_COT,
    PROMPT_ANSWER_COT_FIRST,
    PROMPT_ANSWER_COT_SECOND,
    PROMPT_ANSWER_NO_COT,
    STRUCTURED_OUTPUT_ANSWER_COT_FIRST,
    STRUCTURED_OUTPUT_ANSWER_COT_SECOND,
    STRUCTURED_OUTPUT_ANSWER_NO_COT
)

# OUTDIR = "data_mmlu/mmlu_pro.n100.mini.run2"
NUM_EXAMPLES = 100 # Change this to None to use the full set.
MODEL = "lepton://llama3-2-3b"
BASE_OUTDIR = f"data_mmlu/mmlu_pro.{f'n{NUM_EXAMPLES}' if NUM_EXAMPLES else 'full'}.lepton"
NUM_RUNS = 3

assert MODEL in LLM_COUNCIL_MEMBERS

CHOICE_MAP = {
    0: "A",
    1: "B",
    2: "C",
    3: "D",
    4: "E",
    5: "F",
    6: "G",
    7: "H",
    8: "I",
    9: "J",
    10: "K", 
    11: "L", 
    12: "M",
}

PROMPT_MAP = {
    "so_jgt_cot1": STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_FIRST,
    "so_jgt_cot2": STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_SECOND,
    "so_jgt_cot0": STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_NO_COT,
    "pr_jgt_cot1": PROMPT_JUDGE_GROUND_TRUTH_COT_FIRST,
    "pr_jgt_cot2": PROMPT_JUDGE_GROUND_TRUTH_COT_SECOND,
    "pr_jgt_cot0": PROMPT_JUDGE_GROUND_TRUTH_NO_COT,
    "pr_ans_cot1": PROMPT_ANSWER_COT_FIRST,
    "pr_ans_cot2": PROMPT_ANSWER_COT_SECOND,
    "pr_ans_cot0": PROMPT_ANSWER_NO_COT,
    "so_ans_cot1": STRUCTURED_OUTPUT_ANSWER_COT_FIRST,
    "so_ans_cot2": STRUCTURED_OUTPUT_ANSWER_COT_SECOND,
    "so_ans_cot0": STRUCTURED_OUTPUT_ANSWER_NO_COT,
}


def get_options_string(options):
    str = ""
    option_strings = []
    for i, option in enumerate(options):
        option_strings.append(f"{CHOICE_MAP[i]}: {option}")
    return ", ".join(option_strings)
    

def get_answer_string(options, answer_index):
    return CHOICE_MAP[answer_index] + ": " + options[answer_index]


def generate_requests(
    prompt_name: int,
    should_judge_ground_truth: bool,
    temperature: float,
    schema_name: str = None,
    role: str = None,
    run: int = 0
):
    base_prompt = PROMPT_MAP[prompt_name]
    outdir = os.path.join(
        ".".join([BASE_OUTDIR, f"run{run}"]), 
        ".".join([
            prompt_name,
            # "judge_ground_truth" if should_judge_ground_truth else "answer",
            schema_name if schema_name else "no_schema",
            f"temp{temperature}",
            role if role else "no_role",
    ]))
    
    council_service = CouncilService(
        llm_council_members=[
            MODEL
            # "openai://gpt-4o-mini-2024-07-18", 
            # "openai://gpt-4o-2024-08-06",
        ],
        outdir=outdir,
    )

    if NUM_EXAMPLES:
        data = df_ground_truth.head(NUM_EXAMPLES)
    else:
        data = df_ground_truth

    for i, row in data.iterrows():
        realized_prompt = base_prompt.format(
            role=role,
            question=row.question,
            options=get_options_string(row.options),
            answer=get_answer_string(row.options, row.answer_index)
        )
        metadata = {
            "completion_request": {
                "question_id": row["question_id"],
                "temperature": temperature,
                "schema_name": schema_name,
                "should_judge_ground_truth": should_judge_ground_truth,
                "role": role,
            }
        }
        council_service.write_council_request(
            realized_prompt, metadata, temperature, schema_name=schema_name,
        )

#### Generate requests

In [None]:
for run in range(NUM_RUNS):
    # STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_FIRST
    generate_requests(prompt_name="so_jgt_cot1", should_judge_ground_truth=True, schema_name="reasoning_then_answer", temperature=0, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot1", should_judge_ground_truth=True, schema_name="reasoning_then_answer", temperature=1, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot1", should_judge_ground_truth=True, schema_name="reasoning_then_answer", temperature=0, role="expert", run=run)
    generate_requests(prompt_name="so_jgt_cot1", should_judge_ground_truth=True, schema_name="reasoning_then_answer", temperature=1, role="expert", run=run)
    
    # STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_COT_SECOND
    generate_requests(prompt_name="so_jgt_cot2", should_judge_ground_truth=True, schema_name="answer_then_reasoning", temperature=0, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot2", should_judge_ground_truth=True, schema_name="answer_then_reasoning", temperature=1, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot2", should_judge_ground_truth=True, schema_name="answer_then_reasoning", temperature=0, role="expert", run=run)
    generate_requests(prompt_name="so_jgt_cot2", should_judge_ground_truth=True, schema_name="answer_then_reasoning", temperature=1, role="expert", run=run)
    
    # STRUCTURED_OUTPUT_JUDGE_GROUND_TRUTH_NO_COT
    generate_requests(prompt_name="so_jgt_cot0", should_judge_ground_truth=True, schema_name="answer_only", temperature=0, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot0", should_judge_ground_truth=True, schema_name="answer_only", temperature=1, role="student", run=run)
    generate_requests(prompt_name="so_jgt_cot0", should_judge_ground_truth=True, schema_name="answer_only", temperature=0, role="expert", run=run)
    generate_requests(prompt_name="so_jgt_cot0", should_judge_ground_truth=True, schema_name="answer_only", temperature=1, role="expert", run=run)
    
    # PROMPT_JUDGE_GROUND_TRUTH_COT_FIRST
    generate_requests(prompt_name="pr_jgt_cot1", should_judge_ground_truth=True, temperature=0, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot1", should_judge_ground_truth=True, temperature=1, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot1", should_judge_ground_truth=True, temperature=0, role="expert", run=run)
    generate_requests(prompt_name="pr_jgt_cot1", should_judge_ground_truth=True, temperature=1, role="expert", run=run)
    
    # PROMPT_JUDGE_GROUND_TRUTH_COT_SECOND
    generate_requests(prompt_name="pr_jgt_cot2", should_judge_ground_truth=True, temperature=0, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot2", should_judge_ground_truth=True, temperature=1, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot2", should_judge_ground_truth=True, temperature=0, role="expert", run=run)
    generate_requests(prompt_name="pr_jgt_cot2", should_judge_ground_truth=True, temperature=1, role="expert", run=run)
    
    # PROMPT_JUDGE_GROUND_TRUTH_NO_COT
    generate_requests(prompt_name="pr_jgt_cot0", should_judge_ground_truth=True, temperature=0, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot0", should_judge_ground_truth=True, temperature=1, role="student", run=run)
    generate_requests(prompt_name="pr_jgt_cot0", should_judge_ground_truth=True, temperature=0, role="expert", run=run)
    generate_requests(prompt_name="pr_jgt_cot0", should_judge_ground_truth=True, temperature=1, role="expert", run=run)
    
    # PROMPT_ANSWER_COT_FIRST
    generate_requests(prompt_name="pr_ans_cot1", should_judge_ground_truth=False, temperature=0, run=run)
    generate_requests(prompt_name="pr_ans_cot1", should_judge_ground_truth=False, temperature=1, run=run)
    
    # PROMPT_ANSWER_COT_SECOND
    generate_requests(prompt_name="pr_ans_cot2", should_judge_ground_truth=False, temperature=0, run=run)
    generate_requests(prompt_name="pr_ans_cot2", should_judge_ground_truth=False, temperature=1, run=run)
    
    # PROMPT_ANSWER_NO_COT
    generate_requests(prompt_name="pr_ans_cot0", should_judge_ground_truth=False, temperature=0, run=run)
    generate_requests(prompt_name="pr_ans_cot0", should_judge_ground_truth=False, temperature=1, run=run)
    
    # STRUCTURED_OUTPUT_ANSWER_COT_FIRST
    generate_requests(prompt_name="so_ans_cot1", should_judge_ground_truth=False, schema_name="reasoning_then_answer", temperature=0, run=run)
    generate_requests(prompt_name="so_ans_cot1", should_judge_ground_truth=False, schema_name="reasoning_then_answer", temperature=1, run=run)
    
    # STRUCTURED_OUTPUT_ANSWER_COT_SECOND
    generate_requests(prompt_name="so_ans_cot2", should_judge_ground_truth=False, schema_name="answer_then_reasoning", temperature=0, run=run)
    generate_requests(prompt_name="so_ans_cot2", should_judge_ground_truth=False, schema_name="answer_then_reasoning", temperature=1, run=run)
    
    # STRUCTURED_OUTPUT_ANSWER_NO_COT
    generate_requests(prompt_name="so_ans_cot0", should_judge_ground_truth=False, schema_name="answer_only", temperature=0, run=run)
    generate_requests(prompt_name="so_ans_cot0", should_judge_ground_truth=False, schema_name="answer_only", temperature=1, run=run)

#### Execute Requests

In [None]:
from llm_council.invocation.execute_council import execute
import logging

# execute(requests_dir="data/mmlu_pro.n10.mini.run0")
# execute(requests_dir="data/mmlu_pro.n10.mini.run1")
# execute(requests_dir="data/mmlu_pro.n10.mini.run2")

# execute(requests_dir="data/mmlu_pro.n100.mini.run0")
# execute(requests_dir="data/mmlu_pro.n100.mini.run1")
# execute(requests_dir="data/mmlu_pro.n100.mini.run2")

logging.basicConfig(level=logging.WARNING)
for i in range(NUM_RUNS):
    execute(requests_dir=".".join([BASE_OUTDIR, f"run{i}"]), models=[MODEL])