In [None]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Prepare batch file for OpenAI Batch API

c.f. https://platform.openai.com/docs/guides/batch/1-preparing-your-batch-file

In [None]:
import orjson
import os
from datasets import load_dataset, Dataset
from typing import Any

In [None]:
dset_info_list: list[tuple[str, str]] = [
    {
        "id": "hendrycks/competition_math",
        "split": "train",
        "query_field": "problem",
    },
    {
        "id": "hkust-nlp/gsm8k-fix",
        "split": "train",
        "query_field": "query",
    },
]
temperature: float = 0.6
n_sample_per_req: int = 2
output_bs: int = 1000
batch_home: str = "../data/oai-batch-reqs"
task_tag = "math-gsm8k-sys-patch"
os.makedirs(batch_home, exist_ok=True)

In [None]:
cover_fpath_list: str = [
    "../data/oai-outputs/output_t0.0_n1.jsonl",
]

covered_prompt_list: list[str] = []
for output_fpath in cover_fpath_list:
    with open(output_fpath, "r") as f:
        for line in f:
            resp_sample: dict = orjson.loads(line)
            if resp_sample["correct"]:
                covered_prompt_list.append(resp_sample["query"].strip())
covered_prompt_set: set[str] = set(covered_prompt_list)
print(f"{len(covered_prompt_set)=}")

len(covered_prompt_set)=12782


In [None]:
todo_fpath_list: list[str] = [
    "../data/oai-outputs/output_t0.3_n2.jsonl",
]

todo_prompt_list: list[str] = []
for output_fpath in todo_fpath_list:
    with open(output_fpath, "r") as f:
        for line in f:
            resp_sample: dict = orjson.loads(line)
            query: str = resp_sample["query"].strip()
            if not resp_sample["correct"] and query not in covered_prompt_set:
                todo_prompt_list.append(query)

todo_prompt_set: set[str] = set(todo_prompt_list)
print(f"{len(todo_prompt_set)=}")

len(todo_prompt_set)=1962


In [None]:
req_list: list[dict[str, Any]] = []
for dset_info in dset_info_list:
    dset_id: str = dset_info["id"]
    dset_split: str = dset_info["split"]
    dset: Dataset = load_dataset(dset_id, split=dset_split)
    prompt_list = dset[dset_info["query_field"]]
    for prompt_idx, prompt in enumerate(prompt_list):
        if prompt.strip() not in todo_prompt_set:
            continue
        req_list.append(
            {
                "custom_id": f"{dset_info['id']}/{dset_info['split']}:{prompt_idx}",
                "method": "POST",
                "url": "/v1/chat/completions",
                "body": {
                    "model": "gpt-4o-mini-2024-07-18",
                    "messages": [
                        {
                            "role": "system",
                            "content": 'You are a helpful assistant. Solve the problem and provide the final answer in the format of "The answer is $\\boxed{...}$".',
                        },
                        {
                            "role": "user",
                            "content": prompt,
                        },
                    ],
                    "max_tokens": 2048,
                    "n": n_sample_per_req,
                    "temperature": temperature,
                    "top_p": 1.0,
                    "seed": 42,
                    # Maximize information gains
                    "logprobs": True,
                    "top_logprobs": 20,
                },
            }
        )

print(f"{len(req_list)=}")
print(f"{req_list[0]=}")

len(req_list)=1962
req_list[0]={'custom_id': 'hendrycks/competition_math/train:6', 'method': 'POST', 'url': '/v1/chat/completions', 'body': {'model': 'gpt-4o-mini-2024-07-18', 'messages': [{'role': 'system', 'content': 'You are a helpful assistant. Solve the problem and provide the final answer in the format of "The answer is $\\boxed{...}$".'}, {'role': 'user', 'content': 'What are all values of $p$ such that for every $q>0$, we have   $$\\frac{3(pq^2+p^2q+3q^2+3pq)}{p+q}>2p^2q?$$ Express your answer in interval notation in decimal form.'}], 'max_tokens': 2048, 'n': 2, 'temperature': 0.6, 'top_p': 1.0, 'seed': 42, 'logprobs': True, 'top_logprobs': 20}}


In [None]:
req_bs: int = output_bs // n_sample_per_req
for split_start_idx in range(0, len(req_list), req_bs):
    split_end_idx = min(split_start_idx + req_bs, len(req_list))
    with open(
        f"{batch_home}/req_{task_tag}_t{temperature}_n{n_sample_per_req}_{split_start_idx}-{split_end_idx}.jsonl",
        "w",
    ) as f:
        for req in req_list[split_start_idx:split_end_idx]:
            f.write(orjson.dumps(req).decode("utf-8") + "\n")