In [47]:
import pandas as pd
import sqlite3
from pathlib import Path
from tqdm.notebook import tqdm
import io
import contextlib
import json
import time
from datetime import datetime

import anyio
import httpx
from utils import LLMPool, RuntimeConfig

In [48]:
# Configuration
API_KEY = "sk-local"
MODEL_NAME = "gpt-oss"
SYSTEM_MESSAGE = "You are a mathematician writing Python code to solve problems."
NGINX_BALANCER_URL = "http://127.0.0.1:8080/v1"

TEMPERATURE = 1
MAX_TOKENS = 1024 * 20
REASONING_EFFORT = "low"

OUTPUT_BASE_DIR = Path('/home/larcanio/AIMO3_v2/data/datasets/mvidia_reasoning_steps/DeepSeek-R1')
PROMPT_VERSION = "v1"
MAX_PROBLEMS_TO_PROCESS = None

# Hot-reloadable settings (edit config.json during run)
CONFIG_FILE = "config.json"

cfg = RuntimeConfig(CONFIG_FILE, defaults={
    "MAX_EXECUTION_RETRIES": 3,
    "MAX_ANSWER_RETRIES": 3,
    "LLM_REQUEST_RETRY_COUNT": 3,
    "LLM_REQUEST_TIMEOUT_SECONDS": 300,
    "EXECUTION_TIMEOUT_SECONDS": 30,
    "MAX_CONCURRENT_REQUESTS": 5,
})
print(f"Configuration loaded: {cfg}")

[config] reloaded: MAX_CONCURRENT_REQUESTS: 5 -> 30
RuntimeConfig(MAX_EXECUTION_RETRIES=3, MAX_ANSWER_RETRIES=3, LLM_REQUEST_RETRY_COUNT=3, LLM_REQUEST_TIMEOUT_SECONDS=300, EXECUTION_TIMEOUT_SECONDS=30, MAX_CONCURRENT_REQUESTS=30)


In [49]:
# # Create SQLite database
# db_path = Path('/home/larcanio/AIMO3_v2/mvidia_reasoning_steps.db')
# db_path.parent.mkdir(parents=True, exist_ok=True)

# # Connect to database
# conn = sqlite3.connect(str(db_path))
# cursor = conn.cursor()

# # Create table with schema based on dataset description
# cursor.execute('''
# CREATE TABLE IF NOT EXISTS reasoning_steps (
#     id INTEGER PRIMARY KEY AUTOINCREMENT,
#     problem TEXT NOT NULL,
#     generated_solution TEXT,
#     generation_model TEXT,
#     problem_type TEXT,
#     expected_answer TEXT,
#     problem_source TEXT,
#     inference_mode TEXT,
#     pass_rate_72b_tir TEXT,
#     used_in_kaggle TEXT,
#     created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
# )
# ''') 

# # Create index on commonly queried fields
# cursor.execute('CREATE INDEX IF NOT EXISTS idx_problem_type ON reasoning_steps(problem_type)')
# cursor.execute('CREATE INDEX IF NOT EXISTS idx_generation_model ON reasoning_steps(generation_model)')
# cursor.execute('CREATE INDEX IF NOT EXISTS idx_inference_mode ON reasoning_steps(inference_mode)')
# cursor.execute('CREATE INDEX IF NOT EXISTS idx_problem_source ON reasoning_steps(problem_source)')
# cursor.execute('CREATE INDEX IF NOT EXISTS idx_used_in_kaggle ON reasoning_steps(used_in_kaggle)')

# conn.commit()
# print("Database and table created successfully")
# # 
# # Load all parquet files from data/nvidia/data
# data_dir = Path('/home/larcanio/AIMO3_v2/data/nvidia/data')
# parquet_files = sorted(glob.glob(str(data_dir / '*.parquet')))

# print(f"\nFound {len(parquet_files)} parquet files")

# # Load and insert all parquet files into SQLite database
# total_inserted = 0
# for parquet_file in tqdm(parquet_files, desc="Processing files", total=len(parquet_files)):
#     df = pd.read_parquet(parquet_file)
    
#     # Map column names to database schema
#     # Handle different possible column names in parquet files
#     insert_data = []
#     for _, row in tqdm(df.iterrows(), desc="Processing rows", total=len(df  )):
#         # Map common column variations to our schema
#         record = {
#             'problem': row.get('problem') or row.get('question', ''),
#             'generated_solution': row.get('generated_solution') or row.get('solution') or row.get('answer', ''),
#             'generation_model': row.get('generation_model', ''),
#             'problem_type': row.get('problem_type', ''),
#             'expected_answer': row.get('expected_answer', ''),
#             'problem_source': row.get('problem_source', ''),
#             'inference_mode': row.get('inference_mode', ''),
#             'pass_rate_72b_tir': row.get('pass_rate_72b_tir', 'n/a'),
#             'used_in_kaggle': row.get('used_in_kaggle', '')
#         }
#         insert_data.append(record)
    
#     # Insert data into database
#     cursor.executemany('''
#         INSERT INTO reasoning_steps 
#         (problem, generated_solution, generation_model, problem_type, 
#          expected_answer, problem_source, inference_mode, pass_rate_72b_tir, used_in_kaggle)
#         VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
#     ''', [
#         (r['problem'], r['generated_solution'], r['generation_model'], r['problem_type'],
#          r['expected_answer'], r['problem_source'], r['inference_mode'], r['pass_rate_72b_tir'], r['used_in_kaggle'])
#         for r in insert_data
#     ])
    
#     inserted_count = len(insert_data)
#     total_inserted += inserted_count
#     conn.commit()
#     print(f"Inserted {inserted_count} rows from {Path(parquet_file).name}")

# print(f"\nTotal inserted: {total_inserted} examples")

# # Verify the data
# cursor.execute('SELECT COUNT(*) FROM reasoning_steps')
# total_count = cursor.fetchone()[0]
# print(f"Total records in database: {total_count}")

# # Show sample entry
# cursor.execute('SELECT * FROM reasoning_steps LIMIT 1')
# sample = cursor.fetchone()
# if sample:
#     columns = [description[0] for description in cursor.description]
#     print(f"\nSample entry:")
#     for col, val in zip(columns, sample):
#         if val and len(str(val)) > 200:
#             print(f"  {col}: {str(val)[:200]}...")
#         else:
#             print(f"  {col}: {val}")

# # Show column names
# cursor.execute('PRAGMA table_info(reasoning_steps)')
# columns_info = cursor.fetchall()
# print(f"\nDatabase columns:")
# for col_info in columns_info:
#     print(f"  {col_info[1]} ({col_info[2]})")

# conn.close()
# print(f"\nDatabase saved to: {db_path}")

In [50]:
db_path='/home/larcanio/AIMO3_v2/mvidia_reasoning_steps.db'
conn = sqlite3.connect(str(db_path))

In [51]:
def query_database(query, connection=None):
    if connection is None:
        connection = conn
    df = pd.read_sql_query(query, connection)
    return df

# Query database for problems with numeric answers
query = """
SELECT *
FROM reasoning_steps
  WHERE expected_answer GLOB '[0-9]*'
  AND expected_answer NOT GLOB '*[^0-9]*'
  AND CAST(expected_answer AS INTEGER) BETWEEN 0 AND 99999
  AND (
        expected_answer = '0'
        OR expected_answer NOT LIKE '0%'
      )
  AND problem_type = 'has_answer_extracted'
LIMIT 200000;
"""
problems = query_database(query)
print(f"Loaded {len(problems)} problems from database")

display(problems.head(2))
display(problems.tail(2))


Unnamed: 0,id,problem,generated_solution,generation_model,problem_type,expected_answer,problem_source,inference_mode,pass_rate_72b_tir,used_in_kaggle,created_at
0,161902,Let $f$ be the function defined by $f(x) = -2 ...,,,has_answer_extracted,61,MATH_training_set,,,0,2026-01-28 13:10:28
1,161911,A block of wood has the shape of a right circu...,,,has_answer_extracted,53,MATH_training_set,,,0,2026-01-28 13:10:28


Unnamed: 0,id,problem,generated_solution,generation_model,problem_type,expected_answer,problem_source,inference_mode,pass_rate_72b_tir,used_in_kaggle,created_at
199998,1145940,Find the least positive integer \( n \) such t...,"<think>\nOkay, so I need to find the smallest ...",DeepSeek-R1,has_answer_extracted,4,aops_c6_high_school_olympiads,cot,0.96875,1,2026-01-28 13:12:35
199999,1145957,Students 1 through 9 are assigned cards number...,"<think>\nOkay, so I need to figure out how man...",DeepSeek-R1,has_answer_extracted,77760,aops_c4_high_school_math,cot,0.59375,0,2026-01-28 13:12:35


In [52]:
# Domain Filtering
from domain_classifier_2nd_stage import (
    extract_fields,
    compute_all_scores,
    get_heuristic_ranking,
    get_hard_override,
    ALLOWED_DOMAINS,
)

# Set to None to disable filtering, or specify domains to keep
# Options: "algebra", "number_theory", "combinatorics", "geometry"
FILTER_DOMAINS = ["number_theory"]


def classify_problem_domain(problem_text: str) -> str:
    """Classify a problem's domain using heuristic patterns."""
    payload = {
        "problem": {"text": problem_text},
        "code": "",
        "plan": "",
        "goal": "",
    }
    
    fields = extract_fields(payload)
    scores = compute_all_scores(fields)
    
    forced = get_hard_override(fields)
    if forced:
        return forced
    
    heur_best, _, _ = get_heuristic_ranking(scores)
    return heur_best


if FILTER_DOMAINS:
    invalid = set(FILTER_DOMAINS) - ALLOWED_DOMAINS
    if invalid:
        raise ValueError(f"Invalid domain names: {invalid}. Allowed: {ALLOWED_DOMAINS}")
    
    pre_filter_count = len(problems)
    print(f"Filtering by domains: {FILTER_DOMAINS}")
    print(f"Problems before filter: {pre_filter_count}")
    
    domain_classifications = []
    for _, row in tqdm(problems.iterrows(), total=len(problems), desc="Classifying"):
        domain = classify_problem_domain(row['problem'])
        domain_classifications.append(domain)
    
    problems['_classified_domain'] = domain_classifications
    
    print("\nDomain distribution:")
    print(problems['_classified_domain'].value_counts())
    
    problems = problems[problems['_classified_domain'].isin(FILTER_DOMAINS)].copy()
    problems = problems.drop(columns=['_classified_domain'])
    
    post_filter_count = len(problems)
    excluded = pre_filter_count - post_filter_count
    print(f"Excluded {excluded} problems | Remaining: {post_filter_count}")
else:
    print("Domain filtering disabled")

problems = problems[:5000]

Filtering by domains: ['number_theory']
Problems before domain filter: 200000


Classifying domains:   0%|          | 0/200000 [00:00<?, ?it/s]


Domain distribution (before filter):
_classified_domain
algebra          118237
combinatorics     34799
geometry          27617
number_theory     19347
Name: count, dtype: int64

Excluded 180653 problems outside target domains
Problems after domain filter: 19347


In [53]:
# Deduplication against reference dataset
REFERENCE_JSONL = Path('/home/larcanio/AIMO3_v2/data/datasets/mvidia_reasoning_steps/Compiled-OpenMath/dataset.jsonl')

initial_count = len(problems)
print(f"Problems from query: {initial_count}")

if REFERENCE_JSONL and REFERENCE_JSONL.exists():
    existing_problem_ids = set()
    with open(REFERENCE_JSONL, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                record = json.loads(line)
                pid = record.get('problem_id') or record.get('problem', {}).get('problem_id')
                if pid is not None:
                    existing_problem_ids.add(str(pid))
    
    print(f"Found {len(existing_problem_ids)} existing problems in reference")
    
    problems = problems[~problems['id'].astype(str).isin(existing_problem_ids)]
    filtered_count = len(problems)
    excluded_count = initial_count - filtered_count
    print(f"Excluded {excluded_count} duplicates | Remaining: {filtered_count}")
else:
    print("Reference file not found or not specified, skipping deduplication")

print(f"Final problem count: {len(problems)}")

Problems from database query: 5000
Found 35570 existing problem_ids in reference file
Excluded 1411 duplicate problems
Final problem count: 3589


In [54]:
import multiprocessing as mp
import traceback

def _exec_worker(code: str, queue: mp.Queue):
    buf = io.StringIO()
    try:
        globals_dict = {"__name__": "__main__"}
        with contextlib.redirect_stdout(buf):
            exec(code, globals_dict, globals_dict)
        queue.put(("ok", buf.getvalue()))
    except Exception:
        queue.put(("error", traceback.format_exc()))

def exec_with_timeout_capture_stdout(code: str, timeout_seconds: int = 60) -> str:
    queue = mp.Queue()
    p = mp.Process(target=_exec_worker, args=(code, queue), daemon=True)
    p.start()
    p.join(timeout_seconds)

    try:
        if p.is_alive():
            p.terminate()
            p.join(timeout=2)
            if p.is_alive():
                p.kill()
                p.join()
            raise TimeoutError(f"Execution exceeded {timeout_seconds} seconds")

        if queue.empty():
            raise RuntimeError("Execution finished but produced no result")

        status, payload = queue.get_nowait()
        if status == "error":
            raise RuntimeError(payload)

        return payload
    finally:
        if p.is_alive():
            p.kill()
            p.join()
        queue.close()
        queue.join_thread()

In [55]:
import os

def cleanup_zombie_processes():
    try:
        while True:
            pid, status = os.waitpid(-1, os.WNOHANG)
            if pid == 0:
                break
    except ChildProcessError:
        pass

cleanup_zombie_processes()

In [56]:
def strip_markdown_code_blocks(code: str) -> str:
    if not code:
        return code
    code = code.strip()
    if code.startswith('```'):
        lines = code.split('\n')
        if lines[0].startswith('```'):
            lines = lines[1:]
        if lines and lines[-1].strip() == '```':
            lines = lines[:-1]
        code = '\n'.join(lines)
    return code.strip()

In [57]:
# Setup output directory and files
output_dir = OUTPUT_BASE_DIR
output_dir.mkdir(parents=True, exist_ok=True)

run_timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_dir = output_dir / (MODEL_NAME + "_" + run_timestamp)
run_dir.mkdir(parents=True, exist_ok=True)

jsonl_file = run_dir / "dataset.jsonl"

prompt_files = {
    "base": run_dir / f"base_prompt_{PROMPT_VERSION}.txt",
    "timeout_repair": run_dir / f"timeout_repair_prompt_{PROMPT_VERSION}.txt",
    "error_repair": run_dir / f"error_repair_prompt_{PROMPT_VERSION}.txt",
    "wrong_answer_repair": run_dir / f"wrong_answer_repair_prompt_{PROMPT_VERSION}.txt",
}

def create_datapoint_id(dataset, problem_id, timestamp, sequence):
    """Generate unique datapoint ID."""
    return f"{dataset}_{problem_id}_{timestamp}_{sequence:03d}"

def classify_failure(error, is_timeout):
    """Classify failure type for tracking."""
    if is_timeout:
        return {
            "failure_regime": "timeout",
            "error_signature": "TimeoutError",
            "retry_pattern": None
        }

    if error is None:
        return None

    error_str = str(error)

    if "Traceback" in error_str:
        lines = error_str.split('\n')
        error_signature = None
        for line in lines:
            if 'Error:' in line or 'Exception:' in line:
                error_signature = line.strip()
                break
        if error_signature is None:
            error_signature = f"RuntimeError: {error_str.split(chr(10))[0]}"
    else:
        error_signature = f"RuntimeError: {error_str.split(chr(10))[0][:200]}"

    return {
        "failure_regime": "runtime_error",
        "error_signature": error_signature,
        "retry_pattern": None
    }

def save_datapoint(datapoint):
    """Append datapoint to JSONL file."""
    with open(jsonl_file, 'a', encoding='utf-8') as f:
        f.write(json.dumps(datapoint, ensure_ascii=False) + '\n')

In [58]:
# Load prompt templates from files
PROMPT_DIR = Path("prompts")

BASE_PROMPT = (PROMPT_DIR / "nb0_code_generation_base.md").read_text(encoding="utf-8")
TIMEOUT_REPAIR_TEMPLATE = (PROMPT_DIR / "nb0_code_repair_timeout.md").read_text(encoding="utf-8")
ERROR_REPAIR_TEMPLATE = (PROMPT_DIR / "nb0_code_repair_error.md").read_text(encoding="utf-8")
WRONG_ANSWER_REPAIR_TEMPLATE = (PROMPT_DIR / "nb0_code_repair_wrong_answer.md").read_text(encoding="utf-8")
VERIFIER_PROMPT_TEMPLATE = (PROMPT_DIR / "nb0_code_verifier.md").read_text(encoding="utf-8")
REPAIR_PROMPT_TEMPLATE = (PROMPT_DIR / "nb0_code_repair_quality.md").read_text(encoding="utf-8")

print(f"Loaded {len([BASE_PROMPT, TIMEOUT_REPAIR_TEMPLATE, ERROR_REPAIR_TEMPLATE, WRONG_ANSWER_REPAIR_TEMPLATE, VERIFIER_PROMPT_TEMPLATE, REPAIR_PROMPT_TEMPLATE])} prompt templates from {PROMPT_DIR}")

In [59]:
# Save prompts for reproducibility
with open(prompt_files["base"], 'w', encoding='utf-8') as f:
    f.write(BASE_PROMPT)

with open(prompt_files["timeout_repair"], 'w', encoding='utf-8') as f:
    f.write(TIMEOUT_REPAIR_TEMPLATE)

with open(prompt_files["error_repair"], 'w', encoding='utf-8') as f:
    f.write(ERROR_REPAIR_TEMPLATE)

with open(prompt_files["wrong_answer_repair"], 'w', encoding='utf-8') as f:
    f.write(WRONG_ANSWER_REPAIR_TEMPLATE)

print(f"Prompts saved to: {run_dir}")

Prompts saved to /home/larcanio/AIMO3_v2/data/datasets/mvidia_reasoning_steps/DeepSeek-R1/gpt-oss_2026-02-04_02-30-10


In [60]:
async def process_problem(pool: LLMPool, cfg, row, idx, sequence_num):
    """Process a single problem with two-phase retry matching the original schema.

    Phase 1 – Execution retries (up to MAX_EXECUTION_RETRIES):
      - TimeoutError  -> TIMEOUT_REPAIR_TEMPLATE
      - RuntimeError  -> ERROR_REPAIR_TEMPLATE

    Phase 2 – Answer retries (up to MAX_ANSWER_RETRIES):
      - Wrong answer  -> WRONG_ANSWER_REPAIR_TEMPLATE

    Returns:
        (datapoint, found_correct, tok_in_total, tok_out_total)
    """
    problem_text = row["problem"]
    solution_text = row["generated_solution"]
    expected_answer = str(row["expected_answer"])
    problem_id = str(row.get("id", idx))

    base_prompt = BASE_PROMPT.format(problem=problem_text, solution=solution_text)

    timestamp = datetime.now().isoformat()
    datapoint_id = create_datapoint_id(
        "mvidia_reasoning_steps", problem_id,
        datetime.now().strftime("%Y%m%d"), sequence_num,
    )

    datapoint = {
        "id": datapoint_id,
        "dataset": "mvidia_reasoning_steps",
        "problem_id": problem_id,
        "timestamp": timestamp,
        "problem": {
            "text": problem_text,
            "expected_answer": expected_answer,
            "problem_type": row.get("problem_type", ""),
            "problem_source": row.get("problem_source", ""),
            "original_solution": row.get("generated_solution", ""),
        },
        "prompt_ids": {
            "base": f"base_prompt_{PROMPT_VERSION}",
            "timeout_repair": f"timeout_repair_prompt_{PROMPT_VERSION}",
            "error_repair": f"error_repair_prompt_{PROMPT_VERSION}",
            "wrong_answer_repair": f"wrong_answer_repair_prompt_{PROMPT_VERSION}",
            "reasoning_summary": None,
        },
        "models": {
            "generating": MODEL_NAME,
            "fixing": MODEL_NAME,
        },
        "generation_config": {
            "temperature": TEMPERATURE,
            "max_tokens": MAX_TOKENS,
            "reasoning_effort": REASONING_EFFORT,
        },
        "reasoning_summary_config": {
            "enabled": False,
            "reasoning_effort": None,
            "max_tokens": None,
        },
        "attempts": [],
    }

    tok_in_total = 0
    tok_out_total = 0

    # ── Phase 1: Execution retry loop ────────────────────────────────────
    response = None
    captured_stdout = None
    last_error = None
    is_timeout = False
    execution_success = False

    for attempt in range(cfg.MAX_EXECUTION_RETRIES):
        attempt_timestamp = datetime.now().isoformat()

        if attempt == 0:
            user_message = base_prompt
            stage = "initial"
        elif is_timeout:
            user_message = TIMEOUT_REPAIR_TEMPLATE.format(
                timeout_seconds=cfg.EXECUTION_TIMEOUT_SECONDS,
                problem=problem_text,
                solution=solution_text,
                previous_code=response,
                error=str(last_error),
                base_prompt=base_prompt,
            )
            stage = "execution_repair"
        else:
            user_message = ERROR_REPAIR_TEMPLATE.format(
                problem=problem_text,
                solution=solution_text,
                previous_code=response,
                error=str(last_error),
                base_prompt=base_prompt,
            )
            stage = "execution_repair"

        attempt_record = {
            "n": attempt + 1,
            "stage": stage,
            "model": MODEL_NAME,
            "timestamp": attempt_timestamp,
            "code": None,
            "exec": {
                "status": None,
                "stdout": None,
                "stderr": None,
                "error": None,
                "error_type": None,
                "timeout": False,
                "duration": None,
            },
            "result": {
                "predicted": None,
                "correct": None,
            },
            "retry_reason": None,
        }

        # ── LLM call ─────────────────────────────────────────────────────
        messages = [
            {"role": "system", "content": SYSTEM_MESSAGE},
            {"role": "user", "content": user_message},
        ]

        try:
            resp = await pool.request(
                messages,
                temperature=TEMPERATURE,
                max_tokens=MAX_TOKENS,
                reasoning_effort=REASONING_EFFORT,
            )
        except Exception as e:
            status_code = getattr(e, "status_code", None)
            err_detail = getattr(e, "message", str(e))
            print(
                f"LLM error (problem {idx}, attempt {attempt + 1}): "
                f"{type(e).__name__}"
                f"{f' [{status_code}]' if status_code else ''}"
                f" - {str(err_detail)[:300]}"
            )
            continue

        response = resp.content
        if not response:
            continue

        response = strip_markdown_code_blocks(response)
        attempt_record["code"] = response
        tok_in_total += resp.prompt_tokens
        tok_out_total += resp.completion_tokens

        # ── Execute generated code ────────────────────────────────────────
        exec_start = time.monotonic()
        try:
            captured_stdout = await anyio.to_thread.run_sync(
                lambda c=response, t=cfg.EXECUTION_TIMEOUT_SECONDS: (
                    exec_with_timeout_capture_stdout(c, t)
                )
            )
            exec_duration = time.monotonic() - exec_start

            attempt_record["exec"]["status"] = "success"
            attempt_record["exec"]["stdout"] = captured_stdout
            attempt_record["exec"]["duration"] = exec_duration
            execution_success = True
            datapoint["attempts"].append(attempt_record)
            break

        except TimeoutError as e:
            exec_duration = time.monotonic() - exec_start
            last_error = e
            is_timeout = True

            attempt_record["exec"]["status"] = "timeout"
            attempt_record["exec"]["error"] = str(e)
            attempt_record["exec"]["error_type"] = "TimeoutError"
            attempt_record["exec"]["timeout"] = True
            attempt_record["exec"]["duration"] = exec_duration
            attempt_record["retry_reason"] = "timeout"

        except Exception as e:
            exec_duration = time.monotonic() - exec_start
            last_error = e
            is_timeout = False

            attempt_record["exec"]["status"] = "error"
            attempt_record["exec"]["error"] = str(e)
            attempt_record["exec"]["error_type"] = type(e).__name__
            attempt_record["exec"]["duration"] = exec_duration
            attempt_record["retry_reason"] = "runtime_error"

        datapoint["attempts"].append(attempt_record)

    # ── Phase 2: Answer retry loop ───────────────────────────────────────
    if execution_success and captured_stdout is not None:
        predicted = captured_stdout.strip()
        expected = expected_answer.strip()

        # Set result on the last execution attempt
        if datapoint["attempts"]:
            last_att = datapoint["attempts"][-1]
            last_att["result"]["predicted"] = predicted
            last_att["result"]["correct"] = (predicted == expected)

        is_correct = (predicted == expected)
        answer_retry_count = 0

        while answer_retry_count < cfg.MAX_ANSWER_RETRIES and not is_correct:
            answer_retry_count += 1
            answer_retry_timestamp = datetime.now().isoformat()

            answer_attempt_record = {
                "n": len(datapoint["attempts"]) + 1,
                "stage": "answer_repair",
                "model": MODEL_NAME,
                "timestamp": answer_retry_timestamp,
                "code": None,
                "exec": {
                    "status": None,
                    "stdout": None,
                    "stderr": None,
                    "error": None,
                    "error_type": None,
                    "timeout": False,
                    "duration": None,
                },
                "result": {
                    "predicted": None,
                    "correct": None,
                },
                "retry_reason": "wrong_answer",
            }

            user_message = WRONG_ANSWER_REPAIR_TEMPLATE.format(
                problem=problem_text,
                solution=solution_text,
                previous_code=response,
                wrong_answer=predicted,
                expected_answer=expected,
                base_prompt=base_prompt,
            )

            messages = [
                {"role": "system", "content": SYSTEM_MESSAGE},
                {"role": "user", "content": user_message},
            ]

            try:
                resp = await pool.request(
                    messages,
                    temperature=TEMPERATURE,
                    max_tokens=MAX_TOKENS,
                    reasoning_effort=REASONING_EFFORT,
                )
            except Exception as e:
                status_code = getattr(e, "status_code", None)
                err_detail = getattr(e, "message", str(e))
                print(
                    f"LLM error (answer repair, problem {idx}): "
                    f"{type(e).__name__}"
                    f"{f' [{status_code}]' if status_code else ''}"
                    f" - {str(err_detail)[:300]}"
                )
                break

            response = resp.content
            if not response:
                break

            response = strip_markdown_code_blocks(response)
            answer_attempt_record["code"] = response
            tok_in_total += resp.prompt_tokens
            tok_out_total += resp.completion_tokens

            exec_start = time.monotonic()
            try:
                captured_stdout = await anyio.to_thread.run_sync(
                    lambda c=response, t=cfg.EXECUTION_TIMEOUT_SECONDS: (
                        exec_with_timeout_capture_stdout(c, t)
                    )
                )
                exec_duration = time.monotonic() - exec_start

                predicted = captured_stdout.strip()
                answer_attempt_record["exec"]["status"] = "success"
                answer_attempt_record["exec"]["stdout"] = captured_stdout
                answer_attempt_record["exec"]["duration"] = exec_duration
                answer_attempt_record["result"]["predicted"] = predicted
                answer_attempt_record["result"]["correct"] = (predicted == expected)
                is_correct = (predicted == expected)
                datapoint["attempts"].append(answer_attempt_record)

            except TimeoutError as e:
                exec_duration = time.monotonic() - exec_start
                answer_attempt_record["exec"]["status"] = "timeout"
                answer_attempt_record["exec"]["error"] = str(e)
                answer_attempt_record["exec"]["error_type"] = "TimeoutError"
                answer_attempt_record["exec"]["timeout"] = True
                answer_attempt_record["exec"]["duration"] = exec_duration
                datapoint["attempts"].append(answer_attempt_record)
                break

            except Exception as e:
                exec_duration = time.monotonic() - exec_start
                answer_attempt_record["exec"]["status"] = "error"
                answer_attempt_record["exec"]["error"] = str(e)
                answer_attempt_record["exec"]["error_type"] = type(e).__name__
                answer_attempt_record["exec"]["duration"] = exec_duration
                datapoint["attempts"].append(answer_attempt_record)
                break

    # ── Build outcome ────────────────────────────────────────────────────
    if execution_success and captured_stdout is not None:
        predicted = captured_stdout.strip()
        expected = expected_answer.strip()
        is_correct = (predicted == expected)

        pass_at_k = None
        for i, att in enumerate(datapoint["attempts"]):
            if att["exec"]["status"] == "success" and att["result"].get("correct"):
                pass_at_k = i + 1
                break

        datapoint["outcome"] = {
            "status": "success" if is_correct else "wrong_answer",
            "answer": predicted if is_correct else None,
            "total_attempts": len(datapoint["attempts"]),
            "execution_attempts": sum(
                1 for a in datapoint["attempts"]
                if a["stage"] in ("initial", "execution_repair")
            ),
            "answer_retry_attempts": sum(
                1 for a in datapoint["attempts"]
                if a["stage"] == "answer_repair"
            ),
            "pass_at_k": pass_at_k,
        }

        if not is_correct:
            failure_info = classify_failure(None, False)
            if failure_info:
                datapoint["outcome"]["failure_analysis"] = failure_info
    else:
        datapoint["outcome"] = {
            "status": "failed",
            "answer": None,
            "total_attempts": len(datapoint["attempts"]),
            "execution_attempts": len(datapoint["attempts"]),
            "answer_retry_attempts": 0,
            "pass_at_k": None,
        }

        if datapoint["attempts"]:
            last_att = datapoint["attempts"][-1]
            failure_info = classify_failure(
                last_att["exec"]["error"] if last_att["exec"]["error"] else None,
                last_att["exec"]["timeout"],
            )
            if failure_info:
                datapoint["outcome"]["failure_analysis"] = failure_info

    found_correct = datapoint["outcome"]["status"] == "success"
    return datapoint, found_correct, tok_in_total, tok_out_total

In [61]:
async def run_all():
    passed_count = 0
    total_tok_in = 0
    total_tok_out = 0

    max_problems = len(problems) if MAX_PROBLEMS_TO_PROCESS is None else min(MAX_PROBLEMS_TO_PROCESS, len(problems))
    write_lock = anyio.Lock()
    counter = {"done": 0}
    pbar = tqdm(total=max_problems, desc="Processing")

    async def config_reloader(pool):
        """Re-read config.json every 60s and apply changes to the pool."""
        while True:
            await anyio.sleep(60)
            cfg.reload()
            pool._limiter.total_tokens = cfg.MAX_CONCURRENT_REQUESTS
            pool.max_retries = cfg.LLM_REQUEST_RETRY_COUNT
            pool._client.timeout = httpx.Timeout(cfg.LLM_REQUEST_TIMEOUT_SECONDS, connect=30)

    async def process_one(pool: LLMPool, idx: int):
        nonlocal passed_count, total_tok_in, total_tok_out
        row = problems.iloc[idx]
        t0 = time.monotonic()

        try:
            datapoint, found_correct, tok_in, tok_out = await process_problem(
                pool, cfg, row, idx, idx + 1,
            )
        except Exception as e:
            print(f"Problem {idx} failed unexpectedly: {e}")
            return

        elapsed = time.monotonic() - t0

        async with write_lock:
            save_datapoint(datapoint)
            total_tok_in += tok_in
            total_tok_out += tok_out
            if found_correct:
                passed_count += 1
            counter["done"] += 1
            done = counter["done"]
            pbar.update(1)
            passrate = passed_count / done
            pbar.set_postfix(pass_rate=f"{passrate:.2%}", finished=done)
            status = "Pass" if found_correct else "Fail"
            pass_at_k = datapoint["outcome"]["pass_at_k"]
            # print(
            #     f"{done}/{max_problems} -> {status}"
            #     f" -> pass@k={pass_at_k},"
            #     f" req={elapsed:.1f}s,"
            #     f" tokens: in={tok_in}, out={tok_out}, total={tok_in+tok_out}"
            # )

    async with LLMPool(
        base_url=NGINX_BALANCER_URL,
        api_key=API_KEY,
        model=MODEL_NAME,
        max_inflight=cfg.MAX_CONCURRENT_REQUESTS,
        timeout=cfg.LLM_REQUEST_TIMEOUT_SECONDS,
        max_retries=cfg.LLM_REQUEST_RETRY_COUNT,
    ) as pool:
        async with anyio.create_task_group() as tg:
            tg.start_soon(config_reloader, pool)

            async with anyio.create_task_group() as work_tg:
                for idx in range(max_problems):
                    work_tg.start_soon(process_one, pool, idx)

            # All work finished — cancel the reloader
            tg.cancel_scope.cancel()

    pbar.close()
    print(f"\nComplete: {passed_count}/{max_problems} passed, {passed_count/max_problems:.2%} pass rate")
    print(f"Total tokens: in={total_tok_in}, out={total_tok_out}, total={total_tok_in+total_tok_out}")

await run_all()

Processing:   0%|          | 0/3589 [00:00<?, ?it/s]



LLM error (problem 796, attempt 2): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -148. (parameter=max_tokens, value=-148)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (answer repair, problem 998): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -1461. (parameter=max_tokens, value=-1461)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (problem 1136, attempt 2): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -245. (parameter=max_tokens, value=-245)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (answer repair, problem 1192): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -4653. (parameter=max_tokens, value=-4653)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (ans

Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 2.77402622738568 seconds as it raised ReadTimeout: .


LLM error (answer repair, problem 1593): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -42. (parameter=max_tokens, value=-42)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (answer repair, problem 1994): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -2856. (parameter=max_tokens, value=-2856)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (answer repair, problem 2018): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -2807. (parameter=max_tokens, value=-2807)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (answer repair, problem 2129): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -12. (parameter=max_tokens, value=-12)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM erro

Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 5.81218776113709 seconds as it raised ReadTimeout: .


LLM error (problem 796, attempt 3): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -148. (parameter=max_tokens, value=-148)","type":"BadRequestError","param":"max_tokens","code":400}}
LLM error (problem 1136, attempt 3): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -245. (parameter=max_tokens, value=-245)","type":"BadRequestError","param":"max_tokens","code":400}}


Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 3.5321113634044483 seconds as it raised ReadTimeout: .


LLM error (problem 2356, attempt 3): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -1976. (parameter=max_tokens, value=-1976)","type":"BadRequestError","param":"max_tokens","code":400}}


Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 4.6323812877803014 seconds as it raised ReadTimeout: .


LLM error (problem 3244, attempt 3): LLMRequestError [400] - LLM request failed [400]: {"error":{"message":"max_tokens must be at least 1, got -4081. (parameter=max_tokens, value=-4081)","type":"BadRequestError","param":"max_tokens","code":400}}


Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 2.301634895075943 seconds as it raised ReadTimeout: .
Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 2.594515806157423 seconds as it raised ReadTimeout: .
Retrying llm_pool.LLMPool._post_with_retry.<locals>._do_post in 1.0520680929379347 seconds as it raised ReadTimeout: .



Complete: 2927/3589 passed, 81.55% pass rate
Total tokens: in=29516344, out=4032009, total=33548353
