In [1]:
import json
import sys
import re
from pathlib import Path
from tqdm.notebook import tqdm
from datetime import datetime
from typing import Dict, Any, Optional, List, Tuple
import time

import anyio
import httpx
from utils import LLMPool, RuntimeConfig

In [2]:
# Configuration Parameters

# LLM Configuration
API_KEY = "sk-local"
MODEL_NAME = "gpt-oss"
NGINX_BALANCER_URL = "http://127.0.0.1:8080/v1"

TEMPERATURE = 0.3
TOP_P = 0.95
MAX_TOKENS = 1024 * 8
REASONING_EFFORT = "low"

# Input/Output Configuration
INPUT_JSONL_PATH = Path("/home/larcanio/AIMO3_v2/data/datasets/gsm8k_reasoning_steps/Compiled-GSM8K/dataset.jsonl")
OUTPUT_JSONL_PATH = INPUT_JSONL_PATH.parent / "dataset_classfied.jsonl"

# Processing Configuration
MAX_PROBLEMS_TO_PROCESS = None
START_FROM_INDEX = 0

# Prompt File Paths
PROMPT_DIR = Path("prompts")
TEXT_STRUCTURE_PROMPT_PATH = PROMPT_DIR / "nb3_math_structure_from_text.md"
SOLUTION_STRUCTURE_PROMPT_PATH = PROMPT_DIR / "nb3_math_structure_from_solution.md"

# Runtime Configuration
CONFIG_FILE = "config.json"

cfg = RuntimeConfig(CONFIG_FILE, defaults={
    "LLM_REQUEST_RETRY_COUNT": 3,
    "LLM_REQUEST_TIMEOUT_SECONDS": 300,
    "MAX_CONCURRENT_REQUESTS": 25,
})

print(f"Configuration: Model={MODEL_NAME}, Effort={REASONING_EFFORT}")
print(f"Config: {cfg}")
print(f"Input: {INPUT_JSONL_PATH.name}")
print(f"Output: {OUTPUT_JSONL_PATH.name}")
print(f"Max problems: {MAX_PROBLEMS_TO_PROCESS or 'All'}")

[config] reloaded: MAX_CONCURRENT_REQUESTS: 25 -> 20
Configuration loaded:
  Model: gpt-oss
  Reasoning effort: low
  Config file: config.json
RuntimeConfig(LLM_REQUEST_RETRY_COUNT=3, LLM_REQUEST_TIMEOUT_SECONDS=300, MAX_CONCURRENT_REQUESTS=20)
  Input file: /home/larcanio/AIMO3_v2/data/datasets/gsm8k_reasoning_steps/Compiled-GSM8K/dataset.jsonl
  Output file: /home/larcanio/AIMO3_v2/data/datasets/gsm8k_reasoning_steps/Compiled-GSM8K/dataset_classfied.jsonl
  Max problems: All


In [3]:
# Prompt Templates

# Load prompt templates from markdown files
_text_prompt_raw = TEXT_STRUCTURE_PROMPT_PATH.read_text(encoding="utf-8")
_solution_prompt_raw = SOLUTION_STRUCTURE_PROMPT_PATH.read_text(encoding="utf-8")

# Split system prompt from trailing section header
_text_split = _text_prompt_raw.split("### Problem")
TEXT_STRUCTURE_SYSTEM_PROMPT = _text_split[0].strip()

_solution_split = _solution_prompt_raw.split("### Chain-of-Thought")
SOLUTION_STRUCTURE_SYSTEM_PROMPT = _solution_split[0].strip()

# User prompt templates
TEXT_STRUCTURE_USER_PROMPT = "### Problem\n\n{PROBLEM_TEXT}"
SOLUTION_STRUCTURE_USER_PROMPT = "### Chain-of-Thought\n\n{SOLUTION_CODE}"

print(f"Prompts loaded: Text={len(TEXT_STRUCTURE_SYSTEM_PROMPT)} chars, "
      f"Solution={len(SOLUTION_STRUCTURE_SYSTEM_PROMPT)} chars")

Prompts loaded
  Text structure system prompt: 4404 chars
  Solution structure system prompt: 4492 chars


In [4]:
# Helper Functions

def parse_json_from_response(text: str) -> Dict[str, Any]:
    """Extract and parse JSON from LLM response.
    
    Handles markdown code blocks and surrounding text.
    """
    if not text or text.strip() == "":
        raise ValueError("Empty response from LLM")
    
    text = text.strip()
    if "```" in text:
        code_block_pattern = r"```(?:json)?\s*\n?(.*?)```"
        matches = re.findall(code_block_pattern, text, re.DOTALL)
        if matches:
            text = matches[0].strip()
    
    start_idx = text.find("{")
    if start_idx == -1:
        raise ValueError("No JSON object found in response")
    
    brace_count = 0
    end_idx = start_idx
    for i in range(start_idx, len(text)):
        if text[i] == "{":
            brace_count += 1
        elif text[i] == "}":
            brace_count -= 1
            if brace_count == 0:
                end_idx = i + 1
                break
    
    json_str = text[start_idx:end_idx]
    
    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        json_str_fixed = re.sub(r",\s*}", "}", json_str)
        json_str_fixed = re.sub(r",\s*]", "]", json_str_fixed)
        try:
            return json.loads(json_str_fixed)
        except:
            raise ValueError(f"Failed to parse JSON: {e}. Text: {json_str[:500]}")

print("Helper functions ready")

Helper functions loaded


In [5]:
# Load Data

def load_problems_from_jsonl(jsonl_path: Path) -> list:
    """Load problems from JSONL file."""
    problems = []
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line_num, line in enumerate(f, 1):
            if line.strip():
                try:
                    data = json.loads(line)
                    problems.append(data)
                except json.JSONDecodeError as e:
                    print(f"Error parsing line {line_num}: {e}")
                    continue
    return problems

all_problems = load_problems_from_jsonl(INPUT_JSONL_PATH)
print(f"Loaded {len(all_problems)} problems")

# Apply filters
problems = all_problems[:]

if START_FROM_INDEX > 0:
    problems = problems[START_FROM_INDEX:]
    print(f"Starting from index {START_FROM_INDEX}")

if MAX_PROBLEMS_TO_PROCESS:
    problems = problems[:MAX_PROBLEMS_TO_PROCESS]

print(f"Processing {len(problems)} problems")

# Position mapping for in-place update
_id_to_full_pos = {id(dp): i for i, dp in enumerate(all_problems)}
classify_to_full_pos = [_id_to_full_pos[id(dp)] for dp in problems]

# Pre-scan existing math_structure data
_pre_has_text = 0
_pre_has_sol = 0
_pre_has_both = 0
_pre_has_neither = 0
for dp in problems:
    ms = dp.get("math_structure", {})
    has_text = bool(ms.get("from_text"))
    has_sol = bool(ms.get("from_solution"))
    if has_text and has_sol:
        _pre_has_both += 1
    elif has_text:
        _pre_has_text += 1
    elif has_sol:
        _pre_has_sol += 1
    else:
        _pre_has_neither += 1

print(f"\nExisting coverage:")
print(f"  Both: {_pre_has_both} (skip) | Text only: {_pre_has_text} | Solution only: {_pre_has_sol}")
print(f"  Neither: {_pre_has_neither} | Total needing work: {len(problems) - _pre_has_both}")

Loaded 7411 problems from /home/larcanio/AIMO3_v2/data/datasets/gsm8k_reasoning_steps/Compiled-GSM8K/dataset.jsonl
Will process 7411 problems

Existing math_structure coverage:
  Both from_text + from_solution: 0 (will skip)
  Only from_text:                 0 (will compute from_solution)
  Only from_solution:             0 (will compute from_text)
  Neither:                        7411 (will compute both)
  Total needing work:             7411


In [6]:
# ============================================================================
# PROCESS PROBLEMS - TWO-PASS STRUCTURE EXTRACTION
# ============================================================================

SAVE_EVERY = 500


async def extract_text_structure(pool: LLMPool, problem_text: str) -> Tuple[Optional[Dict[str, Any]], int, int]:
    """Extract mathematical structure from the problem text.
    
    Returns:
        Tuple of (structure_dict, prompt_tokens, completion_tokens)
    """
    try:
        user_prompt = TEXT_STRUCTURE_USER_PROMPT.format(PROBLEM_TEXT=problem_text)
        messages = [
            {"role": "system", "content": TEXT_STRUCTURE_SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]
        resp = await pool.request(
            messages,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            reasoning_effort=REASONING_EFFORT,
            top_p=TOP_P,
        )
        if not resp.content:
            return None, resp.prompt_tokens, resp.completion_tokens
        return parse_json_from_response(resp.content), resp.prompt_tokens, resp.completion_tokens
    except Exception as e:
        print(f"  text-structure error: {type(e).__name__}: {str(e)[:100]}")
        return None, 0, 0


async def extract_solution_structure(pool: LLMPool, solution_code: str) -> Tuple[Optional[Dict[str, Any]], int, int]:
    """Extract reasoning structure from the solution code / CoT.
    
    Returns:
        Tuple of (structure_dict, prompt_tokens, completion_tokens)
    """
    try:
        user_prompt = SOLUTION_STRUCTURE_USER_PROMPT.format(SOLUTION_CODE=solution_code)
        messages = [
            {"role": "system", "content": SOLUTION_STRUCTURE_SYSTEM_PROMPT},
            {"role": "user", "content": user_prompt},
        ]
        resp = await pool.request(
            messages,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
            reasoning_effort=REASONING_EFFORT,
            top_p=TOP_P,
        )
        if not resp.content:
            return None, resp.prompt_tokens, resp.completion_tokens
        return parse_json_from_response(resp.content), resp.prompt_tokens, resp.completion_tokens
    except Exception as e:
        print(f"  solution-structure error: {type(e).__name__}: {str(e)[:100]}")
        return None, 0, 0


def _save_checkpoint(completed_n, total_n, save_num):
    """Write all_problems to disk atomically (runs in thread)."""
    OUTPUT_JSONL_PATH.parent.mkdir(parents=True, exist_ok=True)
    tmp = OUTPUT_JSONL_PATH.with_suffix('.tmp')
    with open(tmp, 'w', encoding='utf-8') as f:
        for dp in all_problems:
            f.write(json.dumps(dp, ensure_ascii=False) + '\n')
    tmp.rename(OUTPUT_JSONL_PATH)
    print(f"\n[checkpoint {save_num}] {completed_n}/{total_n} saved")


async def run_all():
    """Run the two-pass extraction pipeline."""
    results = []
    completed = 0
    save_count = 0
    total = len(problems)

    write_lock = anyio.Lock()
    save_lock = anyio.Lock()
    counter = {
        "done": 0, "skipped": 0, "already_complete": 0,
        "tok_in": 0, "tok_out": 0,
        "both": 0, "text_only": 0, "solution_only": 0,
    }

    pbar = tqdm(total=total, desc="Extracting")
    spawn_limit = anyio.Semaphore(cfg.MAX_CONCURRENT_REQUESTS * 3)

    async def config_reloader(pool):
        """Re-read config.json every 60s and apply changes to the pool."""
        while True:
            await anyio.sleep(60)
            cfg.reload()
            pool._limiter.total_tokens = cfg.MAX_CONCURRENT_REQUESTS
            pool.max_retries = cfg.LLM_REQUEST_RETRY_COUNT
            pool._client.timeout = httpx.Timeout(cfg.LLM_REQUEST_TIMEOUT_SECONDS, connect=30)

    async def process_one(pool, idx, dp):
        nonlocal completed, save_count

        tok_in = 0
        tok_out = 0

        problem_text = dp.get("problem", {}).get("text", "")
        if not problem_text:
            print(f"  Problem {idx} has no text, skipping")
            async with write_lock:
                counter["done"] += 1
                counter["skipped"] += 1
                completed += 1
                pbar.update(1)
            return

        # Resolve solution code from pass_at_k
        pass_at_k = dp.get("outcome", {}).get("pass_at_k") or 0
        attempts = dp.get("attempts", [])
        solution_code = ""
        if pass_at_k > 0 and pass_at_k <= len(attempts):
            solution_code = attempts[pass_at_k - 1].get("code", "")

        # Check existing math_structure to avoid redundant computation
        existing_ms = dp.get("math_structure", {})
        existing_text = existing_ms.get("from_text")
        existing_sol = existing_ms.get("from_solution")
        need_text = not existing_text
        need_sol = not existing_sol

        if not need_text and not need_sol:
            # Both already present, skip entirely
            async with write_lock:
                counter["done"] += 1
                counter["already_complete"] += 1
                completed += 1
                pbar.update(1)
            return

        try:
            t0 = time.monotonic()

            # --- Pass 1: text structure (only if missing) ---
            text_struct = existing_text
            if need_text:
                text_struct, p1_in, p1_out = await extract_text_structure(pool, problem_text)
                tok_in += p1_in
                tok_out += p1_out

            # --- Pass 2: solution structure (only if missing and code exists) ---
            solution_struct = existing_sol
            if need_sol:
                if solution_code:
                    solution_struct, p2_in, p2_out = await extract_solution_structure(pool, solution_code)
                    tok_in += p2_in
                    tok_out += p2_out
                else:
                    solution_struct = None

            elapsed = time.monotonic() - t0

            if text_struct is None and solution_struct is None:
                print(f"  Both extractions failed for problem {idx}")
                async with write_lock:
                    counter["done"] += 1
                    counter["tok_in"] += tok_in
                    counter["tok_out"] += tok_out
                    completed += 1
                    pbar.update(1)
                return

            # Build result
            result = dict(dp)
            result["math_structure"] = {
                "from_text": text_struct,
                "from_solution": solution_struct,
            }
            result["extraction_timestamp"] = datetime.now().isoformat()
            result["extraction_model"] = MODEL_NAME
            result["extraction_reasoning_effort"] = REASONING_EFFORT
            result["extraction_tokens"] = {"prompt": tok_in, "completion": tok_out}

            should_save = False
            async with write_lock:
                results.append(result)
                all_problems[classify_to_full_pos[idx]] = result
                counter["done"] += 1
                counter["tok_in"] += tok_in
                counter["tok_out"] += tok_out
                completed += 1

                done = counter["done"]
                skipped = counter["skipped"]
                already = counter["already_complete"]
                extracted = done - skipped - already
                has_both = counter["both"]
                text_only = counter["text_only"]
                pbar.set_description(
                    f"Extracting ({extracted} done, {has_both} both, {text_only} text-only) [skip:{skipped}, cached:{already}]"
                )
                pbar.update(1)

                # Track extraction completeness
                if text_struct and solution_struct:
                    counter["both"] += 1
                elif text_struct:
                    counter["text_only"] += 1
                else:
                    counter["solution_only"] += 1

                should_save = completed % SAVE_EVERY == 0

            if should_save:
                async with save_lock:
                    save_count += 1
                    await anyio.to_thread.run_sync(
                        _save_checkpoint, completed, total, save_count,
                    )

        except Exception as e:
            print(f"  Failed to extract problem {idx}: {type(e).__name__}: {str(e)[:100]}")
            async with write_lock:
                counter["done"] += 1
                counter["tok_in"] += tok_in
                counter["tok_out"] += tok_out
                completed += 1
                pbar.update(1)

    async with LLMPool(
        base_url=NGINX_BALANCER_URL,
        api_key=API_KEY,
        model=MODEL_NAME,
        max_inflight=cfg.MAX_CONCURRENT_REQUESTS,
        timeout=cfg.LLM_REQUEST_TIMEOUT_SECONDS,
        max_retries=cfg.LLM_REQUEST_RETRY_COUNT,
    ) as pool:
        async with anyio.create_task_group() as tg:
            tg.start_soon(config_reloader, pool)
            async with anyio.create_task_group() as work_tg:
                for idx, dp in enumerate(problems):
                    await spawn_limit.acquire()
                    async def _run(pool=pool, idx=idx, dp=dp):
                        try:
                            await process_one(pool, idx, dp)
                        finally:
                            spawn_limit.release()
                    work_tg.start_soon(_run)
            tg.cancel_scope.cancel()

    pbar.close()

    # Final save
    save_count += 1
    _save_checkpoint(completed, total, save_count)

    total_tok_in = counter["tok_in"]
    total_tok_out = counter["tok_out"]
    total_tok = total_tok_in + total_tok_out

    print(f"\nProcessed {counter['done']} problems ({counter['already_complete']} already complete, {counter['skipped']} skipped, {len(results)} extracted)")
    print(f"Total tokens: in={total_tok_in:,}, out={total_tok_out:,}, total={total_tok:,}")
    print(f"\nExtraction Completeness:")
    print(f"  Both text + solution: {counter['both']}")
    print(f"  Text only:           {counter['text_only']}")
    print(f"  Solution only:       {counter['solution_only']}")

    return results

# Run the processing
results = await run_all()

Extracting:   0%|          | 0/7411 [00:00<?, ?it/s]


[checkpoint 1] 500/7411 saved

[checkpoint 2] 1000/7411 saved

[checkpoint 3] 1500/7411 saved

[checkpoint 4] 2000/7411 saved

[checkpoint 5] 2500/7411 saved

[checkpoint 6] 3000/7411 saved

[checkpoint 7] 3500/7411 saved

[checkpoint 8] 4000/7411 saved

[checkpoint 9] 4500/7411 saved

[checkpoint 10] 5000/7411 saved

[checkpoint 11] 5500/7411 saved
[config] reloaded: MAX_CONCURRENT_REQUESTS: 20 -> 30

[checkpoint 12] 6000/7411 saved

[checkpoint 13] 6500/7411 saved
[config] reloaded: MAX_CONCURRENT_REQUESTS: 30 -> 50

[checkpoint 14] 7000/7411 saved

[checkpoint 15] 7411/7411 saved

Processed 7411 problems (0 already complete, 0 skipped, 7411 extracted)
Total tokens: in=17,067,645, out=2,936,974, total=20,004,619

Extraction Completeness:
  Both text + solution: 7317
  Text only:           94
  Solution only:       0


In [7]:
# ============================================================================
# SAVE RESULTS (atomic final write of full dataset)
# ============================================================================

# Results are already saved incrementally during processing via checkpoints.
# Do a final atomic write to ensure everything is persisted.
OUTPUT_JSONL_PATH.parent.mkdir(parents=True, exist_ok=True)
tmp = OUTPUT_JSONL_PATH.with_suffix('.tmp')
with open(tmp, 'w', encoding='utf-8') as f:
    for dp in all_problems:
        f.write(json.dumps(dp, ensure_ascii=False) + '\n')
tmp.rename(OUTPUT_JSONL_PATH)
print(f"\nAll {len(all_problems)} records saved to {OUTPUT_JSONL_PATH}")


All 7411 records saved to /home/larcanio/AIMO3_v2/data/datasets/gsm8k_reasoning_steps/Compiled-GSM8K/dataset_classfied.jsonl


In [8]:
# ============================================================================
# STATISTICS
# ============================================================================

import pandas as pd
from collections import Counter

if not results:
    print("No results to analyze")
else:
    # --- from_text fields ---
    text_rows = []
    for r in results:
        ft = r.get("math_structure", {}).get("from_text") or {}
        text_rows.append({
            "problem_id": r.get("problem_id", ""),
            "domain": ft.get("domain"),
            "output_type": ft.get("output_type"),
            "objects": ft.get("objects", []),
            "constraints": ft.get("constraints", []),
            "mechanisms": ft.get("mechanisms", []),
            "n_objects": len(ft.get("objects", [])),
            "n_constraints": len(ft.get("constraints", [])),
            "n_mechanisms": len(ft.get("mechanisms", [])),
        })
    df_text = pd.DataFrame(text_rows)

    # --- from_solution fields ---
    sol_rows = []
    for r in results:
        fs = r.get("math_structure", {}).get("from_solution") or {}
        sol_rows.append({
            "problem_id": r.get("problem_id", ""),
            "reasoning_shape": fs.get("reasoning_shape"),
            "case_split": fs.get("case_split"),
            "invariant": fs.get("invariant"),
            "auxiliary_construction": fs.get("auxiliary_construction"),
            "reasoning_depth": fs.get("reasoning_depth"),
            "technique_transitions": fs.get("technique_transitions"),
            "argument_style": fs.get("argument_style"),
            "reasoning_scope": fs.get("reasoning_scope"),
            "dead_end_pruning": fs.get("dead_end_pruning"),
            "intermediate_reuse": fs.get("intermediate_reuse"),
        })
    df_sol = pd.DataFrame(sol_rows)

    print("=" * 70)
    print("TEXT STRUCTURE (from_text)")
    print("=" * 70)

    print("\nDomain Distribution:")
    print(df_text["domain"].value_counts())

    print("\nOutput Type Distribution:")
    print(df_text["output_type"].value_counts(dropna=False))

    # Flatten array fields
    all_objects = [o for objs in df_text["objects"] for o in (objs if isinstance(objs, list) else [])]
    all_constraints = [c for cs in df_text["constraints"] for c in (cs if isinstance(cs, list) else [])]
    all_mechanisms = [m for ms in df_text["mechanisms"] for m in (ms if isinstance(ms, list) else [])]

    print("\nTop Objects:")
    print(pd.Series(Counter(all_objects)).sort_values(ascending=False).head(10))

    print("\nTop Constraints:")
    print(pd.Series(Counter(all_constraints)).sort_values(ascending=False).head(10))

    print("\nTop Mechanisms:")
    if all_mechanisms:
        print(pd.Series(Counter(all_mechanisms)).sort_values(ascending=False).head(10))
    else:
        print("  (none extracted)")

    has_solution = df_sol["reasoning_shape"].notna().sum()
    print(f"\n{'=' * 70}")
    print(f"SOLUTION STRUCTURE (from_solution)  [{has_solution}/{len(df_sol)} have data]")
    print("=" * 70)

    if has_solution > 0:
        df_s = df_sol.dropna(subset=["reasoning_shape"])

        for col in ["reasoning_shape", "case_split", "invariant", "auxiliary_construction",
                     "reasoning_depth", "technique_transitions", "argument_style",
                     "reasoning_scope", "dead_end_pruning", "intermediate_reuse"]:
            print(f"\n{col}:")
            print(df_s[col].value_counts(dropna=False))

    # Combined summary
    print(f"\n{'=' * 70}")
    print("COMBINED SUMMARY")
    print("=" * 70)
    print(f"  Total results:            {len(results)}")
    print(f"  With from_text:           {df_text['domain'].notna().sum()}")
    print(f"  With from_solution:       {has_solution}")
    print(f"  With both:                {(df_text['domain'].notna() & df_sol['reasoning_shape'].notna()).sum()}")

TEXT STRUCTURE (from_text)

Domain Distribution:
domain
algebra          5476
combinatorics     892
number_theory     863
geometry          158
mixed              21
set                 1
Name: count, dtype: int64

Output Type Distribution:
output_type
exact_value       7315
maximum             44
NaN                 23
minimum             16
existence           11
non_existence        1
classification       1
Name: count, dtype: int64

Top Objects:
integer             6187
real                1983
positive_integer    1195
rational              44
sequence              35
set                   27
polygon               20
triangle              10
point                  6
rectangle              6
dtype: int64

Top Constraints:
equality            5608
exists               638
inequality           266
divisibility         256
bounded               86
integral              70
distinct              43
forall                31
positive_integer       7
integer                5
dtype: int64

T