# Dataset Split: Train / Test (Normalized Format)

In [1]:
import json
import os
from pathlib import Path
from collections import Counter
from typing import Optional

import pandas as pd
from sklearn.model_selection import train_test_split

In [None]:
# Configuration

# Source JSONL files to load and concatenate
DATASET_PATH = "/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_full_metadata.jsonl"

# Output directory
OUTPUT_DIR = "/home/larcanio/AIMO3_v2/data/datasets/splits/gemma_balanced"

# Filtering Options

# Only include records where outcome.status == "success"
REQUIRE_CORRECT = True

# Effective difficulty distribution
# Key = bucket level (or None for unsolved), value = % of final dataset.
# Percentages must sum to 100 to preserve proportions.
# Levels not listed will be excluded.
# Derived from computation_buckets: min level where passes >= 1
# Available levels: 0, 1, 2, 3, 4, 5, 6, 7, None (unsolved)
EFFECTIVE_DIFFICULTY: Optional[dict[int | None, int]] = {
    2: 34,
    3: 33,
    4: 33,
}

# Code scores to include (None = all scores)
CODE_SCORES: Optional[list[int]] = None

# Solution token range: (min, max) estimated tokens for problem.original_solution
# Estimated as len(text) // 4. Set to None to disable.
SOLUTION_TOKEN_RANGE: Optional[tuple[int, int]] = None  # e.g. (0, 500)

# Math Structure Filters
# Filters based on math_structure.from_text and math_structure.from_solution
# Set to None to disable a filter.

# Max number of constraints in math_structure.from_text.constraints (None = no limit)
MAX_CONSTRAINT_COUNT: Optional[int] = 1

# Max number of objects in math_structure.from_text.objects (None = no limit)
MAX_OBJECT_COUNT: Optional[int] = 1

# Tag/value filters (empty list = no filtering)
#
# NOTE: these filter at the record level ("must match") and are applied BEFORE
# the % distributions / downsampling.

# Objects tags from math_structure.from_text.objects (list[str])
# EDA note: rich vocab (~320 types, ~99% populated). Example tag seen in sample: "positive_integer".
OBJECTS_FILTER: list[str] = []  # e.g. ["positive_integer"]

# Constraints tags from math_structure.from_text.constraints (list[str])
# Common values (from EDA, N >= 500):
#   equality, exists, inequality, distinct, forall, bounded, divisibility, parity,
#   positive_integer, integral
CONSTRAINTS_FILTER: list[str] = []  # e.g. ["equality", "divisibility"]

# Technique transitions from math_structure.from_solution.technique_transitions (scalar)
# Values (from EDA): 0, 1, 2, 3  (rare literal "null" also appears)
TECHNIQUE_TRANSITIONS_FILTER: list[int] = []  # e.g. [0, 1]

# Reasoning scope from math_structure.from_solution.reasoning_scope (scalar)
# Values (from EDA): "local", "global", "mixed", "none"  (rare literal "null" also appears)
REASONING_SCOPE_FILTER: list[str] = []  # e.g. ["local", "global"]

# Intermediate reuse from math_structure.from_solution.intermediate_reuse (scalar)
# Values (from EDA): "none", "single", "multiple"  (rare literal "null" also appears)
INTERMEDIATE_REUSE_FILTER: list[str] = []  # e.g. ["none", "single"]

# Reasoning depth distribution
# Key = depth string, value = % of final dataset.
# Percentages must sum to 100 to preserve proportions.
# Depths not listed will be excluded.
# Extracted from math_structure.from_solution.reasoning_depth (missing -> "unknown")
REASONING_DEPTH: Optional[dict[str, int]] = {
    "shallow": 50,
    "medium": 50,
}

# Domain Distribution
# Specify percentage (1-100) of final dataset per domain.
# Percentages should sum to 100. Set to None to use all domains equally.
# Domains not listed will be excluded.

DOMAIN_DISTRIBUTION: Optional[dict[str, int]] = {
    "algebra": 65,
    "combinatorics": 20,
    "geometry": 5,
    "number_theory": 10,
    
}

# Dataset Source Distribution
# Specify percentage (1-100) of final dataset per source dataset (record["dataset"]).
# Percentages should sum to <= 100. Set to None to keep all records (no resampling).
# Applied after domain distribution.
# Available datasets: gsm8k, numina1.5, mvidia_reasoning_steps,
#   numina1.5:aops, numina1.5:cn_contest, numina1.5:metamath,
#   numina1.5:inequalities, numina1.5:number_theory

# Example:
# DATASET_DISTRIBUTION: Optional[dict[str, int]] = {
#     "gsm8k": 55,
#     "numina1.5": 15,
#     "mvidia_reasoning_steps": 20,
# }
DATASET_DISTRIBUTION = None

# Target Dataset Size
TARGET_TOTAL_RECORDS: Optional[int] = 15000

# Train/Test Split
TRAIN_RATIO = 0.9
TEST_RATIO = 0.1
RANDOM_SEED = 42

In [3]:
# ── Load  ─────────────────────────────────────────────────────

records = []

with open(DATASET_PATH) as f:
    for line in f:
        line = line.strip()
        if line:
            records.append(json.loads(line))


print(f"Loaded records from {Path(DATASET_PATH).parent.name}/{Path(DATASET_PATH).name}")

print(f"\nTotal records loaded: {len(records):,}")

Loaded records from bucketed/dataset_full_metadata.jsonl

Total records loaded: 71,832


In [4]:
# ── Apply filters ──────────────────────────────────────────────────────────

# Convert lists to sets for O(1) lookup
_code_scores_set = set(CODE_SCORES) if CODE_SCORES else None


def get_effective_difficulty(r: dict) -> int | None:
    """Compute effective difficulty: minimum bucket level where passes >= 1."""
    buckets = r.get("computation_buckets", [])
    passing = [b["level"] for b in buckets if b.get("passes", 0) >= 1]
    return min(passing) if passing else None


def get_reasoning_depth(r: dict) -> str:
    """Null-safe reasoning depth from math_structure.from_solution.reasoning_depth."""
    ms = r.get("math_structure") or {}
    fs = ms.get("from_solution") or {}
    return fs.get("reasoning_depth") or "unknown"


def _estimate_tokens(text: str) -> int:
    """Estimate token count from text (avg ~4 chars per token)."""
    return len(text) // 4


def _fmt_pct_distribution(d: Optional[dict]) -> str:
    if not d:
        return "all"
    return ", ".join(f"{k}:{v}%" for k, v in d.items())


def _fmt_list_filter(xs) -> str:
    if not xs:
        return "all"
    return ", ".join(map(str, xs))


def passes_filters(r: dict) -> bool:
    """Check if a record passes all configured filters."""
    audit = r.get("audit", {})
    outcome = r.get("outcome", {})

    # Correctness filter (outcome.status == "success")
    if REQUIRE_CORRECT and outcome.get("status") != "success":
        return False

    # Code score filter (audit.code_score)
    if _code_scores_set is not None:
        code_score = audit.get("code_score")
        if code_score is not None and code_score not in _code_scores_set:
            return False

    # Solution token range filter (problem.original_solution)
    if SOLUTION_TOKEN_RANGE is not None:
        solution = r.get("problem", {}).get("original_solution") or ""
        sol_tokens = _estimate_tokens(solution)
        lo, hi = SOLUTION_TOKEN_RANGE
        if sol_tokens < lo or sol_tokens > hi:
            return False

    # ── Math structure filters ─────────────────────────────────────────
    ms = r.get("math_structure") or {}
    from_text = ms.get("from_text") or {}
    from_solution = (ms.get("from_solution") or {})

    # Constraint count filter (math_structure.from_text.constraints)
    if MAX_CONSTRAINT_COUNT is not None:
        constraints = from_text.get("constraints") or []
        if len(constraints) > MAX_CONSTRAINT_COUNT:
            return False

    # Object count filter (math_structure.from_text.objects)
    if MAX_OBJECT_COUNT is not None:
        objects = from_text.get("objects") or []
        if len(objects) > MAX_OBJECT_COUNT:
            return False

    # Tag/value filters (empty list = no filtering)

    if OBJECTS_FILTER:
        objects = from_text.get("objects") or []
        if not any(o in OBJECTS_FILTER for o in objects):
            return False

    if CONSTRAINTS_FILTER:
        constraints = from_text.get("constraints") or []
        if not any(c in CONSTRAINTS_FILTER for c in constraints):
            return False

    if TECHNIQUE_TRANSITIONS_FILTER:
        tt = from_solution.get("technique_transitions")
        if tt not in TECHNIQUE_TRANSITIONS_FILTER:
            return False

    if REASONING_SCOPE_FILTER:
        scope = from_solution.get("reasoning_scope") or "unknown"
        if scope not in REASONING_SCOPE_FILTER:
            return False

    if INTERMEDIATE_REUSE_FILTER:
        reuse = from_solution.get("intermediate_reuse") or "unknown"
        if reuse not in INTERMEDIATE_REUSE_FILTER:
            return False

    return True

before = len(records)
records = [r for r in records if passes_filters(r)]
after = len(records)

# Annotate records with effective_difficulty + reasoning_depth for downstream use
for r in records:
    r["_effective_difficulty"] = get_effective_difficulty(r)
    r["_reasoning_depth"] = get_reasoning_depth(r)

print(f"After filtering: {before:,} -> {after:,}  (dropped {before - after:,})")
print("\nActive filters:")
print(f"  - Require correct: {REQUIRE_CORRECT}")
print(f"  - Effective difficulty target: {_fmt_pct_distribution(EFFECTIVE_DIFFICULTY)}")
print(f"  - Code scores: {CODE_SCORES or 'all'}")
if SOLUTION_TOKEN_RANGE:
    print(f"  - Solution tokens: [{SOLUTION_TOKEN_RANGE[0]}, {SOLUTION_TOKEN_RANGE[1]}] (estimated, ~4 chars/token)")
else:
    print("  - Solution tokens: all")
print(f"  - Max constraint count: {MAX_CONSTRAINT_COUNT or 'all'}")
print(f"  - Max object count: {MAX_OBJECT_COUNT or 'all'}")
print(f"  - Objects filter: {_fmt_list_filter(OBJECTS_FILTER)}")
print(f"  - Constraints filter: {_fmt_list_filter(CONSTRAINTS_FILTER)}")
print(f"  - Technique transitions filter: {_fmt_list_filter(TECHNIQUE_TRANSITIONS_FILTER)}")
print(f"  - Reasoning scope filter: {_fmt_list_filter(REASONING_SCOPE_FILTER)}")
print(f"  - Intermediate reuse filter: {_fmt_list_filter(INTERMEDIATE_REUSE_FILTER)}")
print(f"  - Reasoning depth target: {_fmt_pct_distribution(REASONING_DEPTH)}")

After filtering: 71,832 -> 7,123  (dropped 64,709)

Active filters:
  - Require correct: True
  - Effective difficulty: [2, 3, 4]
  - Code scores: all
  - Solution tokens: all
  - Max constraint count: 1
  - Max object count: 1
  - Reasoning depth: ['shallow', 'medium']


In [5]:
# ── Apply distributions + show availability ───────────────────────────────

import random
random.seed(RANDOM_SEED)


def get_domain(r: dict) -> str:
    """Null-safe domain extraction from math_structure.from_text.domain."""
    ms = r.get("math_structure") or {}
    ft = ms.get("from_text") or {}
    return ft.get("domain", "unknown")


def apply_target_distribution(
    records: list[dict],
    *,
    key_fn,
    distribution: Optional[dict],
    label: str,
    target_size: Optional[int],
) -> list[dict]:
    """Downsample records to match distribution exactly (capped at bottleneck)."""
    if distribution is None:
        return records

    total_pct = sum(distribution.values())
    if total_pct != 100:
        raise ValueError(f"{label} percentages sum to {total_pct}%, must be 100%")

    # Partition records into requested keys vs unlisted keys
    by_key: dict = {k: [] for k in distribution}
    other = []
    for r in records:
        k = key_fn(r)
        if k in by_key:
            by_key[k].append(r)
        else:
            other.append(r)

    missing = [k for k, pct in distribution.items() if pct > 0 and len(by_key.get(k, [])) == 0]
    if missing:
        raise ValueError(f"{label}: no available records for key(s): {missing}")

    pool_size = len(records)
    listed_available = sum(len(v) for v in by_key.values())
    
    # Domain distribution logic: cap at bottleneck to preserve exact proportions
    max_proportional_targets = []
    for k, pct in distribution.items():
        available = len(by_key.get(k, []))
        if pct > 0:
            max_for_k = int(available / (pct / 100))
            max_proportional_targets.append((k, max_for_k, available, pct))

    max_proportional_targets.sort(key=lambda x: x[1])
    bn_key, max_achievable, bn_available, bn_pct = max_proportional_targets[0]

    requested_target = target_size or pool_size
    effective_target = min(requested_target, listed_available, max_achievable)

    print(f"\n{label} distribution analysis:")
    print(f"  Pool size:              {pool_size:,}")
    print(f"  Listed keys total:      {listed_available:,}")
    print(f"  Unlisted keys excluded: {len(other):,}")
    print(
        f"  Max proportional target: {max_achievable:,} (bottleneck: {bn_key} "
        f"with {bn_available:,} records @ {bn_pct}%)"
    )
    print(f"  Effective target:       {effective_target:,}\n")

    allocated: dict = {}
    total_allocated = 0

    print(f"Allocation (target: {effective_target:,}):")
    for k, pct in distribution.items():
        available = len(by_key.get(k, []))
        desired = int(effective_target * pct / 100)
        count = min(desired, available)
        allocated[k] = count
        total_allocated += count

        actual_pct = (count / effective_target * 100) if effective_target > 0 else 0
        tag = " (ALL)" if count == available else ""
        print(f"  {k}: {count:,} / {available:,} available{tag} [{actual_pct:.1f}% actual vs {pct}% desired]")

    remainder = effective_target - total_allocated
    if remainder > 0:
        for k, pct in sorted(distribution.items(), key=lambda x: -x[1]):
            if remainder == 0:
                break
            available = len(by_key.get(k, []))
            capacity = available - allocated[k]
            add = min(remainder, capacity)
            if add > 0:
                allocated[k] += add
                total_allocated += add
                remainder -= add

    print(f"\n  Total allocated: {total_allocated:,} / {effective_target:,}")

    selected = []
    for k, count in allocated.items():
        if count <= 0:
            continue
        available = by_key.get(k, [])
        if count >= len(available):
            selected.extend(available)
        else:
            selected.extend(random.sample(available, count))

    random.shuffle(selected)
    return selected


# ── Apply effective difficulty distribution ────────────────────────────────

if EFFECTIVE_DIFFICULTY is not None:
    print("Available records by effective difficulty (after filtering):")
    eff_counts = Counter(r.get("_effective_difficulty") for r in records)
    for k, v in sorted(eff_counts.items(), key=lambda x: (x[0] is None, x[0])):
        print(f"  {k}: {v:,}")

    records = apply_target_distribution(
        records,
        key_fn=lambda r: r.get("_effective_difficulty"),
        distribution=EFFECTIVE_DIFFICULTY,
        label="Effective difficulty",
        target_size=TARGET_TOTAL_RECORDS,
    )
    print(f"\nDataset size after effective difficulty distribution: {len(records):,}")


# ── Apply reasoning depth distribution ─────────────────────────────────────

if REASONING_DEPTH is not None:
    print("\nAvailable records by reasoning depth (after previous steps):")
    depth_counts = Counter(r.get("_reasoning_depth", "unknown") for r in records)
    for k, v in sorted(depth_counts.items(), key=lambda x: (-x[1], str(x[0]))):
        print(f"  {k}: {v:,}")

    records = apply_target_distribution(
        records,
        key_fn=lambda r: r.get("_reasoning_depth", "unknown"),
        distribution=REASONING_DEPTH,
        label="Reasoning depth",
        target_size=TARGET_TOTAL_RECORDS,
    )
    print(f"\nDataset size after reasoning depth distribution: {len(records):,}")


# ── Show available data by domain ─────────────────────────────────────────-

by_domain: dict[str, list] = {}
for r in records:
    domain = get_domain(r)
    by_domain.setdefault(domain, []).append(r)

print("\nAvailable records by domain (after filtering + distributions):")
for domain, recs in sorted(by_domain.items(), key=lambda x: -len(x[1])):
    print(f"  {domain}: {len(recs):,}")

Available records by domain (after filtering):
  algebra: 3,581
  number_theory: 1,817
  combinatorics: 1,211
  geometry: 423
  mixed: 62
  unknown: 23
  None: 6


In [6]:
# ── Apply domain distribution ──────────────────────────────────────────────

selected_records = []

if DOMAIN_DISTRIBUTION is not None:
    # Validate percentages
    total_pct = sum(DOMAIN_DISTRIBUTION.values())
    if total_pct > 100:
        raise ValueError(f"Domain percentages sum to {total_pct}%, must be <= 100%")
    
    target_size = TARGET_TOTAL_RECORDS or len(records)
    
    max_proportional_targets = []
    for domain, pct in DOMAIN_DISTRIBUTION.items():
        available = len(by_domain.get(domain, []))
        if pct > 0:
            max_for_domain = int(available / (pct / 100))
            max_proportional_targets.append((domain, max_for_domain, available, pct))
    
    max_proportional_targets.sort(key=lambda x: x[1])
    bottleneck_domain, max_achievable, bn_available, bn_pct = max_proportional_targets[0]

    listed_available = sum(len(by_domain.get(d, [])) for d in DOMAIN_DISTRIBUTION)
    effective_target = min(target_size, listed_available, max_achievable)
    
    print("Domain distribution analysis:")
    print(f"  User target:            {target_size:,}")
    print(f"  Listed domains total:   {listed_available:,}")
    print(f"  Max proportional target: {max_achievable:,} (bottleneck: {bottleneck_domain} "
          f"with {bn_available:,} records @ {bn_pct}%)")
    print(f"  Effective target:       {effective_target:,}\n")
    
    if effective_target < target_size:
        print(f"  NOTE: Capping at {effective_target:,} to maintain exact distribution proportions.\n")
    
    domain_allocated = {}
    total_allocated = 0
    
    print(f"Allocation (target: {effective_target:,}):")
    for domain, pct in DOMAIN_DISTRIBUTION.items():
        available = by_domain.get(domain, [])
        desired = int(effective_target * pct / 100)
        allocated = min(desired, len(available))
        domain_allocated[domain] = allocated
        total_allocated += allocated
        
        actual_pct = (allocated / effective_target * 100) if effective_target > 0 else 0
        tag = " (ALL)" if allocated == len(available) else ""
        print(f"  {domain}: {allocated:,} / {len(available):,} available{tag} "
              f"[{actual_pct:.1f}% actual vs {pct}% desired]")
    
    remainder = effective_target - total_allocated
    if remainder > 0:
        for domain, pct in sorted(DOMAIN_DISTRIBUTION.items(), key=lambda x: -x[1]):
            if remainder == 0:
                break
            available = len(by_domain.get(domain, []))
            capacity = available - domain_allocated[domain]
            add = min(remainder, capacity)
            if add > 0:
                domain_allocated[domain] += add
                total_allocated += add
                remainder -= add
    
    print(f"\n  Total allocated: {total_allocated:,} / {effective_target:,}")
    
    for domain, count in domain_allocated.items():
        if count > 0:
            available = by_domain.get(domain, [])
            if count >= len(available):
                selected_records.extend(available)
            else:
                selected_records.extend(random.sample(available, count))
    
    unused_listed = listed_available - total_allocated
    unused_other = sum(len(v) for d, v in by_domain.items() if d not in DOMAIN_DISTRIBUTION)
    if unused_listed > 0 or unused_other > 0:
        print(f"\n  Unused records: {unused_listed:,} from listed domains, "
              f"{unused_other:,} from unlisted domains ({unused_listed + unused_other:,} total)")
else:
    selected_records = records
    print("No domain distribution specified - using all filtered records")

random.shuffle(selected_records)
records = selected_records

print(f"\nFinal dataset size: {len(records):,}")

Domain distribution analysis:
  User target:            15,000
  Listed domains total:   7,032
  Max proportional target: 5,509 (bottleneck: algebra with 3,581 records @ 65%)
  Effective target:       5,509

  NOTE: Capping at 5,509 to maintain exact distribution proportions.

Allocation (target: 5,509):
  algebra: 3,580 / 3,581 available [65.0% actual vs 65% desired]
  combinatorics: 1,101 / 1,211 available [20.0% actual vs 20% desired]
  geometry: 275 / 423 available [5.0% actual vs 5% desired]
  number_theory: 550 / 1,817 available [10.0% actual vs 10% desired]

  Total allocated: 5,509 / 5,509

  Unused records: 1,523 from listed domains, 91 from unlisted domains (1,614 total)

Final dataset size: 5,509


In [7]:
by_dataset: dict[str, list] = {}
for r in records:
    ds = r.get("dataset", "unknown")
    by_dataset.setdefault(ds, []).append(r)

print("Available records by dataset source (after previous filters):")
for ds, recs in sorted(by_dataset.items(), key=lambda x: -len(x[1])):
    print(f"  {ds}: {len(recs):,}")

if DATASET_DISTRIBUTION is not None:
    total_pct = sum(DATASET_DISTRIBUTION.values())
    if total_pct > 100:
        raise ValueError(f"Dataset percentages sum to {total_pct}%, must be <= 100%")

    pool_size = len(records)

    max_proportional_targets = []
    for ds_name, pct in DATASET_DISTRIBUTION.items():
        available = len(by_dataset.get(ds_name, []))
        if pct > 0:
            max_for_ds = int(available / (pct / 100))
            max_proportional_targets.append((ds_name, max_for_ds, available, pct))

    max_proportional_targets.sort(key=lambda x: x[1])
    bn_ds, max_achievable, bn_available, bn_pct = max_proportional_targets[0]

    listed_available = sum(len(by_dataset.get(d, [])) for d in DATASET_DISTRIBUTION)
    effective_target = min(pool_size, listed_available, max_achievable)

    print("\nDataset distribution analysis:")
    print(f"  Pool size:              {pool_size:,}")
    print(f"  Listed datasets total:  {listed_available:,}")
    print(f"  Max proportional target: {max_achievable:,} (bottleneck: {bn_ds} "
          f"with {bn_available:,} records @ {bn_pct}%)")
    print(f"  Effective target:       {effective_target:,}")

    if effective_target < pool_size:
        print(f"\n  NOTE: Capping at {effective_target:,} to maintain exact dataset proportions.")

    ds_allocated = {}
    total_allocated = 0

    print(f"\nAllocation (target: {effective_target:,}):")
    for ds_name, pct in DATASET_DISTRIBUTION.items():
        available = by_dataset.get(ds_name, [])
        desired = int(effective_target * pct / 100)
        allocated = min(desired, len(available))
        ds_allocated[ds_name] = allocated
        total_allocated += allocated

        actual_pct = (allocated / effective_target * 100) if effective_target > 0 else 0
        tag = " (ALL)" if allocated == len(available) else ""
        print(f"  {ds_name}: {allocated:,} / {len(available):,} available{tag} "
              f"[{actual_pct:.1f}% actual vs {pct}% desired]")

    remainder = effective_target - total_allocated
    if remainder > 0:
        for ds_name, pct in sorted(DATASET_DISTRIBUTION.items(), key=lambda x: -x[1]):
            if remainder == 0:
                break
            available = len(by_dataset.get(ds_name, []))
            capacity = available - ds_allocated[ds_name]
            add = min(remainder, capacity)
            if add > 0:
                ds_allocated[ds_name] += add
                total_allocated += add
                remainder -= add

    print(f"\n  Total allocated: {total_allocated:,} / {effective_target:,}")

    ds_selected = []
    for ds_name, count in ds_allocated.items():
        if count > 0:
            available = by_dataset.get(ds_name, [])
            if count >= len(available):
                ds_selected.extend(available)
            else:
                ds_selected.extend(random.sample(available, count))

    unused_listed = listed_available - total_allocated
    unused_other = sum(len(v) for d, v in by_dataset.items() if d not in DATASET_DISTRIBUTION)
    if unused_listed > 0 or unused_other > 0:
        print(f"\n  Unused records: {unused_listed:,} from listed datasets, "
              f"{unused_other:,} from unlisted datasets ({unused_listed + unused_other:,} total)")

    random.shuffle(ds_selected)
    records = ds_selected
    print(f"\nDataset size after dataset distribution: {len(records):,}")
else:
    print("\nNo dataset distribution specified - keeping all records")
    print(f"Current dataset size: {len(records):,}")

Available records by dataset source (after previous filters):
  gsm8k: 3,190
  numina1.5:mixed: 1,128
  open_math_reasoning: 947
  numina1.5:metamath: 147
  numina1.5:cn_contest: 49
  numina1.5:aops: 37
  numina1.5:number_theory: 7
  numina1.5:inequalities: 4

No dataset distribution specified - keeping all records
Current dataset size: 5,509


In [8]:
print("Final dataset validation:")
print(f"  Target: {TARGET_TOTAL_RECORDS or 'unlimited'}")
print(f"  Actual: {len(records):,}")

if TARGET_TOTAL_RECORDS is not None:
    if len(records) == TARGET_TOTAL_RECORDS:
        print("  ✓ Target achieved")
    elif len(records) < TARGET_TOTAL_RECORDS:
        shortfall = TARGET_TOTAL_RECORDS - len(records)
        pct = (len(records) / TARGET_TOTAL_RECORDS) * 100
        print(f"  ⚠️  Shortfall: {shortfall:,} records ({pct:.1f}% of target)")
    else:
        # This shouldn't happen with the new algorithm, but safety check
        print(f"  ⚠️  WARNING: Exceeded target by {len(records) - TARGET_TOTAL_RECORDS:,} records")
        print("     Downsampling to target size...")
        records = random.sample(records, TARGET_TOTAL_RECORDS)
        print(f"  ✓ Downsampled to {len(records):,} records")

Final dataset validation:
  Target: 15000
  Actual: 5,509
  ⚠️  Shortfall: 9,491 records (36.7% of target)


In [9]:
# Build stratification keys based on domain + effective_difficulty + reasoning_depth
strat_keys = []
for r in records:
    domain = get_domain(r)
    eff_diff = r.get("_effective_difficulty", "NA")
    depth = r.get("_reasoning_depth", "unknown")
    strat_keys.append(f"{domain}_D{eff_diff}_R{depth}")

print("Stratification groups:")
for k, v in sorted(Counter(strat_keys).items()):
    print(f"  {k}: {v:,}")

strat_counts = Counter(strat_keys)
safe_keys = [
    k if strat_counts[k] >= 2 else "_rare_"
    for k in strat_keys
]

rare_count = sum(1 for k in safe_keys if k == "_rare_")
if rare_count:
    print(f"\nMerged {rare_count} record(s) from singleton strata into '_rare_' group")

indices = list(range(len(records)))

train_idx, test_idx = train_test_split(
    indices,
    test_size=TEST_RATIO,
    random_state=RANDOM_SEED,
    stratify=safe_keys,
)

train_records = [records[i] for i in train_idx]
test_records = [records[i] for i in test_idx]

print(f"\nTrain: {len(train_records):,}")
print(f"Test:  {len(test_records):,}")

Stratification groups:
  algebra_D2: 2,217
  algebra_D3: 1,094
  algebra_D4: 270
  combinatorics_D2: 549
  combinatorics_D3: 402
  combinatorics_D4: 152
  geometry_D2: 130
  geometry_D3: 108
  geometry_D4: 37
  number_theory_D2: 303
  number_theory_D3: 161
  number_theory_D4: 86

Train: 4,958
Test:  551


In [10]:
train_strat = Counter(strat_keys[i] for i in train_idx)
test_strat = Counter(strat_keys[i] for i in test_idx)
all_groups = sorted(set(strat_keys))

rows = []
for g in all_groups:
    total = train_strat.get(g, 0) + test_strat.get(g, 0)
    rows.append({
        "group": g,
        "total": total,
        "train": train_strat.get(g, 0),
        "test": test_strat.get(g, 0),
        "train_%": round(train_strat.get(g, 0) / total * 100, 1) if total else 0,
    })

pd.DataFrame(rows)

Unnamed: 0,group,total,train,test,train_%
0,algebra_D2,2217,1995,222,90.0
1,algebra_D3,1094,985,109,90.0
2,algebra_D4,270,243,27,90.0
3,combinatorics_D2,549,494,55,90.0
4,combinatorics_D3,402,362,40,90.0
5,combinatorics_D4,152,137,15,90.1
6,geometry_D2,130,117,13,90.0
7,geometry_D3,108,97,11,89.8
8,geometry_D4,37,33,4,89.2
9,number_theory_D2,303,273,30,90.1


In [11]:
train_ds = Counter(r.get("dataset", "unknown") for r in train_records)
test_ds = Counter(r.get("dataset", "unknown") for r in test_records)
all_datasets = sorted(set(train_ds) | set(test_ds))

ds_rows = []
for ds in all_datasets:
    total = train_ds.get(ds, 0) + test_ds.get(ds, 0)
    ds_rows.append({
        "dataset": ds,
        "total": total,
        "train": train_ds.get(ds, 0),
        "test": test_ds.get(ds, 0),
        "pct_of_total": round(total / len(records) * 100, 1),
    })

ds_rows.sort(key=lambda x: -x["total"])

print("Samples per dataset:")
pd.DataFrame(ds_rows)

Samples per dataset:


Unnamed: 0,dataset,total,train,test,pct_of_total
0,gsm8k,3190,2868,322,57.9
1,numina1.5:mixed,1128,1020,108,20.5
2,open_math_reasoning,947,850,97,17.2
3,numina1.5:metamath,147,132,15,2.7
4,numina1.5:cn_contest,49,44,5,0.9
5,numina1.5:aops,37,34,3,0.7
6,numina1.5:number_theory,7,6,1,0.1
7,numina1.5:inequalities,4,4,0,0.1


In [12]:
os.makedirs(OUTPUT_DIR, exist_ok=True)

train_path = os.path.join(OUTPUT_DIR, "train.jsonl")
test_path = os.path.join(OUTPUT_DIR, "test.jsonl")

with open(train_path, "w") as f:
    for r in train_records:
        f.write(json.dumps(r) + "\n")

with open(test_path, "w") as f:
    for r in test_records:
        f.write(json.dumps(r) + "\n")

print(f"Saved {len(train_records):,} train records -> {train_path}")
print(f"Saved {len(test_records):,} test records  -> {test_path}")

print(f"\n{'='*60}")
print("DATASET SUMMARY")
print(f"{'='*60}")
print(f"Sources: {len(DATASET_PATH)} datasets")
print(f"Filters: correct={REQUIRE_CORRECT}")
print(f"         effective_difficulty={_fmt_pct_distribution(EFFECTIVE_DIFFICULTY)}")
print(f"         code_scores={CODE_SCORES or 'all'}")
if SOLUTION_TOKEN_RANGE:
    print(f"         solution_tokens=[{SOLUTION_TOKEN_RANGE[0]}-{SOLUTION_TOKEN_RANGE[1]}]")
else:
    print("         solution_tokens=all")
print(f"         max_constraint_count={MAX_CONSTRAINT_COUNT or 'all'}")
print(f"         max_object_count={MAX_OBJECT_COUNT or 'all'}")
print(f"         reasoning_depth={_fmt_pct_distribution(REASONING_DEPTH)}")
if DOMAIN_DISTRIBUTION:
    print(f"Domains: {', '.join(f'{d}:{p}%' for d, p in DOMAIN_DISTRIBUTION.items())}")
else:
    print("Domains: all (no distribution)")
if DATASET_DISTRIBUTION:
    print(f"Datasets: {', '.join(f'{d}:{p}%' for d, p in DATASET_DISTRIBUTION.items())}")
else:
    print("Datasets: all (no distribution)")
print(f"Target:  {TARGET_TOTAL_RECORDS or 'unlimited'}")
print(f"Split:   {TRAIN_RATIO*100:.0f}% train / {TEST_RATIO*100:.0f}% test")
print(f"Output:  {OUTPUT_DIR}")
print(f"{'='*60}")

Saved 4,958 train records -> /home/larcanio/AIMO3_v2/data/datasets/splits/gemma_balanced/train.jsonl
Saved 551 test records  -> /home/larcanio/AIMO3_v2/data/datasets/splits/gemma_balanced/test.jsonl

DATASET SUMMARY
Sources: 87 datasets
Filters: correct=True
         effective_difficulty=[2, 3, 4]
         code_scores=all
         solution_tokens=all
         max_constraint_count=1
         max_object_count=1
         reasoning_depth=['shallow', 'medium']
Domains: algebra:65%, combinatorics:20%, geometry:5%, number_theory:10%
Datasets: all (no distribution)
Target:  15000
Split:   90% train / 10% test
Output:  /home/larcanio/AIMO3_v2/data/datasets/splits/gemma_balanced
