In [1]:
# Cell 1: Imports and Configuration

# !pip install -q torch==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
# !pip install -q vllm==0.8.5.post1
# !pip install --upgrade fsspec datasets

import os
import json
import torch
import re
from collections import Counter
from datasets import load_dataset
from tqdm import tqdm
from pathlib import Path
import random, os, json, re, torch
import openai
from dotenv import load_dotenv

# vLLM imports
import html, uuid, asyncio, contextlib, nest_asyncio, logging
from IPython.display import HTML, display
from huggingface_hub import snapshot_download
from vllm import TokensPrompt
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.sampling_params import SamplingParams, RequestOutputKind

load_dotenv()
nest_asyncio.apply()
torch.set_grad_enabled(False)
logging.disable(logging.INFO)

NUM_RUNS = 1
# NUM_RUNS = 16
BASE_SEED = 42

MODEL_REPO_NAME = "Qwen/Qwen3-0.6B" 
# MODEL_REPO_NAME = "jaeh8nkim/s1-slth-qwen3-0.6b"

MODEL_LOCAL_PATH = "vanilla-qwen3-0.6b-local"
# MODEL_LOCAL_PATH = "s1-slth-qwen3-0.6b-local"

SMALL_GPU_INDEX = "3"

SMALL_TEMPERATURE = 0.7
# MAX_SEQ_LEN = 8192
MAX_SEQ_LEN = 16384


INFO 06-30 11:21:36 [__init__.py:239] Automatically detected platform cuda.


In [None]:
# Cell 2: Model Download

def download_model_locally(repo_name, local_path):
    """Download model from HuggingFace and save locally"""
    print(f"📥 Downloading model from {repo_name}...")
    
    try:
        # Check if model already exists locally
        if os.path.exists(local_path):
            print(f"✅ Model already exists at {local_path}")
            return local_path
        
        # Download model using snapshot_download (same as vLLM uses)
        checkpoint_path = snapshot_download(repo_name)
        
        # Create local directory
        os.makedirs(local_path, exist_ok=True)
        
        # Copy all files from checkpoint to local path
        import shutil
        shutil.copytree(checkpoint_path, local_path, dirs_exist_ok=True)
        
        print(f"✅ Model downloaded and saved to {local_path}")
        return local_path
        
    except Exception as e:
        print(f"❌ Download failed: {e}")
        return None

# Download the model locally
print("🚀 Downloading model locally...")
model_path = download_model_locally(MODEL_REPO_NAME, MODEL_LOCAL_PATH)

if model_path is None:
    raise RuntimeError("Failed to download model")

In [2]:
# Cell 3: vLLM Engine Setup

# ---------------- utility: temporarily set visible GPUs --------------------
@contextlib.contextmanager
def visible_gpus(devices: str):
    original = os.environ.get("CUDA_VISIBLE_DEVICES", "")
    os.environ["CUDA_VISIBLE_DEVICES"] = devices
    print(f"\nCUDA_VISIBLE_DEVICES = {devices}")
    try:
        yield
    finally:
        os.environ["CUDA_VISIBLE_DEVICES"] = original

# --------------------------- engine setup ----------------------------------
async def setup_engine():
    global engine, tokenizer, vocab_size
    
    # Use the locally downloaded model
    print(f"Setting up engine with local model: {MODEL_LOCAL_PATH}")

    with visible_gpus(SMALL_GPU_INDEX):
        print("torch sees", torch.cuda.device_count(), "GPU(s)")              
        engine = AsyncLLMEngine.from_engine_args(
            AsyncEngineArgs(model=MODEL_LOCAL_PATH,  # Use local path instead of checkpoint
                            tensor_parallel_size=1,
                            max_model_len=MAX_SEQ_LEN, 
                            gpu_memory_utilization=0.90,
                            dtype="bfloat16"),
            start_engine_loop=True)
        
        tokenizer = await engine.get_tokenizer()

    # Get model config using async method
    model_config = await engine.get_model_config()
    vocab_size = model_config.get_vocab_size()
    
    print(f"Vocab size: {vocab_size}")

# --------------------------- sampling params -------------------------------
sampling_params = SamplingParams(
    max_tokens=MAX_SEQ_LEN,
    temperature=SMALL_TEMPERATURE,
    top_p=0.95,
)

# Initialize the engine
await setup_engine()

Setting up engine with local model: vanilla-qwen3-0.6b-local

CUDA_VISIBLE_DEVICES = 3
torch sees 1 GPU(s)


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Vocab size: 151936


In [3]:
# Cell 4: Answer extraction and grading functions

def extract_boxed_answer(records, tokenizer):
    """Extract the last \\boxed{} answer between tokens 151668 and 151645"""
    token_ids = [record['token_id'] for record in records]
    
    # Find positions of the tokens
    pos_151668 = [i for i, tid in enumerate(token_ids) if tid == 151668]
    pos_151645 = [i for i, tid in enumerate(token_ids) if tid == 151645]
    
    if len(pos_151668) != 1 or len(pos_151645) == 0:
        return None

    start_pos = pos_151668[0]
    end_pos = pos_151645[0]  # Take the first occurrence of 151645
    
    if start_pos >= end_pos:
        return None

    # Extract token IDs between the markers (including the end marker)
    between_token_ids = token_ids[start_pos:end_pos+1]
    
    # Decode the entire sequence at once to avoid U+FFFD issues
    between_text = tokenizer.decode(between_token_ids)
    
    # Find all \\boxed{} patterns with proper brace matching
    matches = []
    i = 0
    while i < len(between_text):
        boxed_start = between_text.find('\\boxed{', i)
        if boxed_start == -1:
            break
        
        j = boxed_start + 7  # Start after '\\boxed{'
        brace_count = 1
        while j < len(between_text) and brace_count > 0:
            if between_text[j] == '{':
                brace_count += 1
            elif between_text[j] == '}':
                brace_count -= 1
            j += 1
        
        if brace_count == 0:
            matches.append(between_text[boxed_start + 7:j-1])
        
        i = boxed_start + 1
    
    return matches[-1] if matches else None

def llm_grader(expected_answer, boxed_answer, openai_client, model_name="gpt-4o-mini"):

    def grader_prompt(expected_answer, boxed_answer):
        """Creates the system and user prompts for grading."""
        system_prompt = (
            f"You are an expert grader tasked with evaluating the correctness of an answer.\n"
            f"You will be provided with two pieces of text: the expected answer and the generated answer.\n"
            f"Your task is to determine if the generated answer is semantically equivalent to the expected answer.\n"
            f"Ignore minor formatting differences, extra whitespace, or trivial variations. For numerical answers, consider equivalent representations as correct (e.g., '1/2' and '0.5').\n"
            f"Respond with exactly one word: either 'true' (if correct) or 'false' (if incorrect). Do not include quotation marks, explanations, or any other text.\n"
        )
        user_prompt = (
            f"Expected answer:\n"
            f"{expected_answer}\n"
            f"Generated answer:\n"
            f"{boxed_answer}\n"
        )
        
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
        return messages
    
    def grader(grading_messages, openai_client, model_name):
        api_response = openai_client.chat.completions.create(
            model=model_name,
            messages=grading_messages
        ).choices[0].message.content
        
        grade = api_response.strip().lower()
        return grade
    
    grading_messages = grader_prompt(expected_answer, boxed_answer)
    grade = grader(grading_messages, openai_client, model_name)
    
    # Ensure the grade is exactly 'true' or 'false'
    if grade in ['true', 'false']:
        return grade
    else:
        # Fallback in case the API returns something unexpected
        return 'false'



In [4]:
# Cell 5: Evaluation functions

OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
client = openai.OpenAI(api_key=OPENAI_API_KEY)

async def graded_is_correct(gold, pred, tokenizer):
    # Convert generated text into token-records so extract_boxed_answer works
    ids = tokenizer.encode(pred)
    records = [{"token_id": t} for t in ids]

    boxed = extract_boxed_answer(records, tokenizer)
    extracted = boxed if boxed else pred

    return llm_grader(gold, extracted, client) == "true"

def print_dataset_info(dataset, task_name):
    """Print dataset count and first 5 examples"""
    print(f"\n--- {task_name.upper()} DATASET INFO ---")
    print(f"Total samples: {len(dataset)}")
    print(f"Dataset type: {type(dataset)}")
    
    # Check the first item to understand the structure
    if len(dataset) > 0:
        first_item = dataset[0]
        print(f"First item type: {type(first_item)}")
        print(f"First item keys: {list(first_item.keys()) if hasattr(first_item, 'keys') else 'No keys'}")
    
    print(f"First 5 examples:")
    
    for i in range(min(5, len(dataset))):
        item = dataset[i]  # Access by index instead of iteration
        print(f"\n{i+1}. ", end="")
        
        try:
            if "math" in task_name:
                question = item.get("problem", item.get("question", ""))
                answer = item.get("solution", "")
                print(f"Question: {question[:200]}...")
                print(f"   Answer: {answer[:200]}...")
            elif "gpqa" in task_name:
                # Handle both original GPQA format and alternative formats
                question = item.get("Question", item.get("question", item.get("Problem", "")))
                answer = item.get("Correct Answer", item.get("correct_answer", item.get("Answer", "")))
                print(f"Question: {question[:200]}...")
                print(f"   Answer: {answer}")
            else:  # AIME
                question = item.get("problem", item.get("Problem", item.get("question", "")))
                answer = str(item.get("answer", item.get("Answer", "")))
                print(f"Question: {question[:200]}...")
                print(f"   Answer: {answer}")
        except Exception as e:
            print(f"Error displaying item: {e}")
            print(f"Item keys: {list(item.keys()) if hasattr(item, 'keys') else 'Not a dict'}")
            print(f"Item type: {type(item)}")
            print(f"Raw item: {str(item)[:200]}...")
            break

async def evaluate_problem_multiple_times(item, task_name, num_runs):
    """
    Evaluate a single problem multiple times and return accuracy for that problem.
    """
    global engine, tokenizer
    correct = 0
    
    # Extract question and gold answer based on task type
    try:
        if "math" in task_name:
            question = item.get("problem", item.get("question", ""))
            gold = item.get("solution", "")
        elif "gpqa" in task_name:
            # Handle both original GPQA format and alternative formats
            question = item.get("Question", item.get("question", item.get("Problem", "")))
            
            # Try different field names for correct answer
            gold = item.get("Correct Answer", item.get("correct_answer", item.get("Answer", "")))
            
            # Handle choices if they exist
            choices = []
            if "Incorrect Answer 1" in item:
                # Original format
                choices = [
                    item["Incorrect Answer 1"],
                    item["Incorrect Answer 2"],
                    item["Incorrect Answer 3"],
                    item["Correct Answer"],
                ]
                random.shuffle(choices)
                gold = chr(65 + choices.index(item["Correct Answer"]))
                question += "\n\nChoices:\n" + "\n".join(
                    f"{chr(65+i)}. {c}" for i, c in enumerate(choices)
                )
            elif any(f"choice_{i}" in item for i in ['A', 'B', 'C', 'D']):
                # Alternative choice format
                choices = [item.get(f"choice_{i}", "") for i in ['A', 'B', 'C', 'D']]
                question += "\n\nChoices:\n" + "\n".join(
                    f"{chr(65+i)}. {c}" for i, c in enumerate(choices) if c
                )
                # Find which choice is correct
                for i, choice in enumerate(['A', 'B', 'C', 'D']):
                    if item.get(f"choice_{choice}", "") == gold:
                        gold = choice
                        break
            
        else:  # AIME
            question = item.get("problem", item.get("Problem", item.get("question", "")))
            gold = str(item.get("answer", item.get("Answer", "")))

        system_prompt = (
            f"You are Qwen, created by Alibaba Cloud. You are a helpful assistant.\n"
            f"You must respond to every query in the following manner:\n"
            f"First, provide a step-by-step logical exploration of the problem.\n"
            f"Then, provide a clear and direct response based on your reasoning, with the final answer enclosed in \\boxed{{}}."
        )

        input = (
            f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
            f"<|im_start|>user\n{question}<|im_end|>\n"
            f"<|im_start|>assistant\n<think>"
        )
        
        # Run the problem multiple times
        for run in range(num_runs):
            random.seed(BASE_SEED + run)
            torch.manual_seed(BASE_SEED + run)
            
            # Generate with vLLM (using working pattern)
            request_id = str(uuid.uuid4())
            generator = engine.generate(input, sampling_params, request_id)
            
            # Get the result using working pattern
            final_output = None
            async for request_output in generator:
                final_output = request_output
            
            if final_output and final_output.outputs:
                predicted = final_output.outputs[0].text.strip()
                if await graded_is_correct(gold, predicted, tokenizer):
                    correct += 1
                
    except Exception as e:
        print(f"Error processing problem: {e}")
        return 0
    
    return correct / num_runs

async def evaluate_model_average(num_runs=NUM_RUNS):
    """
    Evaluate a model on multiple datasets with multiple runs.
    """
    print(f"Using vLLM engine with {MODEL_REPO_NAME}...")

    # Reordered datasets as requested: aime24, aime25, gpqa diamond, math500
    dataset_configs = {
        "aime_2024": ("HuggingFaceH4/aime_2024", "train", None),
        "aime_2025": ("opencompass/AIME2025", "test", "AIME2025-I"),  # Fixed split to "test"
        "gpqa_diamond": ("spawn99/GPQA-diamond-ClaudeR1", "train", None),
        "math_500": ("HuggingFaceH4/MATH-500", "test", None),
    }

    # First, load all datasets and print their info
    datasets = {}
    print("\n" + "="*60)
    print(" LOADING ALL DATASETS ")
    print("="*60)
    
    for task, (repo, split, config) in dataset_configs.items():
        try:
            print(f"Loading {task} dataset from {repo}...")
            if config:
                ds = load_dataset(repo, config, split=split, trust_remote_code=True)
            else:
                ds = load_dataset(repo, split=split, trust_remote_code=True)
            
            datasets[task] = ds
            print(f"✅ {task} loaded successfully")
            
        except Exception as e:
            print(f"❌ Error loading {task}: {e}")
            print(f"Skipping {task} dataset...")
            continue
    
    # Print info for all loaded datasets
    print("\n" + "="*60)
    print(" DATASET INFORMATION ")
    print("="*60)
    
    for task, ds in datasets.items():
        print_dataset_info(ds, task)
    
    # Now run evaluations
    print("\n" + "="*60)
    print(" STARTING EVALUATIONS ")
    print("="*60)
    
    all_results = {}
    
    for task, ds in datasets.items():
        print(f"\n🔄 Evaluating {task.upper()}...")
        
        problem_accuracies = []
        
        # Iterate through each problem
        for i in tqdm(range(len(ds)), desc=f"{task} problems"):
            item = ds[i]
            problem_accuracy = await evaluate_problem_multiple_times(item, task, num_runs)
            problem_accuracies.append(problem_accuracy)
            
            # Determine status
            correct_runs = int(problem_accuracy * num_runs)  # Convert back to count
            if correct_runs > 0:
                status_emoji = "✅"
            else:
                status_emoji = "❌"
            
            # Show individual problem results
            print(f"{status_emoji} Problem {i+1:02d}/{len(ds)} — {task}: {correct_runs}/{num_runs}")

        average_accuracy = sum(problem_accuracies) / len(problem_accuracies)
        all_results[task] = {"average_accuracy": average_accuracy, "problem_accuracies": problem_accuracies}
        
        # Final summary for this dataset
        total_runs = len(ds) * num_runs
        total_correct_runs = sum(int(acc * num_runs) for acc in problem_accuracies)
        print(f"✅ {task} complete - {total_correct_runs}/{total_runs} ({average_accuracy:.2%} accuracy)")

    return all_results

In [None]:
# Cell 6: Main execution

results = await evaluate_model_average(NUM_RUNS)

print("\n" + "="*50)
print(" FINAL AVERAGED ACCURACIES ")
print("="*50)
for task, result in results.items():
    print(f"{task.upper():<15}: {result['average_accuracy']:.2%}")
print("="*50)

Using vLLM engine with Qwen/Qwen3-0.6B...

 LOADING ALL DATASETS 
Loading aime_2024 dataset from HuggingFaceH4/aime_2024...
✅ aime_2024 loaded successfully
Loading aime_2025 dataset from opencompass/AIME2025...
✅ aime_2025 loaded successfully
Loading gpqa_diamond dataset from spawn99/GPQA-diamond-ClaudeR1...
✅ gpqa_diamond loaded successfully
Loading math_500 dataset from HuggingFaceH4/MATH-500...
✅ math_500 loaded successfully

 DATASET INFORMATION 

--- AIME_2024 DATASET INFO ---
Total samples: 30
Dataset type: <class 'datasets.arrow_dataset.Dataset'>
First item type: <class 'dict'>
First item keys: ['id', 'problem', 'solution', 'answer', 'url', 'year']
First 5 examples:

1. Question: Every morning Aya goes for a $9$-kilometer-long walk and stops at a coffee shop afterwards. When she walks at a constant speed of $s$ kilometers per hour, the walk takes her 4 hours, including $t$ min...
   Answer: 204

2. Question: Let $ABC$ be a triangle inscribed in circle $\omega$. Let the tangents 

aime_2024 problems:   3%|▎         | 1/30 [01:26<41:41, 86.27s/it]

❌ Problem 01/30 — aime_2024: 0/1


aime_2024 problems:   7%|▋         | 2/30 [02:52<40:14, 86.24s/it]

❌ Problem 02/30 — aime_2024: 0/1


aime_2024 problems:  10%|█         | 3/30 [04:18<38:46, 86.15s/it]

❌ Problem 03/30 — aime_2024: 0/1


aime_2024 problems:  13%|█▎        | 4/30 [04:58<29:22, 67.80s/it]

❌ Problem 04/30 — aime_2024: 0/1


aime_2024 problems:  17%|█▋        | 5/30 [06:17<29:56, 71.84s/it]

❌ Problem 05/30 — aime_2024: 0/1


aime_2024 problems:  20%|██        | 6/30 [07:43<30:44, 76.84s/it]

❌ Problem 06/30 — aime_2024: 0/1


aime_2024 problems:  23%|██▎       | 7/30 [09:00<29:27, 76.87s/it]

❌ Problem 07/30 — aime_2024: 0/1


aime_2024 problems:  27%|██▋       | 8/30 [10:25<29:07, 79.42s/it]

❌ Problem 08/30 — aime_2024: 0/1


aime_2024 problems:  30%|███       | 9/30 [10:45<21:20, 60.95s/it]

❌ Problem 09/30 — aime_2024: 0/1


aime_2024 problems:  33%|███▎      | 10/30 [11:42<19:53, 59.69s/it]

❌ Problem 10/30 — aime_2024: 0/1


aime_2024 problems:  37%|███▋      | 11/30 [13:12<21:48, 68.87s/it]

❌ Problem 11/30 — aime_2024: 0/1


aime_2024 problems:  40%|████      | 12/30 [14:38<22:16, 74.23s/it]

✅ Problem 12/30 — aime_2024: 1/1


aime_2024 problems:  43%|████▎     | 13/30 [14:52<15:51, 55.97s/it]

❌ Problem 13/30 — aime_2024: 0/1


aime_2024 problems:  47%|████▋     | 14/30 [15:44<14:33, 54.60s/it]

❌ Problem 14/30 — aime_2024: 0/1


aime_2024 problems:  50%|█████     | 15/30 [17:09<15:59, 63.95s/it]

❌ Problem 15/30 — aime_2024: 0/1


aime_2024 problems:  53%|█████▎    | 16/30 [18:03<14:11, 60.83s/it]

❌ Problem 16/30 — aime_2024: 0/1


aime_2024 problems:  57%|█████▋    | 17/30 [19:30<14:52, 68.62s/it]

❌ Problem 17/30 — aime_2024: 0/1


aime_2024 problems:  60%|██████    | 18/30 [20:24<12:50, 64.23s/it]

❌ Problem 18/30 — aime_2024: 0/1


aime_2024 problems:  63%|██████▎   | 19/30 [21:52<13:05, 71.41s/it]

❌ Problem 19/30 — aime_2024: 0/1


aime_2024 problems:  67%|██████▋   | 20/30 [23:19<12:39, 75.97s/it]

❌ Problem 20/30 — aime_2024: 0/1


aime_2024 problems:  70%|███████   | 21/30 [24:09<10:15, 68.38s/it]

❌ Problem 21/30 — aime_2024: 0/1


aime_2024 problems:  73%|███████▎  | 22/30 [24:45<07:48, 58.62s/it]

❌ Problem 22/30 — aime_2024: 0/1


aime_2024 problems:  77%|███████▋  | 23/30 [26:05<07:35, 65.08s/it]

❌ Problem 23/30 — aime_2024: 0/1


aime_2024 problems:  80%|████████  | 24/30 [27:04<06:18, 63.05s/it]

❌ Problem 24/30 — aime_2024: 0/1


aime_2024 problems:  83%|████████▎ | 25/30 [27:21<04:06, 49.33s/it]

❌ Problem 25/30 — aime_2024: 0/1


aime_2024 problems:  87%|████████▋ | 26/30 [28:47<04:01, 60.39s/it]

❌ Problem 26/30 — aime_2024: 0/1


aime_2024 problems:  90%|█████████ | 27/30 [29:40<02:54, 58.17s/it]

❌ Problem 27/30 — aime_2024: 0/1


aime_2024 problems:  93%|█████████▎| 28/30 [31:06<02:13, 66.52s/it]

❌ Problem 28/30 — aime_2024: 0/1


aime_2024 problems:  97%|█████████▋| 29/30 [32:16<01:07, 67.65s/it]

❌ Problem 29/30 — aime_2024: 0/1


aime_2024 problems: 100%|██████████| 30/30 [32:31<00:00, 65.05s/it]


❌ Problem 30/30 — aime_2024: 0/1
✅ aime_2024 complete - 1/30 (3.33% accuracy)

🔄 Evaluating AIME_2025...


aime_2025 problems:   7%|▋         | 1/15 [00:17<04:08, 17.73s/it]

✅ Problem 01/15 — aime_2025: 1/1


aime_2025 problems:  13%|█▎        | 2/15 [01:24<10:04, 46.47s/it]

❌ Problem 02/15 — aime_2025: 0/1


aime_2025 problems:  20%|██        | 3/15 [01:59<08:13, 41.16s/it]

✅ Problem 03/15 — aime_2025: 1/1


aime_2025 problems:  27%|██▋       | 4/15 [02:26<06:30, 35.54s/it]

❌ Problem 04/15 — aime_2025: 0/1


aime_2025 problems:  33%|███▎      | 5/15 [03:04<06:07, 36.73s/it]

❌ Problem 05/15 — aime_2025: 0/1


aime_2025 problems:  40%|████      | 6/15 [03:46<05:46, 38.48s/it]

❌ Problem 06/15 — aime_2025: 0/1


aime_2025 problems:  47%|████▋     | 7/15 [04:46<06:02, 45.28s/it]

❌ Problem 07/15 — aime_2025: 0/1


aime_2025 problems:  53%|█████▎    | 8/15 [05:36<05:28, 46.91s/it]

❌ Problem 08/15 — aime_2025: 0/1


aime_2025 problems:  60%|██████    | 9/15 [06:30<04:54, 49.05s/it]

✅ Problem 09/15 — aime_2025: 1/1


aime_2025 problems:  67%|██████▋   | 10/15 [07:27<04:17, 51.47s/it]

❌ Problem 10/15 — aime_2025: 0/1


aime_2025 problems:  73%|███████▎  | 11/15 [08:30<03:40, 55.04s/it]

❌ Problem 11/15 — aime_2025: 0/1
