In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '7'
from transformers import AutoTokenizer, AutoModelForCausalLM
from vllm import LLM

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-4B")
# Qwen/Qwen3-4B
# deepseek-ai/DeepSeek-R1-Distill-Qwen-7B

In [None]:
import json
data = []
with open('rollouts.jsonl', 'r') as f:
    for line in f:
        data.append(json.loads(line))
        

In [None]:
import numpy as np

accuracy_by_source = {}
for line in data:
    if line['data_source'] not in accuracy_by_source:
        if 'bbh_' in line['data_source']:
            line['data_source'] = 'bbh'
        accuracy_by_source[line['data_source']] = []
    accuracy_by_source[line['data_source']].append(line['final_reward'])
    
for source in accuracy_by_source:
    mean_accuracy = np.mean(accuracy_by_source[source])
    print(source, mean_accuracy, len(accuracy_by_source[source]))
    
    

In [29]:
import re
from math_verify.errors import TimeoutException
from math_verify.metric import math_metric
from math_verify.parser import ExprExtractionConfig, LatexExtractionConfig

def clean_prediction(model_output: str) -> str:
    """
    Try to normalize model outputs like 'Answer: C. Cotton' into '\boxed{C}'.
    """

    # If boxed answer is already there, keep it
    boxed_match = re.search(r'\\boxed\{([A-Za-z0-9]+)\}', model_output)
    if boxed_match:
        return f"\\boxed{{{boxed_match.group(1)}}}"

    # Otherwise, try to find a single-letter multiple-choice answer
    choice_match = re.search(r'\b([A-E])\b', model_output)
    if choice_match:
        return f"\\boxed{{{choice_match.group(1)}}}"

    # If nothing found, just return as-is
    return model_output.strip()

def getRawCorrectness(model_output: str, ground_truth: str) -> float:
    verify_func = math_metric(
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    )

    ret_score = 0.0
    ground_truth_boxed = f"\\boxed{{{ground_truth}}}"

    try:
        pred_clean = clean_prediction(model_output)
        ret_score, _ = verify_func([ground_truth_boxed], [pred_clean])
    except Exception:
        pass

    return ret_score

def getCorrectness(model_output: str, ground_truth: str) -> float:
    if '</think>' in model_output:
        model_output = model_output.split('</think>')[-1]

    verify_func = math_metric(
        gold_extraction_target=(LatexExtractionConfig(),),
        pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
    )

    ret_score = 0.0
    ground_truth_boxed = f"\\boxed{{{ground_truth}}}"

    try:
        pred_clean = clean_prediction(model_output)
        ret_score, _ = verify_func([ground_truth_boxed], [pred_clean])
    except Exception:
        pass

    return ret_score


In [None]:
import numpy as np
correctness_response_length_list = {}
from tqdm import tqdm
for line in tqdm(data):
    if line['data_source'] not in correctness_response_length_list:
        correctness_response_length_list[line['data_source']] = []
    question = line['prompt']
    correctness_response_length_list[line['data_source']].append((question, getCorrectness(line['response'], line['ground_truth']), len(tokenizer.encode(line['response']))))

            

In [32]:
def compute_aucoaa(data):
    # data is a list of (correctness, response_length) tuples
    n = len(data)
    tmax = max(length for _,_, length in data)
    oaa_t_list = []
    for t in range(tmax + 1):
        filtered = [correct for _,correct, length in data if length < t]
        if filtered:
            oaa_t = sum(filtered) / n  # Denominator is always n, per definition
        else:
            oaa_t = 0.0
        oaa_t_list.append(oaa_t)
    aucoaa = sum(oaa_t_list) / len(oaa_t_list)
    return aucoaa


In [33]:
import numpy as np
from collections import defaultdict

def compute_accuracy_from_runs(prompts, accuracies, n_runs=5):
    """
    prompts: list of prompt IDs (each repeated n_runs times)
    accuracies: list of 0/1, same length as prompts
    n_runs: number of answers per prompt (default 5)
    
    Returns:
        dict with per-answer, per-prompt, and per-run accuracy + std
    """
    assert len(prompts) == len(accuracies), "Both lists must have the same length"

    accuracies = np.array(accuracies)
    total_answers = len(accuracies)

    # ------------------
    # Per-answer accuracy (binomial estimate)
    # ------------------
    acc_answer = accuracies.mean()
    std_answer = np.sqrt(acc_answer * (1 - acc_answer) / total_answers)

    # ------------------
    # Per-prompt accuracy
    # ------------------
    prompt_correct = defaultdict(list)
    for p, a in zip(prompts, accuracies):
        prompt_correct[p].append(a)

    prompt_solved = np.array([np.mean(v) for v in prompt_correct.values()])
    acc_prompt = prompt_solved.mean()
    std_prompt = prompt_solved.std(ddof=1)

    # ------------------
    # Per-run accuracy (like papers report)
    # ------------------
    # initialize n_runs lists (one per run)
    run_correct = [[] for _ in range(n_runs)]

    for p in set(prompts):
        answers = prompt_correct[p]   # length = n_runs
        for i, a in enumerate(answers):
            run_correct[i].append(a)

    run_accs = np.array([np.mean(r) for r in run_correct])
    acc_run = run_accs.mean()
    std_run = run_accs.std(ddof=1)

    return {
        "per_run_accuracy": acc_run,
        "per_run_std": std_run,
    }


In [None]:
for key in correctness_response_length_list.keys():
    print("Data Source: ", key)
    prompt_list = [prompt for prompt, _, _ in correctness_response_length_list[key]]
    accuracy_list = [correct for _, correct, _ in correctness_response_length_list[key]]
    response_length_list = [length for _, _, length in correctness_response_length_list[key]]

    result = compute_accuracy_from_runs(prompt_list, accuracy_list)
    print(f"Accuracy: {result['per_run_accuracy']}")
    print(f"Std: {result['per_run_std']}")
    print(f"Response Length: {np.mean(response_length_list)}")
    print(f"Max Response Length: {max(response_length_list)}")
    print(f"Min Response Length: {min(response_length_list)}")

    print("AUC-OAA: ", compute_aucoaa(correctness_response_length_list[key]))
    print("--------------------------------")
    
    # print(result)
    