In [None]:
%load_ext autoreload
%autoreload 2

# Evaluate OpenAI responses

In [None]:
import os
import orjson
from pathlib import Path
from typing import Any, Optional
from tqdm.notebook import tqdm
from datasets import load_dataset, Dataset, DatasetDict

In [None]:
from dart_math.eval import EvaluatorMathBatch
from dart_math.data import RespSampleBase
from dart_math.utils import PROJ_HOME, load_jsonl, save_jsonl

In [None]:
oai_req_home: Path = Path(os.path.join(PROJ_HOME, "data/oai-batch-reqs"))
oai_resp_home: Path = Path(os.path.join(PROJ_HOME, "data/oai-batch-resps"))
oai_output_home: Path = Path(os.path.join(PROJ_HOME, "data/oai-outputs"))

DSET_ID2FIELD_MAP: dict[str, dict[str, Optional[str]]] = {
    "hendrycks/competition_math": {
        "query": "problem",
        "ref_ans": None,
        "resp": "solution",
    },
    "hkust-nlp/gsm8k-fix": {
        "query": "query",
        "ref_ans": "ans",
        "resp": "resp",
    },
}

In [None]:
evaluator = EvaluatorMathBatch()
QUERY2REF_ANS: dict[str, str] = {}
for dset_id, field_map in DSET_ID2FIELD_MAP.items():
    dset_dict: DatasetDict = load_dataset(dset_id)
    split: str
    dset: Dataset
    for split, dset in dset_dict.items():
        for sample in dset:
            query: str = sample[field_map["query"]]
            ref_ans_field: Optional[str] = field_map["ref_ans"]
            ref_ans: str
            if ref_ans_field is not None:
                ref_ans = sample[ref_ans_field]
            else:
                resp_filed: str = field_map["resp"]
                resp: str = sample[resp_filed]
                ref_ans = evaluator.extract_ans(resp)
            QUERY2REF_ANS[query.strip()] = ref_ans
print(f"{len(QUERY2REF_ANS)=}")
for query, ref_ans in QUERY2REF_ANS.items():
    print(f"E.g. {query=}, {ref_ans=}")
    break
# all_reqs: list[dict] = []
# for resp_fpath in oai_resp_home.iterdir():
#     task_name: str = resp_fpath.stem.replace("resp_", "")
#     req_fpath: Path = oai_req_home / f"req_{task_name}.jsonl"
#     all_reqs.extend(load_jsonl(req_fpath))
# uniq_dset_id_list: set[str] = set([req["custom_id"].split(":")[0] for req in all_reqs])
# for dset_id in uniq_dset_id_list:
#     split: str = dset_id.split("/")[-1]
#     hf_dset_id: str = dset_id.replace(f"/{split}", "")
#     dset: Dataset = load_dataset(hf_dset_id, split)

len(QUERY2REF_ANS)=19972
E.g. query='Let \\[f(x) = \\left\\{\n\\begin{array}{cl} ax+3, &\\text{ if }x>2, \\\\\nx-5 &\\text{ if } -2 \\le x \\le 2, \\\\\n2x-b &\\text{ if } x <-2.\n\\end{array}\n\\right.\\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).', ref_ans='0'


In [None]:
for resp_fpath in oai_resp_home.iterdir():
    with open(resp_fpath, "r") as f:
        first_line: str = f.readline()
        resp_eg: Any = orjson.loads(first_line)
        print(f"{resp_eg.keys()=}")
        print(f"{resp_eg['response'].keys()=}")
        print(f"{resp_eg['response']['body'].keys()=}")
        print(f"{resp_eg['response']['body']['choices'][0].keys()=}")
        print(f"{resp_eg=}")
    break

resp_eg.keys()=dict_keys(['id', 'custom_id', 'response', 'error'])
resp_eg['response'].keys()=dict_keys(['status_code', 'request_id', 'body'])
resp_eg['response']['body'].keys()=dict_keys(['id', 'object', 'created', 'model', 'choices', 'usage', 'system_fingerprint'])
resp_eg['response']['body']['choices'][0].keys()=dict_keys(['index', 'message', 'logprobs', 'finish_reason'])
resp_eg={'id': 'batch_req_qKtzQoV8XmkQYZypesBztb3s', 'custom_id': 'hendrycks/competition_math/train:6', 'response': {'status_code': 200, 'request_id': 'c64ef8544bd0ad51c48e7640a63c85c3', 'body': {'id': 'chatcmpl-9smqfDnbaz062CVJHPcvmA6EIexeq', 'object': 'chat.completion', 'created': 1722844513, 'model': 'gpt-4o-mini-2024-07-18', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': 'To solve the inequality \n\n\\[\n\\frac{3(pq^2 + p^2q + 3q^2 + 3pq)}{p + q} > 2p^2q\n\\]\n\nfor all \\( q > 0 \\), we start by multiplying both sides by \\( p + q \\) (which is positive for \\( q > 0 \\)) to eliminate the

In [None]:
# Collect
output_list: list[dict] = []

oai_resp_home.mkdir(parents=True, exist_ok=True)
for resp_fpath in sorted(oai_resp_home.glob("resp_*.jsonl")):
    task_name: str = resp_fpath.stem.replace("resp_", "")
    req_fpath: Path = oai_req_home / f"req_{task_name}.jsonl"
    # Prepare query information
    req_list: list[dict] = load_jsonl(req_fpath)
    custom_id2query: dict[str, str] = {}
    custom_id2sys_prompt: dict[str, str] = {}
    for req in tqdm(req_list):
        custom_id: str = req["custom_id"]
        query: str = req["body"]["messages"][-1]["content"]
        custom_id2query[custom_id] = query
        sys_prompt: str = req["body"]["messages"][0]["content"]
        custom_id2sys_prompt[custom_id] = sys_prompt
    # Collect from the response file
    # Rather large, thus slow
    resp_list: list[dict] = load_jsonl(resp_fpath, use_tqdm=True)
    for resp in resp_list:  # Rather fast
        custom_id: str = resp["custom_id"]
        query: str = custom_id2query[custom_id]
        sys_prompt: str = custom_id2sys_prompt[custom_id]
        resp_body: dict = resp["response"]["body"]
        for choice in resp_body["choices"]:
            output: RespSampleBase = RespSampleBase(
                dataset=resp["custom_id"].split(":")[0],
                query=query,
                ref_ans=QUERY2REF_ANS[query.strip()],  # Strip to unify the format
                resp=choice["message"]["content"],
                agent=resp_body["model"],
                prompt_template=sys_prompt,
            )
            output_list.append(output)

  0%|          | 0/500 [00:00<?, ?it/s]

Loading /data/user_data/yuxuanto/projects/dart-math/data/oai-batch-resps/t0.6_n2_patch/resp_math-gsm8k-sys-patch_t0.6_n2_0-500.jsonl: 100%|██████████| 500/500 [00:36<00:00, 13.64it/s]


  0%|          | 0/500 [00:00<?, ?it/s]

Loading /data/user_data/yuxuanto/projects/dart-math/data/oai-batch-resps/t0.6_n2_patch/resp_math-gsm8k-sys-patch_t0.6_n2_1000-1500.jsonl: 100%|██████████| 500/500 [00:52<00:00,  9.57it/s]


  0%|          | 0/462 [00:00<?, ?it/s]

Loading /data/user_data/yuxuanto/projects/dart-math/data/oai-batch-resps/t0.6_n2_patch/resp_math-gsm8k-sys-patch_t0.6_n2_1500-1962.jsonl: 100%|██████████| 462/462 [00:06<00:00, 70.19it/s]


  0%|          | 0/500 [00:00<?, ?it/s]

Loading /data/user_data/yuxuanto/projects/dart-math/data/oai-batch-resps/t0.6_n2_patch/resp_math-gsm8k-sys-patch_t0.6_n2_500-1000.jsonl: 100%|██████████| 500/500 [00:42<00:00, 11.82it/s]


In [None]:
# Evaluate
ans_list: list[bool]
correct_list: list[str]
ans_list, correct_list = evaluator.batch_eval(output_list)
for i, output in enumerate(output_list):
    output.ans = ans_list[i]
    output.correct = correct_list[i]
print(f"{output_list[0].to_dict()=}")

Extracting: 100%|██████████| 3924/3924 [00:02<00:00, 1907.83it/s]
Evaluating: 100%|██████████| 3924/3924 [00:37<00:00, 103.30it/s]


output_list[0].to_dict()={'dataset': 'hendrycks/competition_math/train', 'ref_ans': '[0,3)', 'ans': '(0,1.5)', 'correct': False, 'agent': 'gpt-4o-mini-2024-07-18', 'prompt_template': 'You are a helpful assistant. Solve the problem and provide the final answer in the format of "The answer is $\\boxed{...}$".', 'query': 'What are all values of $p$ such that for every $q>0$, we have   $$\\frac{3(pq^2+p^2q+3q^2+3pq)}{p+q}>2p^2q?$$ Express your answer in interval notation in decimal form.', 'resp': 'To solve the inequality \n\n\\[\n\\frac{3(pq^2 + p^2q + 3q^2 + 3pq)}{p + q} > 2p^2q\n\\]\n\nfor all \\( q > 0 \\), we start by multiplying both sides by \\( p + q \\) (which is positive for \\( q > 0 \\)) to eliminate the fraction:\n\n\\[\n3(pq^2 + p^2q + 3q^2 + 3pq) > 2p^2q(p + q).\n\\]\n\nExpanding both sides, we get:\n\nLeft side:\n\n\\[\n3pq^2 + 3p^2q + 9q^2 + 9pq.\n\\]\n\nRight side:\n\n\\[\n2p^2q^2 + 2p^3q.\n\\]\n\nRearranging the inequality gives:\n\n\\[\n3pq^2 + 3p^2q + 9q^2 + 9pq - 2p^2

In [None]:
# Save
oai_output_home.mkdir(parents=True, exist_ok=True)
output_fpath: Path = oai_output_home / f"output.jsonl"
save_jsonl([o.to_dict() for o in output_list], output_fpath)