# Datasets Prepare


In [None]:
import os
import json
from glob import glob


base_infer_path = '/evaluation/outputs/qwen/Qwen2.5-7B-SFT/math_eval'
rl_infer_path = '/evaluation/outputs/qwen/Qwen2.5-7B-DPO/math_eval'
model_name = 'qwen'

# Datasets to process
datasets = ['math', 'gsm8k', 'minerva_math', 'olympiadbench', 'college_math', 'aime24', 'amc23']
# datasets = ['math', 'minerva_math']

save_path = './datasets/processed/' + model_name + '/'

def get_unique_jsonl_file(folder):
    files = glob(os.path.join(folder, '*.jsonl'))
    if len(files) != 1:
        raise ValueError(f"No unique jsonl file found in {folder}, found: {files}")
    return files[0]

def extract_single(item):
    # If the item is a list of length 1, return the single element, otherwise return as is
    if isinstance(item, list) and len(item) == 1:
        return item[0]
    return item

for dataset in datasets:
    base_dataset_dir = os.path.join(base_infer_path, dataset)
    rl_dataset_dir = os.path.join(rl_infer_path, dataset)
    if not os.path.exists(base_dataset_dir) or not os.path.exists(rl_dataset_dir):
        print(f"Skip {dataset}, because path does not exist")
        continue

    base_jsonl = get_unique_jsonl_file(base_dataset_dir)
    rl_jsonl = get_unique_jsonl_file(rl_dataset_dir)

    # Read base and rl inference results, align by idx
    base_data = {}
    with open(base_jsonl, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            idx = item.get('idx')
            base_data[idx] = item

    rl_data = {}
    with open(rl_jsonl, 'r', encoding='utf-8') as f:
        for line in f:
            item = json.loads(line)
            idx = item.get('idx')
            rl_data[idx] = item

    # Four categories
    all_correct = []
    all_wrong = []
    base_right_rl_wrong = []
    base_wrong_rl_right = []

    for idx in base_data:
        if idx not in rl_data:
            continue
        base_item = base_data[idx]
        rl_item = rl_data[idx]

        # Extract score
        base_score = extract_single(base_item.get('score'))
        rl_score = extract_single(rl_item.get('score'))

        # Only keep true/false text
        base_score_bool = str(base_score).lower() == 'true'
        rl_score_bool = str(rl_score).lower() == 'true'

        # Assemble output fields
        out_item = {
            "idx": extract_single(idx),
            "question": extract_single(base_item.get("question")),
            "gt": extract_single(base_item.get("gt")),
            "solution": extract_single(base_item.get("solution")),
            "base_prompt_tokens": extract_single(base_item.get("prompt_token_ids")),
            "base_answer_tokens": extract_single(base_item.get("generate_token_ids")),
            "base_answer_tokens_range": [len(extract_single(base_item.get("prompt_token_ids"))), len(extract_single(base_item.get("prompt_token_ids"))) + len(extract_single(base_item.get("generate_token_ids")))],
            "base_pred": extract_single(base_item.get("pred")),
            "base_score": base_score,
            "rl_prompt_tokens": extract_single(rl_item.get("prompt_token_ids")),
            "rl_answer_tokens": extract_single(rl_item.get("generate_token_ids")),
            "rl_answer_tokens_range": [len(extract_single(rl_item.get("prompt_token_ids"))), len(extract_single(rl_item.get("prompt_token_ids"))) + len(extract_single(rl_item.get("generate_token_ids")))],
            "rl_pred": extract_single(rl_item.get("pred")),
            "rl_score": rl_score
        }

        if base_score_bool and rl_score_bool:
            all_correct.append(out_item)
        elif (not base_score_bool) and (not rl_score_bool):
            all_wrong.append(out_item)
        elif base_score_bool and (not rl_score_bool):
            base_right_rl_wrong.append(out_item)
        elif (not base_score_bool) and rl_score_bool:
            base_wrong_rl_right.append(out_item)

    # Save
    save_dir = os.path.join(save_path, dataset)
    os.makedirs(save_dir, exist_ok=True)

    with open(os.path.join(save_dir, 'all_correct.jsonl'), 'w', encoding='utf-8') as f:
        for item in all_correct:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    with open(os.path.join(save_dir, 'all_wrong.jsonl'), 'w', encoding='utf-8') as f:
        for item in all_wrong:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    with open(os.path.join(save_dir, 'base_right_rl_wrong.jsonl'), 'w', encoding='utf-8') as f:
        for item in base_right_rl_wrong:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    with open(os.path.join(save_dir, 'base_wrong_rl_right.jsonl'), 'w', encoding='utf-8') as f:
        for item in base_wrong_rl_right:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

    print(f"{dataset} processed, saved in {save_dir}")
