# Cleanup NuminaMath

In [None]:
from pathlib import Path
import json
from typing import Any, Dict, Iterable, Tuple, List

from tqdm.auto import tqdm

# use PyArrow streaming. else, fall back to pandas
try:
    import pyarrow.parquet as pq
    import pyarrow as pa
    HAVE_PA = True
except Exception:
    HAVE_PA = False

try:
    import pandas as pd
    HAVE_PD = True
except Exception:
    HAVE_PD = False

BASE = Path("NuminaMath")
TRAIN_GLOB = "NuminaMath_train_*.parquet"
VALID_FILE = "NuminaMath_valid.parquet"

OUT_TRAIN = Path("./train_jsonl/numinamath_train.jsonl")
OUT_VALID = Path("./valid_jsonl/numinamath_valid.jsonl")

BATCH_SIZE = 8192
DO_DEDUP   = True

In [None]:
# coerce to trimmed string; print ints when floats are integral
def to_clean_str(x: Any) -> str:
    if isinstance(x, int):
        return str(x)
    if isinstance(x, float):
        if abs(x - round(x)) < 1e-9:
            return str(int(round(x)))
        return str(x)
    return str(x).strip()

# STRICT: only accept rows having BOTH 'problem' and 'solution'
# skip otherwise.
def extract_pair_strict(obj: Dict[str, Any]) -> Tuple[str, str] | None:
    if "problem" not in obj or "solution" not in obj:
        return None
    u = to_clean_str(obj["problem"])
    a = to_clean_str(obj["solution"])
    if not u or not a:
        return None
    return u, a


# converts parquet rows -> normalized user/assistant JSONL with progress bars
# STRICT: problem/solution only
# returns counters dict
def build_jsonl_from_parquets(in_paths: List[Path], out_path: Path) -> Dict[str, int]:
    out_path.parent.mkdir(parents=True, exist_ok=True)

    # for tqdm (since this is huge)
    total_rows = None
    if HAVE_PA:
        try:
            total_rows = sum(pq.ParquetFile(p).metadata.num_rows for p in in_paths)
        except Exception:
            total_rows = None

    seen = set()
    kept = dropped_missing = deduped = total_seen = 0

    desc = f"writing {out_path.name}"
    with out_path.open("w", encoding="utf-8") as f, tqdm(
        total=total_rows, unit="rows", desc=desc, leave=False
    ) as pbar:

        if HAVE_PA:
            for p in in_paths:
                pf = pq.ParquetFile(p)
                for batch in pf.iter_batches(batch_size=BATCH_SIZE):
                    rows = pa.Table.from_batches([batch]).to_pylist()
                    for rec in rows:
                        total_seen += 1
                        pair = extract_pair_strict(rec)
                        if pair is None:
                            dropped_missing += 1
                            continue
                        user, assistant = pair

                        if DO_DEDUP:
                            h = (user, assistant)
                            if h in seen:
                                deduped += 1
                                continue
                            seen.add(h)

                        f.write(json.dumps({"user": user, "assistant": assistant}, ensure_ascii=False) + "\n")
                        kept += 1

                    pbar.update(len(rows))
                    pbar.set_postfix(written=kept, missing=dropped_missing, deduped=deduped)

        elif HAVE_PD:
            for p in in_paths:
                df = pd.read_parquet(p)
                recs = df.to_dict(orient="records")
                for rec in recs:
                    total_seen += 1
                    pair = extract_pair_strict(rec)
                    if pair is None:
                        dropped_missing += 1
                        continue
                    user, assistant = pair

                    if DO_DEDUP:
                        h = (user, assistant)
                        if h in seen:
                            deduped += 1
                            continue
                        seen.add(h)

                    f.write(json.dumps({"user": user, "assistant": assistant}, ensure_ascii=False) + "\n")
                    kept += 1

                pbar.update(len(recs))
                pbar.set_postfix(written=kept, missing=dropped_missing, deduped=deduped)
        else:
            raise RuntimeError("Neither pyarrow nor pandas is available to read parquet.")

    return {
        "total_rows_in_files": int(total_rows) if total_rows is not None else total_seen,
        "seen_rows": total_seen,
        "written": kept,
        "dropped_missing_or_empty": dropped_missing,
        "deduped": deduped,
    }

In [None]:
train_parts = sorted(BASE.glob(TRAIN_GLOB), key=lambda p: p.name)
valid_parts = [BASE / VALID_FILE]

print("Train parts:")
for p in train_parts: print(" -", p)
print("\nValid parts:")
for p in valid_parts: print(" -", p)

# convert
train_stats = build_jsonl_from_parquets(train_parts, OUT_TRAIN)
valid_stats = build_jsonl_from_parquets(valid_parts, OUT_VALID)

print("\n--- Done ---")
print("Train:", train_stats)
print("Valid:", valid_stats)

Train parts:
 - NuminaMath\NuminaMath_train_1.parquet
 - NuminaMath\NuminaMath_train_2.parquet
 - NuminaMath\NuminaMath_train_3.parquet
 - NuminaMath\NuminaMath_train_4.parquet
 - NuminaMath\NuminaMath_train_5.parquet

Valid parts:
 - NuminaMath\NuminaMath_valid.parquet


writing numinamath_train.jsonl:   0%|          | 0/859494 [00:00<?, ?rows/s]

writing numinamath_valid.jsonl:   0%|          | 0/100 [00:00<?, ?rows/s]


--- Done ---
Train: {'total_rows_in_files': 859494, 'seen_rows': 859494, 'written': 852742, 'dropped_missing_or_empty': 0, 'deduped': 6752}
Valid: {'total_rows_in_files': 100, 'seen_rows': 100, 'written': 100, 'dropped_missing_or_empty': 0, 'deduped': 0}


In [4]:
def peek_jsonl(path: Path, k: int = 3):
    print(f"\n--- {path.name} (first {k}) ---")
    with path.open("r", encoding="utf-8") as f:
        for i, ln in enumerate(f):
            if i >= k: break
            print(ln.rstrip())

peek_jsonl(OUT_TRAIN, 3)
peek_jsonl(OUT_VALID, 3)


--- numinamath_train.jsonl (first 3) ---
{"user": "Consider the terms of an arithmetic sequence: $-\\frac{1}{3}, y+2, 4y, \\ldots$. Solve for $y$.", "assistant": "For an arithmetic sequence, the difference between consecutive terms must be equal. Therefore, we can set up the following equations based on the sequence given:\n\\[ (y + 2) - \\left(-\\frac{1}{3}\\right) = 4y - (y+2) \\]\n\nSimplify and solve these equations:\n\\[ y + 2 + \\frac{1}{3} = 4y - y - 2 \\]\n\\[ y + \\frac{7}{3} = 3y - 2 \\]\n\\[ \\frac{7}{3} + 2 = 3y - y \\]\n\\[ \\frac{13}{3} = 2y \\]\n\\[ y = \\frac{13}{6} \\]\n\nThus, the value of $y$ that satisfies the given arithmetic sequence is $\\boxed{\\frac{13}{6}}$."}
{"user": "Suppose that $g(x) = 5x - 3$. What is $g^{-1}(g^{-1}(14))$?", "assistant": "First, we need to find the inverse function $g^{-1}(x)$. Given $g(x) = 5x - 3$, solve for $x$:\n\\[ y = 5x - 3 \\]\n\\[ y + 3 = 5x \\]\n\\[ x = \\frac{y + 3}{5} \\]\nThus, $g^{-1}(x) = \\frac{x + 3}{5}$.\n\nNow, apply 

In [None]:
from tqdm.auto import tqdm
import json
from pathlib import Path

def sanity_check_jsonl(path: Path, max_lines: int | None = None):
    n = bad_json = nonstring = short_user = short_asst = 0
    with path.open("r", encoding="utf-8") as f:
        for i, ln in enumerate(tqdm(f, unit="lines", desc=f"scan {path.name}")):
            if (max_lines is not None) and (i >= max_lines): break
            try:
                obj = json.loads(ln)
            except Exception:
                bad_json += 1
                continue
            u, a = obj.get("user",""), obj.get("assistant","")
            if not isinstance(u, str) or not isinstance(a, str):
                nonstring += 1
                continue
            if len(u) < 5: short_user += 1
            if len(a) < 3: short_asst += 1
            n += 1
    print(f"{path.name}: scanned {n} lines | bad_json={bad_json} nonstring={nonstring} "
          f"| very-short user={short_user} assistant={short_asst}")

# sanity check
sanity_check_jsonl(OUT_TRAIN, max_lines=None)
sanity_check_jsonl(OUT_VALID, max_lines=None)


scan numinamath_train.jsonl: 0lines [00:00, ?lines/s]

numinamath_train.jsonl: scanned 852742 lines | bad_json=0 nonstring=0 | very-short user=0 assistant=0


scan numinamath_valid.jsonl: 0lines [00:00, ?lines/s]

numinamath_valid.jsonl: scanned 100 lines | bad_json=0 nonstring=0 | very-short user=0 assistant=0
