## Setup

In [9]:
import json
import sys
import importlib.util
from pathlib import Path
from collections import Counter
from tqdm.notebook import tqdm

# Load module with numeric prefix in filename
_spec = importlib.util.spec_from_file_location(
    "heuristic_math_classifier",
    Path.cwd() / "4_heuristic_math_classifier.py",
)
hmc = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(hmc)

merge_consensus = hmc.merge_consensus
ALLOWED_DOMAINS = hmc.ALLOWED_DOMAINS
DEFAULT_H_THRESHOLD = hmc.DEFAULT_H_THRESHOLD

## Configuration

In [10]:
INPUT_DATASET_PATH = Path('/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_0_7.jsonl')
OUTPUT_DATASET_PATH = Path('/home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_0_7_consensus.jsonl')

H_THRESHOLD = DEFAULT_H_THRESHOLD  # heuristic margin to override LLM domain (default: 6)

# Filter: only process records that already have math_structure from the LLM classifier
REQUIRE_LLM_STRUCTURE = True

# Rerun: re-process records that already have consensus fields
RERUN_EXISTING = False

N_PROBLEMS = 0  # 0 = all

print(f"Input:  {INPUT_DATASET_PATH}")
print(f"Output: {OUTPUT_DATASET_PATH}")
print(f"H_THRESHOLD={H_THRESHOLD}, REQUIRE_LLM_STRUCTURE={REQUIRE_LLM_STRUCTURE}, RERUN_EXISTING={RERUN_EXISTING}")

Input:  /home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_0_7.jsonl
Output: /home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_0_7_consensus.jsonl
H_THRESHOLD=6, REQUIRE_LLM_STRUCTURE=True, RERUN_EXISTING=False


## Load Dataset

In [11]:
full_datapoints = []
with open(INPUT_DATASET_PATH, 'r', encoding='utf-8') as f:
    for line in f:
        if line.strip():
            try:
                full_datapoints.append(json.loads(line))
            except json.JSONDecodeError:
                continue

print(f"Loaded: {len(full_datapoints)} total records")

# Filter to records that need processing
to_process_idxs = []
for i, dp in enumerate(full_datapoints):
    ms = dp.get('math_structure') or {}
    if REQUIRE_LLM_STRUCTURE and not ms.get('from_text'):
        continue
    if not RERUN_EXISTING and ms.get('consensus_meta'):
        continue
    to_process_idxs.append(i)

if N_PROBLEMS > 0:
    to_process_idxs = to_process_idxs[:N_PROBLEMS]

skipped = len(full_datapoints) - len(to_process_idxs)
print(f"To process: {len(to_process_idxs)} (skipped {skipped})")

Loaded: 71832 total records
To process: 71761 (skipped 71)


## Run Consensus

In [None]:
import time

# Accumulators — field names match consensus_meta keys
CONSENSUS_FIELDS = [
    'reasoning_shape', 'case_split', 'auxiliary_construction',
    'reasoning_depth', 'intermediate_reuse',
    'objects', 'constraints', 'output_type',
    'mechanisms',
]

disagreement_counts = Counter()
source_counts = {f: Counter() for f in CONSENSUS_FIELDS}
domain_decision_counts = Counter()
completed = 0
errors = 0

start_time = time.time()
pbar = tqdm(total=len(to_process_idxs), desc="Consensus")

for idx in to_process_idxs:
    dp = full_datapoints[idx]
    try:
        result = merge_consensus(dp, H_THRESHOLD)
    except Exception as e:
        errors += 1
        if errors <= 5:
            print(f"Error at index {idx}: {type(e).__name__}: {str(e)[:120]}")
        pbar.update(1)
        continue

    full_datapoints[idx] = result
    completed += 1

    # Collect stats from consensus_meta
    ms = result.get('math_structure', {})
    domain_decision_counts[ms.get('domain_meta', {}).get('decision_reason', 'unknown')] += 1

    meta = ms.get('consensus_meta', {})
    for field_name, m in meta.items():
        if field_name in source_counts:
            if m.get('disagreement'):
                disagreement_counts[field_name] += 1
            source_counts[field_name][m.get('source', 'heuristic')] += 1

    pbar.set_postfix(
        done=completed,
        err=errors,
        disagree=f"{sum(disagreement_counts.values())}",
    )
    pbar.update(1)

pbar.close()
elapsed = time.time() - start_time
print(f"\nDone: {completed} processed, {errors} errors in {elapsed:.1f}s ({elapsed/max(completed,1)*1000:.1f}ms/record)")

## Summary Statistics

In [13]:
print(f"{'Field':<28} {'Heuristic':>9} {'LLM':>9} {'Merged':>9} │ {'Disagree':>9} {'Rate':>7}")
print('─' * 80)
for f in CONSENSUS_FIELDS:
    sc = source_counts[f]
    d = disagreement_counts.get(f, 0)
    total_f = sc['heuristic'] + sc['llm'] + sc['merged']
    rate = f"{d/total_f:.1%}" if total_f > 0 else 'n/a'
    print(f"{f:<28} {sc['heuristic']:>9} {sc['llm']:>9} {sc['merged']:>9} │ {d:>9} {rate:>7}")

print(f"\nTotal processed: {completed}")
print(f"Total disagreements: {sum(disagreement_counts.values())}")

Field                        Heuristic       LLM    Merged │  Disagree    Rate
────────────────────────────────────────────────────────────────────────────────
reasoning_shape                  71761         0         0 │      2410    3.4%
case_split                       71761         0         0 │      5057    7.0%
auxiliary_construction           71627       134         0 │     33434   46.6%
reasoning_depth                  71761         0         0 │     36096   50.3%
intermediate_reuse               71761         0         0 │     43706   60.9%
objects                          21323     27356     23082 │     61516   85.7%
constraints                      71761         0         0 │     55819   77.8%
output_type                      71761         0         0 │     12461   17.4%

Total processed: 71761
Total disagreements: 250499


In [14]:
print("Domain decision reasons:")
for reason, cnt in domain_decision_counts.most_common():
    print(f"  {reason:<40} {cnt:>7} ({cnt/completed:.1%})")

Domain decision reasons:
  agree                                      22161 (30.9%)
  llm_default                                20450 (28.5%)
  hard_override:algebra                      10790 (15.0%)
  hard_override:number_theory                 9640 (13.4%)
  hard_override:geometry                      5638 (7.9%)
  heuristic_override:margin=6                 1130 (1.6%)
  llm_missing_heuristic_fallback              1070 (1.5%)
  heuristic_override:margin=9                  303 (0.4%)
  heuristic_override:margin=8                  144 (0.2%)
  heuristic_override:margin=12                 140 (0.2%)
  heuristic_override:margin=7                  112 (0.2%)
  hard_override:combinatorics                   93 (0.1%)
  heuristic_override:margin=10                  56 (0.1%)
  heuristic_override:margin=14                  25 (0.0%)
  heuristic_override:margin=13                   6 (0.0%)
  heuristic_override:margin=16                   2 (0.0%)
  heuristic_override:margin=11             

In [15]:
# Domain distribution
domain_counts = Counter()
for dp in full_datapoints:
    ms = dp.get('math_structure', {})
    if ms.get('domain'):
        domain_counts[ms['domain']] += 1

print("Domain distribution (after consensus):")
for domain, cnt in domain_counts.most_common():
    print(f"  {domain:<20} {cnt:>7} ({cnt/len(full_datapoints):.1%})")

Domain distribution (after consensus):
  algebra                24266 (33.8%)
  number_theory          21202 (29.5%)
  geometry               14369 (20.0%)
  combinatorics          11924 (16.6%)


## Save

In [16]:
OUTPUT_DATASET_PATH.parent.mkdir(parents=True, exist_ok=True)
tmp = OUTPUT_DATASET_PATH.with_suffix('.tmp')
with open(tmp, 'w', encoding='utf-8') as f:
    for dp in full_datapoints:
        f.write(json.dumps(dp, ensure_ascii=False) + '\n')
tmp.rename(OUTPUT_DATASET_PATH)

print(f"Saved {len(full_datapoints)} records to {OUTPUT_DATASET_PATH}")
print(f"  processed: {completed}, unchanged: {len(full_datapoints) - completed}")

Saved 71832 records to /home/larcanio/AIMO3_v2/data/datasets/Dataset_Full/bucketed/dataset_0_7_consensus.jsonl
  processed: 71761, unchanged: 71
