# Dataset Split: Train / Test (Normalized Format)

Load normalized JSONL datasets with flexible filtering:
- Domain-specific distributions
- Math difficulty level range
- Code score range  
- Token length range
- Target dataset size

In [1]:
import json
import os
from pathlib import Path
from collections import Counter
from typing import Optional

import pandas as pd
from sklearn.model_selection import train_test_split

In [None]:
# ── Configuration ──────────────────────────────────────────────────────────

# Source JSONL files to load and concatenate
DATASET_PATHS = [
    "/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_1_5_complete_effective_difficulty.jsonl",
]

# Output directory
OUTPUT_DIR = "/home/larcanio/AIMO3_v2/data/datasets/splits/algebra_specialist"

# ── Filtering Options ──────────────────────────────────────────────────────

# Only include records where outcome.status == "success"
REQUIRE_CORRECT = True

# Effective difficulty levels to include (None = all levels)
# Derived from computation_buckets: min level where passes >= 1
# Available levels: 0, 1, 2, 3, 4 (None = problem unsolved by any bucket)
EFFECTIVE_DIFFICULTY: Optional[list[int]] = [1,2,3]

# Code scores to include (None = all scores)
CODE_SCORES: Optional[list[int]] = None

# Token lengths to include (None = all lengths)
# Use a list of values, or a range like list(range(50, 1001))
TOKEN_LENGTHS: Optional[list[int]] = list(range(0, 512))

# Solution token range: (min, max) estimated tokens for problem.original_solution
# Estimated as len(text) // 4. Set to None to disable.
SOLUTION_TOKEN_RANGE: Optional[tuple[int, int]] = None #(0,512) #None  # e.g. (0, 500)

# ── Math Structure Filters ─────────────────────────────────────────────────
# Filters based on math_structure.from_text and math_structure.from_solution
# Set to None to disable a filter.

# Max number of constraints in math_structure.from_text.constraints (None = no limit)
MAX_CONSTRAINT_COUNT: Optional[int] = 3

# Max number of objects in math_structure.from_text.objects (None = no limit)
MAX_OBJECT_COUNT: Optional[int] = 3

# Allowed reasoning depths from math_structure.from_solution.reasoning_depth (None = all)
REASONING_DEPTH: Optional[list[str]] = ["shallow", "medium"] # ["deep", "shallow", "medium"]

# ── Domain Distribution ────────────────────────────────────────────────────
# Specify percentage (1-100) of final dataset per domain.
# Percentages should sum to 100. Set to None to use all domains equally.

DOMAIN_DISTRIBUTION: Optional[dict[str, int]] = {
    "algebra": 90,
    "combinatorics": 5,
    "geometry": 2,
    "number_theory": 1,
    # Domains not listed will be excluded unless INCLUDE_UNLISTED_DOMAINS = True
}

# ── Dataset Source Distribution ────────────────────────────────────────────
# Specify percentage (1-100) of final dataset per source dataset (record["dataset"]).
# Percentages should sum to <= 100. Set to None to keep all records (no resampling).
# Applied after domain distribution.
# Available datasets: gsm8k, numina1.5, mvidia_reasoning_steps,
#   numina1.5:aops, numina1.5:cn_contest, numina1.5:metamath,
#   numina1.5:inequalities, numina1.5:number_theory

DATASET_DISTRIBUTION: Optional[dict[str, int]] = None

# 	dataset	total	train	test	pct_of_total
# 0	gsm8k
# 1	numina1.5
# 2	mvidia_reasoning_steps
# 3	numina1.5:metamath
# 4	numina1.5:cn_contest
# 5	numina1.5:aops
# 6	numina1.5:inequalities

# Example:
DATASET_DISTRIBUTION: Optional[dict[str, int]] = {
    "gsm8k": 80,
    "numina1.5": 15,
    "mvidia_reasoning_steps": 5,
}

# ── Target Dataset Size ────────────────────────────────────────────────────
# Maximum total records before split. Set to None for no limit.
TARGET_TOTAL_RECORDS: Optional[int] = 20000

# ── Train/Test Split ───────────────────────────────────────────────────────
TRAIN_RATIO = 0.9
TEST_RATIO = 0.1
RANDOM_SEED = 42


In [3]:
# ── Load & concatenate ─────────────────────────────────────────────────────

records = []
for path in DATASET_PATHS:
    count = 0
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))
                count += 1
    print(f"Loaded {count:,} records from {Path(path).parent.name}/{Path(path).name}")

print(f"\nTotal records loaded: {len(records):,}")

Loaded 71,832 records from bucketed/dataset_1_5_complete_effective_difficulty.jsonl

Total records loaded: 71,832


In [14]:
# ── Apply filters ──────────────────────────────────────────────────────────

# Convert lists to sets for O(1) lookup
_effective_difficulty_set = set(EFFECTIVE_DIFFICULTY) if EFFECTIVE_DIFFICULTY else None
_code_scores_set = set(CODE_SCORES) if CODE_SCORES else None
_token_lengths_set = set(TOKEN_LENGTHS) if TOKEN_LENGTHS else None
_reasoning_depth_set = set(REASONING_DEPTH) if REASONING_DEPTH else None


def get_effective_difficulty(r: dict) -> int | None:
    """Compute effective difficulty: minimum bucket level where passes >= 1."""
    buckets = r.get("computation_buckets", [])
    passing = [b["level"] for b in buckets if b.get("passes", 0) >= 1]
    return min(passing) if passing else None


def _estimate_tokens(text: str) -> int:
    """Estimate token count from text (avg ~4 chars per token)."""
    return len(text) // 4


def passes_filters(r: dict) -> bool:
    """Check if a record passes all configured filters."""
    audit = r.get("audit", {})
    outcome = r.get("outcome", {})
    
    # Correctness filter (outcome.status == "success")
    if REQUIRE_CORRECT and outcome.get("status") != "success":
        return False
    
    # Effective difficulty filter (from computation_buckets)
    if _effective_difficulty_set is not None:
        eff_diff = get_effective_difficulty(r)
        if eff_diff is None or eff_diff not in _effective_difficulty_set:
            return False
    
    # Code score filter (audit.code_score)
    if _code_scores_set is not None:
        code_score = audit.get("code_score")
        if code_score is not None and code_score not in _code_scores_set:
            return False
    
    # Token length filter (audit_tokens.completion)
    if _token_lengths_set is not None:
        tokens = r.get("audit_tokens", {}).get("completion")
        if tokens is not None and tokens not in _token_lengths_set:
            return False
    
    # Solution token range filter (problem.original_solution)
    if SOLUTION_TOKEN_RANGE is not None:
        solution = r.get("problem", {}).get("original_solution") or ""
        sol_tokens = _estimate_tokens(solution)
        lo, hi = SOLUTION_TOKEN_RANGE
        if sol_tokens < lo or sol_tokens > hi:
            return False
    
    # ── Math structure filters ─────────────────────────────────────────
    ms = r.get("math_structure") or {}
    from_text = ms.get("from_text") or {}
    from_solution = ms.get("from_solution") or {}
    
    # Constraint count filter (math_structure.from_text.constraints)
    if MAX_CONSTRAINT_COUNT is not None:
        constraints = from_text.get("constraints") or []
        if len(constraints) > MAX_CONSTRAINT_COUNT:
            return False
    
    # Object count filter (math_structure.from_text.objects)
    if MAX_OBJECT_COUNT is not None:
        objects = from_text.get("objects") or []
        if len(objects) > MAX_OBJECT_COUNT:
            return False
    
    # Reasoning depth filter (math_structure.from_solution.reasoning_depth)
    if _reasoning_depth_set is not None:
        reasoning_depth = from_solution.get("reasoning_depth")
        if reasoning_depth is not None and reasoning_depth not in _reasoning_depth_set:
            return False
    
    return True

before = len(records)
records = [r for r in records if passes_filters(r)]
after = len(records)

# Annotate records with effective_difficulty for downstream use
for r in records:
    r["_effective_difficulty"] = get_effective_difficulty(r)

print(f"After filtering: {before:,} -> {after:,}  (dropped {before - after:,})")
print("\nActive filters:")
print(f"  - Require correct: {REQUIRE_CORRECT}")
print(f"  - Effective difficulty: {EFFECTIVE_DIFFICULTY or 'all'}")
print(f"  - Code scores: {CODE_SCORES or 'all'}")
if TOKEN_LENGTHS:
    print(f"  - Token lengths: [{min(TOKEN_LENGTHS)}, {max(TOKEN_LENGTHS)}] ({len(TOKEN_LENGTHS)} values)")
else:
    print("  - Token lengths: all")
if SOLUTION_TOKEN_RANGE:
    print(f"  - Solution tokens: [{SOLUTION_TOKEN_RANGE[0]}, {SOLUTION_TOKEN_RANGE[1]}] (estimated, ~4 chars/token)")
else:
    print("  - Solution tokens: all")
print(f"  - Max constraint count: {MAX_CONSTRAINT_COUNT or 'all'}")
print(f"  - Max object count: {MAX_OBJECT_COUNT or 'all'}")
print(f"  - Reasoning depth: {REASONING_DEPTH or 'all'}")

After filtering: 5,997 -> 5,997  (dropped 0)

Active filters:
  - Require correct: True
  - Effective difficulty: [0, 1, 2, 3]
  - Code scores: all
  - Token lengths: [0, 511] (512 values)
  - Solution tokens: all
  - Max constraint count: 3
  - Max object count: 3
  - Reasoning depth: ['shallow', 'medium']


In [15]:
# ── Show available data by domain ──────────────────────────────────────────

import random
random.seed(RANDOM_SEED)


def get_domain(r: dict) -> str:
    """Null-safe domain extraction from math_structure.from_text.domain."""
    ms = r.get("math_structure") or {}
    ft = ms.get("from_text") or {}
    return ft.get("domain", "unknown")


# Group records by domain
by_domain: dict[str, list] = {}
for r in records:
    domain = get_domain(r)
    by_domain.setdefault(domain, []).append(r)

print("Available records by domain (after filtering):")
for domain, recs in sorted(by_domain.items(), key=lambda x: -len(x[1])):
    print(f"  {domain}: {len(recs):,}")

Available records by domain (after filtering):
  algebra: 5,598
  combinatorics: 283
  geometry: 82
  number_theory: 34


In [16]:
# ── Apply domain distribution ──────────────────────────────────────────────

selected_records = []

if DOMAIN_DISTRIBUTION is not None:
    # Validate percentages
    total_pct = sum(DOMAIN_DISTRIBUTION.values())
    if total_pct > 100:
        raise ValueError(f"Domain percentages sum to {total_pct}%, must be <= 100%")
    
    # Determine target size for percentage calculation
    target_size = TARGET_TOTAL_RECORDS or len(records)
    
    # Compute the maximum target that maintains exact distribution proportions.
    # The bottleneck domain (fewest records relative to its desired %) caps the total.
    max_proportional_targets = []
    for domain, pct in DOMAIN_DISTRIBUTION.items():
        available = len(by_domain.get(domain, []))
        if pct > 0:
            max_for_domain = int(available / (pct / 100))
            max_proportional_targets.append((domain, max_for_domain, available, pct))
    
    # Sort to find bottleneck
    max_proportional_targets.sort(key=lambda x: x[1])
    bottleneck_domain, max_achievable, bn_available, bn_pct = max_proportional_targets[0]
    
    # Effective target: min of user target, total available (4 listed domains), and max achievable
    listed_available = sum(len(by_domain.get(d, [])) for d in DOMAIN_DISTRIBUTION)
    effective_target = min(target_size, listed_available, max_achievable)
    
    print(f"Domain distribution analysis:")
    print(f"  User target:            {target_size:,}")
    print(f"  Listed domains total:   {listed_available:,}")
    print(f"  Max proportional target: {max_achievable:,} (bottleneck: {bottleneck_domain} "
          f"with {bn_available:,} records @ {bn_pct}%)")
    print(f"  Effective target:       {effective_target:,}\n")
    
    if effective_target < target_size:
        print(f"  NOTE: Capping at {effective_target:,} to maintain exact distribution proportions.\n")
    
    # Allocate per domain using effective target
    domain_allocated = {}
    total_allocated = 0
    
    print(f"Allocation (target: {effective_target:,}):")
    for domain, pct in DOMAIN_DISTRIBUTION.items():
        available = by_domain.get(domain, [])
        desired = int(effective_target * pct / 100)
        allocated = min(desired, len(available))
        domain_allocated[domain] = allocated
        total_allocated += allocated
        
        actual_pct = (allocated / effective_target * 100) if effective_target > 0 else 0
        tag = " (ALL)" if allocated == len(available) else ""
        print(f"  {domain}: {allocated:,} / {len(available):,} available{tag} "
              f"[{actual_pct:.1f}% actual vs {pct}% desired]")
    
    # Handle rounding remainder: distribute 1 record at a time to domains with capacity
    remainder = effective_target - total_allocated
    if remainder > 0:
        for domain, pct in sorted(DOMAIN_DISTRIBUTION.items(), key=lambda x: -x[1]):
            if remainder == 0:
                break
            available = len(by_domain.get(domain, []))
            capacity = available - domain_allocated[domain]
            add = min(remainder, capacity)
            if add > 0:
                domain_allocated[domain] += add
                total_allocated += add
                remainder -= add
    
    print(f"\n  Total allocated: {total_allocated:,} / {effective_target:,}")
    
    # Sample from each domain
    for domain, count in domain_allocated.items():
        if count > 0:
            available = by_domain.get(domain, [])
            if count >= len(available):
                selected_records.extend(available)
            else:
                selected_records.extend(random.sample(available, count))
    
    # Report unused records
    unused_listed = listed_available - total_allocated
    unused_other = sum(len(v) for d, v in by_domain.items() if d not in DOMAIN_DISTRIBUTION)
    if unused_listed > 0 or unused_other > 0:
        print(f"\n  Unused records: {unused_listed:,} from listed domains, "
              f"{unused_other:,} from unlisted domains ({unused_listed + unused_other:,} total)")
else:
    # No domain distribution specified - take all
    selected_records = records
    print("No domain distribution specified - using all filtered records")

# Shuffle to mix domains
random.shuffle(selected_records)
records = selected_records

print(f"\nFinal dataset size: {len(records):,}")

Domain distribution analysis:
  User target:            20,000
  Listed domains total:   5,997
  Max proportional target: 3,400 (bottleneck: number_theory with 34 records @ 1%)
  Effective target:       3,400

  NOTE: Capping at 3,400 to maintain exact distribution proportions.

Allocation (target: 3,400):
  algebra: 3,060 / 5,598 available [90.0% actual vs 90% desired]
  combinatorics: 170 / 283 available [5.0% actual vs 5% desired]
  geometry: 68 / 82 available [2.0% actual vs 2% desired]
  number_theory: 34 / 34 available (ALL) [1.0% actual vs 1% desired]

  Total allocated: 3,400 / 3,400

  Unused records: 2,597 from listed domains, 0 from unlisted domains (2,597 total)

Final dataset size: 3,400


In [17]:
# ── Apply dataset source distribution ──────────────────────────────────────

# Group current records by dataset source
by_dataset: dict[str, list] = {}
for r in records:
    ds = r.get("dataset", "unknown")
    by_dataset.setdefault(ds, []).append(r)

print("Available records by dataset source (after previous filters):")
for ds, recs in sorted(by_dataset.items(), key=lambda x: -len(x[1])):
    print(f"  {ds}: {len(recs):,}")

if DATASET_DISTRIBUTION is not None:
    # Validate percentages
    total_pct = sum(DATASET_DISTRIBUTION.values())
    if total_pct > 100:
        raise ValueError(f"Dataset percentages sum to {total_pct}%, must be <= 100%")

    pool_size = len(records)

    # Find the max total that maintains exact proportions (bottleneck analysis)
    max_proportional_targets = []
    for ds_name, pct in DATASET_DISTRIBUTION.items():
        available = len(by_dataset.get(ds_name, []))
        if pct > 0:
            max_for_ds = int(available / (pct / 100))
            max_proportional_targets.append((ds_name, max_for_ds, available, pct))

    max_proportional_targets.sort(key=lambda x: x[1])
    bn_ds, max_achievable, bn_available, bn_pct = max_proportional_targets[0]

    listed_available = sum(len(by_dataset.get(d, [])) for d in DATASET_DISTRIBUTION)
    effective_target = min(pool_size, listed_available, max_achievable)

    print(f"\nDataset distribution analysis:")
    print(f"  Pool size:              {pool_size:,}")
    print(f"  Listed datasets total:  {listed_available:,}")
    print(f"  Max proportional target: {max_achievable:,} (bottleneck: {bn_ds} "
          f"with {bn_available:,} records @ {bn_pct}%)")
    print(f"  Effective target:       {effective_target:,}")

    if effective_target < pool_size:
        print(f"\n  NOTE: Capping at {effective_target:,} to maintain exact dataset proportions.")

    # Allocate per dataset
    ds_allocated = {}
    total_allocated = 0

    print(f"\nAllocation (target: {effective_target:,}):")
    for ds_name, pct in DATASET_DISTRIBUTION.items():
        available = by_dataset.get(ds_name, [])
        desired = int(effective_target * pct / 100)
        allocated = min(desired, len(available))
        ds_allocated[ds_name] = allocated
        total_allocated += allocated

        actual_pct = (allocated / effective_target * 100) if effective_target > 0 else 0
        tag = " (ALL)" if allocated == len(available) else ""
        print(f"  {ds_name}: {allocated:,} / {len(available):,} available{tag} "
              f"[{actual_pct:.1f}% actual vs {pct}% desired]")

    # Distribute rounding remainder
    remainder = effective_target - total_allocated
    if remainder > 0:
        for ds_name, pct in sorted(DATASET_DISTRIBUTION.items(), key=lambda x: -x[1]):
            if remainder == 0:
                break
            available = len(by_dataset.get(ds_name, []))
            capacity = available - ds_allocated[ds_name]
            add = min(remainder, capacity)
            if add > 0:
                ds_allocated[ds_name] += add
                total_allocated += add
                remainder -= add

    print(f"\n  Total allocated: {total_allocated:,} / {effective_target:,}")

    # Sample from each dataset
    ds_selected = []
    for ds_name, count in ds_allocated.items():
        if count > 0:
            available = by_dataset.get(ds_name, [])
            if count >= len(available):
                ds_selected.extend(available)
            else:
                ds_selected.extend(random.sample(available, count))

    # Report unused
    unused_listed = listed_available - total_allocated
    unused_other = sum(len(v) for d, v in by_dataset.items() if d not in DATASET_DISTRIBUTION)
    if unused_listed > 0 or unused_other > 0:
        print(f"\n  Unused records: {unused_listed:,} from listed datasets, "
              f"{unused_other:,} from unlisted datasets ({unused_listed + unused_other:,} total)")

    random.shuffle(ds_selected)
    records = ds_selected
    print(f"\nDataset size after dataset distribution: {len(records):,}")
else:
    print("\nNo dataset distribution specified - keeping all records")
    print(f"Current dataset size: {len(records):,}")

Available records by dataset source (after previous filters):
  gsm8k: 2,711
  numina1.5: 510
  mvidia_reasoning_steps: 179

Dataset distribution analysis:
  Pool size:              3,400
  Listed datasets total:  3,400
  Max proportional target: 3,388 (bottleneck: gsm8k with 2,711 records @ 80%)
  Effective target:       3,388

  NOTE: Capping at 3,388 to maintain exact dataset proportions.

Allocation (target: 3,388):
  gsm8k: 2,710 / 2,711 available [80.0% actual vs 80% desired]
  numina1.5: 508 / 510 available [15.0% actual vs 15% desired]
  mvidia_reasoning_steps: 169 / 179 available [5.0% actual vs 5% desired]

  Total allocated: 3,388 / 3,388

  Unused records: 12 from listed datasets, 0 from unlisted datasets (12 total)

Dataset size after dataset distribution: 3,388


In [18]:
# ── Validation: Final size check ──────────────────────────────────────────

# This cell validates the final dataset size
# (The new sampling algorithm in the previous cell should have already handled sizing)

print("Final dataset validation:")
print(f"  Target: {TARGET_TOTAL_RECORDS or 'unlimited'}")
print(f"  Actual: {len(records):,}")

if TARGET_TOTAL_RECORDS is not None:
    if len(records) == TARGET_TOTAL_RECORDS:
        print("  ✓ Target achieved")
    elif len(records) < TARGET_TOTAL_RECORDS:
        shortfall = TARGET_TOTAL_RECORDS - len(records)
        pct = (len(records) / TARGET_TOTAL_RECORDS) * 100
        print(f"  ⚠️  Shortfall: {shortfall:,} records ({pct:.1f}% of target)")
    else:
        # This shouldn't happen with the new algorithm, but safety check
        print(f"  ⚠️  WARNING: Exceeded target by {len(records) - TARGET_TOTAL_RECORDS:,} records")
        print("     Downsampling to target size...")
        records = random.sample(records, TARGET_TOTAL_RECORDS)
        print(f"  ✓ Downsampled to {len(records):,} records")

Final dataset validation:
  Target: 20000
  Actual: 3,388
  ⚠️  Shortfall: 16,612 records (16.9% of target)


In [19]:
# ── Build stratification key & split ───────────────────────────────────────

# Build stratification keys based on domain + effective_difficulty
# Uses the null-safe get_domain() helper defined in cell 5
strat_keys = []
for r in records:
    domain = get_domain(r)
    eff_diff = r.get("_effective_difficulty", "NA")
    strat_keys.append(f"{domain}_D{eff_diff}")

print("Stratification groups:")
for k, v in sorted(Counter(strat_keys).items()):
    print(f"  {k}: {v:,}")

# Handle singleton strata (merge into fallback group)
strat_counts = Counter(strat_keys)
safe_keys = [
    k if strat_counts[k] >= 2 else "_rare_"
    for k in strat_keys
]

rare_count = sum(1 for k in safe_keys if k == "_rare_")
if rare_count:
    print(f"\nMerged {rare_count} record(s) from singleton strata into '_rare_' group")

# Stratified split
indices = list(range(len(records)))

train_idx, test_idx = train_test_split(
    indices,
    test_size=TEST_RATIO,
    random_state=RANDOM_SEED,
    stratify=safe_keys,
)

train_records = [records[i] for i in train_idx]
test_records = [records[i] for i in test_idx]

print(f"\nTrain: {len(train_records):,}")
print(f"Test:  {len(test_records):,}")

Stratification groups:
  algebra_D0: 26
  algebra_D1: 447
  algebra_D2: 1,742
  algebra_D3: 902
  combinatorics_D0: 2
  combinatorics_D1: 18
  combinatorics_D2: 90
  combinatorics_D3: 59
  geometry_D0: 2
  geometry_D1: 6
  geometry_D2: 30
  geometry_D3: 30
  number_theory_D0: 4
  number_theory_D1: 6
  number_theory_D2: 9
  number_theory_D3: 15

Train: 3,049
Test:  339


In [20]:
# ── Verify stratification proportions ──────────────────────────────────────

train_strat = Counter(strat_keys[i] for i in train_idx)
test_strat = Counter(strat_keys[i] for i in test_idx)
all_groups = sorted(set(strat_keys))

rows = []
for g in all_groups:
    total = train_strat.get(g, 0) + test_strat.get(g, 0)
    rows.append({
        "group": g,
        "total": total,
        "train": train_strat.get(g, 0),
        "test": test_strat.get(g, 0),
        "train_%": round(train_strat.get(g, 0) / total * 100, 1) if total else 0,
    })

pd.DataFrame(rows)

Unnamed: 0,group,total,train,test,train_%
0,algebra_D0,26,23,3,88.5
1,algebra_D1,447,402,45,89.9
2,algebra_D2,1742,1568,174,90.0
3,algebra_D3,902,812,90,90.0
4,combinatorics_D0,2,2,0,100.0
5,combinatorics_D1,18,16,2,88.9
6,combinatorics_D2,90,81,9,90.0
7,combinatorics_D3,59,53,6,89.8
8,geometry_D0,2,2,0,100.0
9,geometry_D1,6,5,1,83.3


In [21]:
# ── Samples per dataset (train / test) ─────────────────────────────────────

train_ds = Counter(r.get("dataset", "unknown") for r in train_records)
test_ds = Counter(r.get("dataset", "unknown") for r in test_records)
all_datasets = sorted(set(train_ds) | set(test_ds))

ds_rows = []
for ds in all_datasets:
    total = train_ds.get(ds, 0) + test_ds.get(ds, 0)
    ds_rows.append({
        "dataset": ds,
        "total": total,
        "train": train_ds.get(ds, 0),
        "test": test_ds.get(ds, 0),
        "pct_of_total": round(total / len(records) * 100, 1),
    })

# Sort by total descending
ds_rows.sort(key=lambda x: -x["total"])

print("Samples per dataset:")
pd.DataFrame(ds_rows)

Samples per dataset:


Unnamed: 0,dataset,total,train,test,pct_of_total
0,gsm8k,2711,2437,274,80.0
1,numina1.5,508,459,49,15.0
2,mvidia_reasoning_steps,169,153,16,5.0


In [22]:
# ── Save to disk ───────────────────────────────────────────────────────────

os.makedirs(OUTPUT_DIR, exist_ok=True)

train_path = os.path.join(OUTPUT_DIR, "train.jsonl")
test_path = os.path.join(OUTPUT_DIR, "test.jsonl")

with open(train_path, "w") as f:
    for r in train_records:
        f.write(json.dumps(r) + "\n")

with open(test_path, "w") as f:
    for r in test_records:
        f.write(json.dumps(r) + "\n")

print(f"Saved {len(train_records):,} train records -> {train_path}")
print(f"Saved {len(test_records):,} test records  -> {test_path}")

# Summary
print(f"\n{'='*60}")
print("DATASET SUMMARY")
print(f"{'='*60}")
print(f"Sources: {len(DATASET_PATHS)} datasets")
print(f"Filters: correct={REQUIRE_CORRECT}")
print(f"         effective_difficulty={EFFECTIVE_DIFFICULTY or 'all'}")
print(f"         code_scores={CODE_SCORES or 'all'}")
if TOKEN_LENGTHS:
    print(f"         tokens=[{min(TOKEN_LENGTHS)}-{max(TOKEN_LENGTHS)}]")
else:
    print("         tokens=all")
if SOLUTION_TOKEN_RANGE:
    print(f"         solution_tokens=[{SOLUTION_TOKEN_RANGE[0]}-{SOLUTION_TOKEN_RANGE[1]}]")
else:
    print("         solution_tokens=all")
print(f"         max_constraint_count={MAX_CONSTRAINT_COUNT or 'all'}")
print(f"         max_object_count={MAX_OBJECT_COUNT or 'all'}")
print(f"         reasoning_depth={REASONING_DEPTH or 'all'}")
if DOMAIN_DISTRIBUTION:
    print(f"Domains: {', '.join(f'{d}:{p}%' for d, p in DOMAIN_DISTRIBUTION.items())}")
else:
    print("Domains: all (no distribution)")
if DATASET_DISTRIBUTION:
    print(f"Datasets: {', '.join(f'{d}:{p}%' for d, p in DATASET_DISTRIBUTION.items())}")
else:
    print("Datasets: all (no distribution)")
print(f"Target:  {TARGET_TOTAL_RECORDS or 'unlimited'}")
print(f"Split:   {TRAIN_RATIO*100:.0f}% train / {TEST_RATIO*100:.0f}% test")
print(f"Output:  {OUTPUT_DIR}")
print(f"{'='*60}")

Saved 3,049 train records -> /home/larcanio/AIMO3_v2/data/datasets/splits/algebra_specialist/train.jsonl
Saved 339 test records  -> /home/larcanio/AIMO3_v2/data/datasets/splits/algebra_specialist/test.jsonl

DATASET SUMMARY
Sources: 1 datasets
Filters: correct=True
         effective_difficulty=[0, 1, 2, 3]
         code_scores=all
         tokens=[0-511]
         solution_tokens=all
         max_constraint_count=3
         max_object_count=3
         reasoning_depth=['shallow', 'medium']
Domains: algebra:90%, combinatorics:5%, geometry:2%, number_theory:1%
Datasets: gsm8k:80%, numina1.5:15%, mvidia_reasoning_steps:5%
Target:  20000
Split:   90% train / 10% test
Output:  /home/larcanio/AIMO3_v2/data/datasets/splits/algebra_specialist
