## Setup

In [22]:
import json
import re

import time
import multiprocessing as mp
import sys
from pathlib import Path
import anyio
import httpx
from tqdm.notebook import tqdm
from utils import RuntimeConfig, LLMPool

# Configuration

In [None]:
# Paths
INPUT_DATASET_PATH = Path('/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_1_5_complete_effective_difficulty.jsonl.jsonl')
OUTPUT_DATASET_PATH = Path('/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_1_5_rerun_more_tokens.jsonl')

AUDIT_FIELD_NAME = "computation_buckets"

# API Configuration
API_BASE_URL = "http://127.0.0.1:8080/v1"
API_KEY = "sk-local"
MODEL_NAME = "gpt-oss"
MODEL_LABEL = "gpt-oss-20b-low"
MODEL_LEVEL = 5

# Level filters (None = no filter)
UPPER_LEVEL = None
LOWER_LEVEL = 0

MAX_TOKENS = 256

# Sampling configuration
N_SAMPLES = 1
if N_SAMPLES not in (1, 2):
    raise ValueError(f"N_SAMPLES must be 1 or 2, got {N_SAMPLES}")

ATTEMPT_TEMPERATURE = [0.0, 0.6]
ATTEMPT_TOP_P = [1.0, 0.9]

CURRENT_LEVEL_FAIL_RERUN = True
KEEP_ONLY = True
N_PROBLEMS = 0

# Validate level bounds
if LOWER_LEVEL is not None and LOWER_LEVEL >= MODEL_LEVEL:
    raise ValueError(f"LOWER_LEVEL ({LOWER_LEVEL}) must be < MODEL_LEVEL ({MODEL_LEVEL})")
if UPPER_LEVEL is not None and UPPER_LEVEL <= MODEL_LEVEL:
    raise ValueError(f"UPPER_LEVEL ({UPPER_LEVEL}) must be > MODEL_LEVEL ({MODEL_LEVEL})")

# Runtime configuration
CONFIG_FILE = "config.json"
cfg = RuntimeConfig(CONFIG_FILE, defaults={
    "MAX_EXECUTION_RETRIES": 1,
    "MAX_ANSWER_RETRIES": 1,
    "LLM_REQUEST_RETRY_COUNT": 2,
    "LLM_REQUEST_TIMEOUT_SECONDS": 300,
    "EXECUTION_TIMEOUT_SECONDS": 2,
    "MAX_CONCURRENT_REQUESTS": 30,
})

print(f"Configuration: {cfg}")
print(f"Level filters: LOWER={LOWER_LEVEL}, MODEL={MODEL_LEVEL}, UPPER={UPPER_LEVEL}")
print(f"Rerun failures: {CURRENT_LEVEL_FAIL_RERUN}")

[config] reloaded: MAX_ANSWER_RETRIES: 1 -> 3, EXECUTION_TIMEOUT_SECONDS: 2 -> 20, MAX_CONCURRENT_REQUESTS: 30 -> 100
RuntimeConfig(MAX_EXECUTION_RETRIES=1, MAX_ANSWER_RETRIES=3, LLM_REQUEST_RETRY_COUNT=2, LLM_REQUEST_TIMEOUT_SECONDS=300, EXECUTION_TIMEOUT_SECONDS=20, MAX_CONCURRENT_REQUESTS=100)
Level filters: LOWER_LEVEL=0, MODEL_LEVEL=5, UPPER_LEVEL=None
CURRENT_LEVEL_FAIL_RERUN=True


# Load Dataset

In [24]:
full_datapoints = []
with open(INPUT_DATASET_PATH, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            try:
                full_datapoints.append(json.loads(line))
            except json.JSONDecodeError:
                continue

print(f"Loaded: {len(full_datapoints)} total records")


def get_tier(dp: dict) -> str:
    """Extract tier from audit field."""
    return dp.get('audit', {}).get('tier', 'core')


all_datapoints = full_datapoints
if KEEP_ONLY:
    all_datapoints = [dp for dp in full_datapoints if get_tier(dp) in ['core', 'extended']]

if N_PROBLEMS > 0:
    all_datapoints = all_datapoints[:N_PROBLEMS]

print(f"Eligible for classification: {len(all_datapoints)}")
if KEEP_ONLY:
    dropped = len(full_datapoints) - len(all_datapoints)
    print(f"Filtered out: {dropped} (tier not in core/extended)")


def _get_audit_list(dp: dict) -> list:
    """Return audit entries as a list (handles missing or legacy dict format)."""
    audit = dp.get(AUDIT_FIELD_NAME)
    if isinstance(audit, list):
        return audit
    return []


def needs_classification(dp: dict) -> bool:
    """Decide whether to classify this datapoint at MODEL_LEVEL.

    Skip rules (in order):
      1. Already has an entry at MODEL_LEVEL:
         - CURRENT_LEVEL_FAIL_RERUN=False → skip (already processed).
         - CURRENT_LEVEL_FAIL_RERUN=True  → skip only if it passed; retry failures.
      2. LOWER_LEVEL filter (only when set): if any level in [LOWER_LEVEL, MODEL_LEVEL) passed → skip.
         (Problem is too easy — a weaker model already solved it.)
         If LOWER_LEVEL is None, no lower-bound filtering is applied.
      3. UPPER_LEVEL filter: if no level in (MODEL_LEVEL, UPPER_LEVEL] passed → skip.
         (No evidence a stronger model can solve it — don't waste compute.)
    """
    entries = _get_audit_list(dp)

    # 1. Current-level check
    for entry in entries:
        if entry.get('level') == MODEL_LEVEL:
            if CURRENT_LEVEL_FAIL_RERUN:
                # skip only if it already passed; allow rerun of failures
                if entry.get('passes', 0) > 0:
                    return False
            else:
                return False

    # 2. Lower-bound filter (only when LOWER_LEVEL is explicitly set)
    if LOWER_LEVEL is not None:
        for entry in entries:
            lvl = entry.get('level')
            if lvl is not None and LOWER_LEVEL <= lvl < MODEL_LEVEL and entry.get('passes', 0) > 0:
                return False

    # 3. Upper-bound filter
    if UPPER_LEVEL is not None:
        has_upper_pass = any(
            entry.get('level') is not None
            and MODEL_LEVEL < entry['level'] <= UPPER_LEVEL
            and entry.get('passes', 0) > 0
            for entry in entries
        )
        if not has_upper_pass:
            return False

    return True


to_classify = [dp for dp in all_datapoints if needs_classification(dp)]

# to_classify = to_classify[:20000]

skipped = len(all_datapoints) - len(to_classify)
print(f"To classify: {len(to_classify)} (skipped {skipped} already processed)")

# Position mapping for incremental save: to_classify[i] -> full_datapoints index
_id_to_full_pos = {id(dp): i for i, dp in enumerate(full_datapoints)}
classify_to_full_pos = [_id_to_full_pos[id(dp)] for dp in to_classify]

KeyboardInterrupt: 

# Code Extraction & Execution

In [None]:
import subprocess
import tempfile
import os


def extract_code_from_response(text):
    if not text:
        return None
    for pattern in [r'```python\s*(.*?)\s*```', r'```\s*(.*?)\s*```']:
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
    return None


def execute_code_with_timeout(code: str, timeout_seconds: int = 30):
    """Execute code in a fresh subprocess (no fork, no inherited memory)."""
    tmp_path = None
    try:
        with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
            f.write(code)
            tmp_path = f.name

        proc = subprocess.run(
            [sys.executable, tmp_path],
            capture_output=True, text=True, timeout=timeout_seconds,
        )

        if proc.returncode != 0:
            return None, proc.stderr[:500] or "Non-zero exit", False
        stdout = proc.stdout.strip()
        if stdout:
            return stdout.split('\n')[-1].strip(), None, False
        return None, "No output", False
    except subprocess.TimeoutExpired:
        return None, "Timeout", True
    except Exception as e:
        return None, str(e)[:500], False
    finally:
        if tmp_path:
            try:
                os.unlink(tmp_path)
            except OSError:
                pass


def check_answer(predicted, expected):
    if predicted is None or expected is None:
        return False
    pred_str = str(predicted).strip()
    exp_str = str(expected).strip()
    if pred_str == exp_str:
        return True
    try:
        return abs(float(pred_str) - float(exp_str)) < 1e-6
    except (ValueError, TypeError):
        return False

# Prompt Template

In [None]:
# Load prompt templates from files
PROMPT_DIR = Path("prompts")
PROMPT_TEMPLATE = (PROMPT_DIR / "nb2_code_generation_simple.md").read_text(encoding="utf-8")

SYSTEM_MESSAGE = "You are a mathematician writing Python code to solve problems."


def format_prompt(problem: str) -> str:
    """Format the problem using the loaded template."""
    return PROMPT_TEMPLATE.format(problem=problem)


print(f"Loaded prompt template from {PROMPT_DIR / 'nb2_code_generation_simple.md'}")

# Classification Engine

In [None]:
SAVE_EVERY = 1000


def get_problem_text(dp: dict) -> str:
    """Extract problem text from datapoint."""
    if 'problem' in dp and isinstance(dp['problem'], dict):
        return dp['problem'].get('text', '')
    return dp.get('text') or dp.get('problem') or ''


def get_expected_answer(dp: dict) -> str:
    """Extract expected answer from datapoint."""
    if 'problem' in dp and isinstance(dp['problem'], dict):
        return str(dp['problem'].get('expected_answer', ''))
    return str(dp.get('answer_expected') or dp.get('answer') or dp.get('expected_answer') or '')


def _save_checkpoint(completed_n: int, total_n: int, save_num: int):
    """Write full datapoints to disk atomically."""
    OUTPUT_DATASET_PATH.parent.mkdir(parents=True, exist_ok=True)
    tmp = OUTPUT_DATASET_PATH.with_suffix('.tmp')
    with open(tmp, 'w', encoding='utf-8') as f:
        for dp in full_datapoints:
            f.write(json.dumps(dp, ensure_ascii=False) + '\n')
    tmp.rename(OUTPUT_DATASET_PATH)
    print(f"\n[Checkpoint {save_num}] {completed_n}/{total_n} saved")


async def classify_single_problem(pool: LLMPool, dp: dict, idx: int) -> tuple[int, int, int]:
    """Run up to N_SAMPLES attempts with early exit on first pass.
    
    Returns (passes, prompt_tokens, completion_tokens).
    """
    problem_text = get_problem_text(dp)
    expected_answer = get_expected_answer(dp)
    prompt = format_prompt(problem_text)
    messages = [
        {"role": "system", "content": SYSTEM_MESSAGE},
        {"role": "user", "content": prompt},
    ]

    passes = 0
    prompt_tokens = 0
    completion_tokens = 0
    for attempt in range(N_SAMPLES):
        try:
            resp = await pool.request(
                messages,
                temperature=ATTEMPT_TEMPERATURE[attempt],
                max_tokens=MAX_TOKENS,
                top_p=ATTEMPT_TOP_P[attempt],
                seed=attempt,
            )
            prompt_tokens += resp.prompt_tokens
            completion_tokens += resp.completion_tokens
            code = extract_code_from_response(resp.content)
            if code:
                result, error, is_timeout = await anyio.to_thread.run_sync(
                    execute_code_with_timeout, code, cfg.EXECUTION_TIMEOUT_SECONDS,
                )
                if error is None and not is_timeout and check_answer(result, expected_answer):
                    passes += 1
                    break  # early exit: no need for more attempts
        except Exception as e:
            print(f"Error (problem {idx}, attempt {attempt}): {type(e).__name__} - {str(e)[:100]}")

    return passes, prompt_tokens, completion_tokens


async def run_classification(datapoints: list):
    results = []
    completed = 0
    pass_count = 0
    total_prompt_tokens = 0
    total_completion_tokens = 0
    save_count = 0
    total = len(datapoints)

    pbar = tqdm(total=total, desc="Classifying")
    write_lock = anyio.Lock()
    save_lock = anyio.Lock()
    spawn_limit = anyio.Semaphore(cfg.MAX_CONCURRENT_REQUESTS * 3)

    async def config_reloader(pool):
        while True:
            await anyio.sleep(60)
            cfg.reload()
            pool._limiter.total_tokens = cfg.MAX_CONCURRENT_REQUESTS
            pool._client.timeout = httpx.Timeout(cfg.LLM_REQUEST_TIMEOUT_SECONDS, connect=30)

    async def process_one(pool, idx, dp):
        nonlocal completed, pass_count, total_prompt_tokens, total_completion_tokens, save_count

        passes, p_tok, c_tok = await classify_single_problem(pool, dp, idx)

        new_entry = {
            "level": MODEL_LEVEL,
            "model": MODEL_LABEL,
            "attempts": N_SAMPLES,
            "passes": passes,
            "max_tokens": MAX_TOKENS,
            "execution_timeout": cfg.EXECUTION_TIMEOUT_SECONDS,
        }

        classified_dp = dp.copy()
        existing_audit = _get_audit_list(dp)

        if CURRENT_LEVEL_FAIL_RERUN:
            # Replace existing entry at MODEL_LEVEL (avoid duplicates on rerun)
            existing_audit = [e for e in existing_audit if e.get('level') != MODEL_LEVEL]

        existing_audit.append(new_entry)
        classified_dp[AUDIT_FIELD_NAME] = existing_audit

        should_save = False
        async with write_lock:
            results.append((idx, classified_dp))
            completed += 1
            full_datapoints[classify_to_full_pos[idx]] = classified_dp
            total_prompt_tokens += p_tok
            total_completion_tokens += c_tok
            if passes > 0:
                pass_count += 1
            pbar.set_postfix(pass_rate=f"{pass_count/completed:.2%}", done=completed)
            pbar.update(1)
            should_save = completed % SAVE_EVERY == 0

        if should_save:
            async with save_lock:
                save_count += 1
                await anyio.to_thread.run_sync(
                    _save_checkpoint, completed, total, save_count,
                )

    async with LLMPool(
        base_url=API_BASE_URL,
        api_key=API_KEY,
        model=MODEL_NAME,
        reasoning_effort="low",
        max_inflight=cfg.MAX_CONCURRENT_REQUESTS,
        timeout=cfg.LLM_REQUEST_TIMEOUT_SECONDS,
    ) as pool:
        async with anyio.create_task_group() as tg:
            tg.start_soon(config_reloader, pool)

            async with anyio.create_task_group() as work_tg:
                for idx, dp in enumerate(datapoints):
                    await spawn_limit.acquire()
                    async def _run(pool=pool, idx=idx, dp=dp):
                        try:
                            await process_one(pool, idx, dp)
                        finally:
                            spawn_limit.release()
                    work_tg.start_soon(_run)

            tg.cancel_scope.cancel()

    pbar.close()

    # Final save
    save_count += 1
    _save_checkpoint(completed, total, save_count)

    results.sort(key=lambda x: x[0])
    classified = [dp for _, dp in results]
    token_stats = {
        "prompt_tokens": total_prompt_tokens,
        "completion_tokens": total_completion_tokens,
        "total_tokens": total_prompt_tokens + total_completion_tokens,
    }
    return classified, token_stats


print("Classification engine ready")

## API Test

In [None]:
async def test_api():
    body = {
        "model": MODEL_NAME,
        "messages": [
            {"role": "system", "content": "You are helpful."},
            {"role": "user", "content": "What is 2+2? Reply with just the number."}
        ],
        "temperature": 0.0,
        "max_tokens": 32,
    }
    async with httpx.AsyncClient(timeout=httpx.Timeout(30, connect=5)) as client:
        resp = await client.post(
            f"{API_BASE_URL}/chat/completions",
            json=body,
            headers={"Authorization": f"Bearer {API_KEY}", "Content-Type": "application/json"},
        )
        if resp.status_code == 200:
            data = resp.json()
            print(f"OK: {data['choices'][0]['message']['content']}")
        else:
            print(f"FAIL {resp.status_code}: {resp.text[:200]}")

await test_api()

OK: 


## Run Classification

In [None]:
print(f"Classifying {len(to_classify)} problems | level={MODEL_LEVEL} | N_SAMPLES={N_SAMPLES} | Model: {MODEL_LABEL}\n")

start_time = time.time()
classified_datapoints, token_stats = await run_classification(to_classify)
elapsed = time.time() - start_time

print(f"\nDone in {elapsed:.1f}s ({elapsed/len(to_classify):.2f}s/problem)")
print(f"\nToken usage:")
print(f"  Prompt tokens:     {token_stats['prompt_tokens']:,}")
print(f"  Completion tokens: {token_stats['completion_tokens']:,}")
print(f"  Total tokens:      {token_stats['total_tokens']:,}")

Classifying 13567 problems | level=5 | N_SAMPLES=1 | Model: gpt-oss-20b-low



Classifying:   0%|          | 0/13567 [00:00<?, ?it/s]


[checkpoint 1] 1000/13567 saved to dataset_1_5_rerun_more_tokens_rm.jsonl

[checkpoint 2] 2000/13567 saved to dataset_1_5_rerun_more_tokens_rm.jsonl


CancelledError: 

## Summary Statistics

In [None]:
total = len(classified_datapoints)

def _current_level_passes(dp):
    for entry in dp.get(AUDIT_FIELD_NAME, []):
        if entry.get('level') == MODEL_LEVEL:
            return entry.get('passes', 0)
    return 0

passed = sum(1 for dp in classified_datapoints if _current_level_passes(dp) > 0)
failed = total - passed

print(f"Level {MODEL_LEVEL} ({MODEL_LABEL})")
print(f"Total:  {total}")
print(f"Passed: {passed} ({passed/total*100:.2f}%)")
print(f"Failed: {failed} ({failed/total*100:.2f}%)")

Level 5 (gpt-oss-20b-low)
Total:  36464
Passed: 11026 (30.24%)
Failed: 25438 (69.76%)


## Save Classified Dataset

In [None]:
# Merge classified results back into the FULL dataset (preserves all original records)
classified_set = {id(dp): cdp for dp, cdp in zip(to_classify, classified_datapoints)}

final_datapoints = []
for dp in full_datapoints:
    if id(dp) in classified_set:
        final_datapoints.append(classified_set[id(dp)])
    else:
        final_datapoints.append(dp)

OUTPUT_DATASET_PATH.parent.mkdir(parents=True, exist_ok=True)
with open(OUTPUT_DATASET_PATH, 'w', encoding='utf-8') as f:
    for dp in final_datapoints:
        f.write(json.dumps(dp, ensure_ascii=False) + '\n')

print(f"Saved {len(final_datapoints)} records to {OUTPUT_DATASET_PATH}")
print(f"  classified: {len(classified_datapoints)}, unchanged: {len(final_datapoints) - len(classified_datapoints)}")

Saved 71832 records to /home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_1_5_rerun_more_tokens.jsonl
  classified: 36464, unchanged: 35368
