# Configure tokenizers & stuffs

In [None]:
from pathlib import Path
import json, re, numpy as np
from tokenizers import ByteLevelBPETokenizer, AddedToken

VOCAB  = Path("../bpe/bpe_model-vocab.json")
MERGES = Path("../bpe/bpe_model-merges.txt")

MAX_LEN          = 512
PAD_ID_EXPECTED  = 0 # must match tokenizer.json
END_ID_EXPECTED  = 3 # must match tokenizer.json

# here's the thing
# if you look at `tokenizer_test.ipynb`, you'll that whole system prompt is 87 tokens
# so, we need to reserve at least that many tokens for the system prompt
# i just set it to 90 (+3 for safety)
SYSTEM_RESERVE   = 90   # tokens reserved for system
# also, if you look at `checking_info.ipynb`, you'll see that the longest one has 87 tokens
# so, we'll reserve 70 tokens for info
INFO_RESERVE     = 70   # tokens reserved for info
# that leaves us with 512 - 90 - 70 = 352 tokens for the user + assistant

# we want to ensure that the assistant always has at least 48 tokens
MIN_ASSISTANT    = 48   # hard minimum tokens for assistant span

# build the two tokenizer *views*:
# - bpe_decode: for decoding + recognizing specials
# - bpe_prompt: for encoding prompts (ONLY PAD and END are special)
bpe_decode = ByteLevelBPETokenizer(str(VOCAB), str(MERGES), lowercase=False, add_prefix_space=True)
bpe_prompt = ByteLevelBPETokenizer(str(VOCAB), str(MERGES), lowercase=False, add_prefix_space=True)

# to verify we do have those special tokens
SPECIALS_DECODE = [
    "<|PAD|>", "<|START|>", "<|END|>", "<|SYSTEM|>",
    "<|INFOSTART|>", "<|INFOEND|>", "<|USER|>", "<|ASSISTANT|>", "<|UNKNOWN|>"
]
SPECIALS_PROMPT = ["<|PAD|>", "<|END|>"]

def report_specials(tok, specials, name):
    ids = {t: tok.token_to_id(t) for t in specials}
    missing = [t for t, i in ids.items() if i is None]
    print(f"[{name}] present {len(specials)-len(missing)}/{len(specials)}")
    if missing:
        print("  missing:", missing)
    return ids, missing

SID, missing_decode = report_specials(bpe_decode, SPECIALS_DECODE, "decode")
SPID, missing_prompt = report_specials(bpe_prompt, SPECIALS_PROMPT, "prompt")

assert not missing_decode, f"Decode tokenizer missing specials: {missing_decode}"
assert not missing_prompt, f"Prompt tokenizer missing specials: {missing_prompt}"

PAD_ID  = SID["<|PAD|>"]; END_ID = SID["<|END|>"]
assert PAD_ID == PAD_ID_EXPECTED and END_ID == END_ID_EXPECTED, (
    f"PAD/END ids mismatch: PAD={PAD_ID} vs {PAD_ID_EXPECTED}, END={END_ID} vs {END_ID_EXPECTED}"
)

# helpers
def tok_len(txt: str) -> int:
    return len(bpe_prompt.encode(txt, add_special_tokens=False).ids)

def enc(txt: str):
    return bpe_prompt.encode(txt, add_special_tokens=False).ids

def dec(ids):
    return bpe_decode.decode([int(t) for t in ids])


[decode] present 9/9
[prompt] present 2/2


# For TQDM STUFFS

In [None]:
# progress bar setup (put once near the top)
from tqdm import tqdm as TQDM
if hasattr(TQDM, "monitor_interval"):
    TQDM.monitor_interval = 0 # no background monitor thread

TQDM_KW = dict(dynamic_ncols=True)
BARFMT  = "{l_bar}{bar} | {n_fmt}/{total_fmt} [{elapsed}<{remaining}]"


# System prompts

In [3]:
import numpy as np

# ok. that's the template we'll use for sft training (we'll use the same by the time we generate output)
# tmi: that's the one with 87 tokens
SYSTEM_PROMPT_DEFAULT = (
    "Be a helpful, concise assistant with a light, friendly tone. "
    "Answer directly in 1–3 sentences. Don’t use steps or bullet lists unless the user asks. "
    "Use the content between INFOSTART/INFOEND only as context. Do not mention it, ‘memory,’ or any internal tags. "
    "Avoid speculation and say when unsure. Keep replies safe, accurate, and on-topic."
)

# later on, we'll look at an SFT ids file to detect whether training used a leading space/newline after <|ASSISTANT|>
SFT_IDS = Path("../final_npy/train_input_ids.npy")
def detect_assist_prefix(ids_path: Path, sample_k: int = 256) -> str:
    if not ids_path.exists():
        return " "
    arr = np.load(ids_path, mmap_mode="r")
    spaces = newlines = letters = 0
    for i in range(min(sample_k, arr.shape[0])):
        row = [int(t) for t in arr[i] if int(t) != PAD_ID]
        txt = bpe_decode.decode(row)
        j = txt.find("<|ASSISTANT|>")
        if j == -1: 
            continue
        ch = txt[j+len("<|ASSISTANT|>"): j+len("<|ASSISTANT|>")+1]
        if ch == " ":      spaces += 1
        elif ch == "\n":   newlines += 1
        elif ch:           letters += 1
    if spaces >= max(newlines, letters): return " "
    if newlines > 0: return "\n"
    return " "

ASSIST_PREFIX = detect_assist_prefix(SFT_IDS)  # default to a single space if unknown

def build_template_text(system_text: str, info_text: str, user_text: str) -> str:
    return (
        "<|START|><|SYSTEM|>" + system_text + "\n"
        "<|INFOSTART|>" + (info_text or "") + "<|INFOEND|>\n"
        "<|USER|>" + user_text.strip() + "\n"
        "<|ASSISTANT|>" + ASSIST_PREFIX
    )


# Budget

In [None]:
import math
import json

RUT_PAT = re.compile(r"(?i)\bstep\s*\d+|identify the (given )?sentence|the sentence is")

# returns available budgets (token counts) for user and assistant when packing to MAX_LEN
# strategy:
#   1) Reserve SYSTEM_RESERVE and INFO_RESERVE (hard caps)
#   2) Build the static template and measure its tokenized length (without user/assistant)
#   3) Leave 1 token for END
#   4) Enforce MIN_ASSISTANT
#   5) Allocate remaining to user
def compute_budgets(system_text: str, info_text: str, user_text: str) -> dict:
    # clamp system/info to reserves (skip if they exceed 512 tokens total)
    sys_ids  = enc((system_text or "")[:])[:SYSTEM_RESERVE]
    info_ids = enc((info_text  or "")[:])[:INFO_RESERVE]

    # template skeleton around user/assistant (we’ll measure tokens contributed by tags + SYSTEM/INFO shells)
    # for length, we must recompute using the same tokenizer behavior used in packing
    # build a "header" with SYSTEM/INFO lines and USER/ASSISTANT tags but with EMPTY user/assistant bodies
    header_txt = (
        "<|START|><|SYSTEM|>" + bpe_decode.decode(sys_ids) + "\n"
        "<|INFOSTART|>" + bpe_decode.decode(info_ids) + "<|INFOEND|>\n"
        "<|USER|>" + "" + "\n"
        "<|ASSISTANT|>" + ASSIST_PREFIX
    )
    header_len = tok_len(header_txt)

    # total tokens available for user_text + assistant_text + END
    remain = MAX_LEN - header_len - 1  # save 1 for END token

    # we recently set MIN_ASSISTANT to 48
    # so, it will ignore user if the template + reserves already nearly fills the row
    # i.e. if remain < MIN_ASSISTANT, then user_budget=0 and assistant_budget=remain
    if remain < MIN_ASSISTANT:
        return dict(user_budget=0, assistant_budget=max(0, remain))
    
    # guarantee assistant minimum
    user_budget = max(0, remain - MIN_ASSISTANT)
    assistant_budget = MIN_ASSISTANT
    return dict(user_budget=user_budget, assistant_budget=assistant_budget)

def truncate_to_budget(text: str, budget: int) -> list[int]:
    if budget <= 0:
        return []
    # token level truncation preserves leading spaces behavior
    ids = enc(text)
    return ids[:budget]


# Row & Masking

In [5]:
# packing row. WE ARE NOT TRUNCATING. skip if it wouldn't fit 512
def strip_any_end(text: str) -> str:
    return re.sub(r"<\|END\|>\s*$", "", text)

def row_to_ids_and_mask(system_text: str, info_text: str, user_text: str, assistant_text: str):
    # cleanup + scaffold filter
    assistant_text = strip_any_end(assistant_text or "")
    if RUT_PAT.search(assistant_text):
        return None

    # tokenize FULL fields (no slicing!) to measure true lengths
    sys_ids_full   = enc(system_text or SYSTEM_PROMPT_DEFAULT)
    info_ids_full  = enc(info_text  or "")
    user_ids_full  = enc(user_text  or "")
    asst_ids_full  = enc(assistant_text or "")

    # enforce fixed reserves for system/info (SKIP if they exceed reserves)
    if len(sys_ids_full) > SYSTEM_RESERVE:
        return None
    if len(info_ids_full) > INFO_RESERVE:
        return None

    sys_ids  = sys_ids_full
    info_ids = info_ids_full

    # compute header length with EXACT same tokenizer behavior
    header_txt = (
        "<|START|><|SYSTEM|>" + bpe_decode.decode(sys_ids) + "\n"
        "<|INFOSTART|>" + bpe_decode.decode(info_ids) + "<|INFOEND|>\n"
        "<|USER|>" + bpe_decode.decode(user_ids_full) + "\n"
        "<|ASSISTANT|>" + ASSIST_PREFIX
    )
    header_len = tok_len(header_txt)

    # total remaining budget for assistant + END
    remain = MAX_LEN - header_len - 1  # keep 1 for END

    # must leave room for at least MIN_ASSISTANT (strict)
    if remain < MIN_ASSISTANT:
        return None

    # strict budgets: user must already fit; assistant must fit into remain
    # we already built header with full user_ids, so the check is simply on asst length.
    if len(asst_ids_full) > remain:
        return None

    # build final sequence
    ids = []
    ids += enc("<|START|><|SYSTEM|>") + sys_ids + enc("\n")
    ids += enc("<|INFOSTART|>") + info_ids + enc("<|INFOEND|>\n")
    ids += enc("<|USER|>") + user_ids_full + enc("\n")
    ids += enc("<|ASSISTANT|>") + enc(ASSIST_PREFIX)
    ids += asst_ids_full
    ids += [END_ID]

    # safety: must fit exactly within 512 (pad if shorter, drop if longer)
    if len(ids) > MAX_LEN:
        return None
    pad_len = MAX_LEN - len(ids)
    mask = np.zeros(len(ids) + pad_len, dtype=np.uint8)

    # mask from first assistant token (after tag+prefix) through END (inclusive)
    prefix_txt = (
        "<|START|><|SYSTEM|>" + bpe_decode.decode(sys_ids) + "\n"
        "<|INFOSTART|>" + bpe_decode.decode(info_ids) + "<|INFOEND|>\n"
        "<|USER|>" + bpe_decode.decode(user_ids_full) + "\n"
        "<|ASSISTANT|>" + ASSIST_PREFIX
    )
    start_mask = len(enc(prefix_txt))
    end_mask   = len(ids) - 1  # END index

    mask[:len(ids)] = 0
    mask[start_mask:end_mask+1] = 1

    if pad_len:
        ids  = ids + [PAD_ID]*pad_len  # right pad to 512

    # stats
    asst_len = end_mask - start_mask  # not counting END
    stats = dict(
        total=len(ids),
        user_len=len(user_ids_full),
        assistant_len=asst_len,
        includes_END=True,
        truncated=False,
    )
    return np.array(ids, dtype=np.int32), mask.astype(np.uint8), stats


# Audit NPY files

this is for sanity check <br/>
around the end, by the time we do sanity check, we'll use this

In [6]:
def audit_npys(ids_path: Path, mask_path: Path, sample_k: int = 5000):
    ids = np.load(ids_path, mmap_mode="r")
    msk = np.load(mask_path, mmap_mode="r")
    N, T = ids.shape
    take = min(sample_k, N)

    bad_order = bad_multi_end = bad_mask_after = 0
    boundary_mismatch = 0
    short_asst = rut_hits = 0

    rng = TQDM(range(take), desc=f"[AUDIT*] {ids_path.name}", total=take, bar_format=BARFMT, **TQDM_KW)
    for i in rng:
        row_tok = ids[i]
        row_msk = msk[i].astype(bool)

        # quick END checks
        nonpad = (row_tok != PAD_ID)
        last_idx = T - 1 - np.argmax(nonpad[::-1])
        end_count = int((row_tok == END_ID).sum())
        if end_count != 1: bad_multi_end += 1
        if row_tok[last_idx] != END_ID: bad_order += 1

        # no mask after END
        if row_msk[last_idx+1:].any(): bad_mask_after += 1

        # assistant start from mask (first 1)
        if not row_msk.any(): continue # should not happen
        a_tok = int(np.argmax(row_msk)) # first supervised index
        e_tok = last_idx

        # semantic boundary check: ensure decoded text around a_tok shows we're right after the tag+prefix
        # decode a small window around the suspected boundary
        lo = max(0, a_tok - 32); hi = min(T, a_tok + 32)
        window = bpe_decode.decode([int(t) for t in row_tok[lo:hi] if int(t) != PAD_ID])

        # we expect "...<|ASSISTANT|>{ASSIST_PREFIX}<first assistant chars>..."
        expect = "<|ASSISTANT|>" + ASSIST_PREFIX
        ok = expect in window
        if not ok: boundary_mismatch += 1

        # ruts / short length
        asst_len = (e_tok - a_tok)
        short_asst += int(asst_len <= 3)
        if RUT_PAT.search(window): rut_hits += 1

        if i and i % 5000 == 0:
            rng.set_postfix_str(f"boundary_mismatch={boundary_mismatch}, short<=3={short_asst}, ruts={rut_hits}")

    print(f"[AUDIT*] rows checked={take}/{N}")
    print(f"  bad_order={bad_order}  multi_or_zero_END={bad_multi_end}  mask_after_END={bad_mask_after}")
    print(f"  boundary_mismatch (tag+prefix not visible near start) = {boundary_mismatch}")
    print(f"  assistant_len<=3: {short_asst} ({short_asst/take:.2%}), rut_hits: {rut_hits} ({rut_hits/take:.2%})")


# Each JSONL to NPY

we're going to do the following: <br/>
1. convert each jsonl to npy
2. strip those npy by ratio
    - chat:math:science = 7:2:1 (+ entire info datasets)
    - those ratio are for assistant tokens that are being masked
    - we split by assistant tokens since that's the crutial one since our current babbling model will be tuned by learning those assistant tokens
3. concatinate those npy files and masked files

# Convert each JSONL to NPY

In [None]:
# each JSONL to its own NPY pair (input_ids + loss_mask)
# with strict filtering (skip rows that don't fit exactly)
from pathlib import Path
import json, numpy as np, re, os
try:
    from tqdm.auto import tqdm as TQDM
except Exception:
    from tqdm import tqdm as TQDM

def iter_lines_hybrid(p: Path, desc="READ"):
    total_bytes = os.stat(p).st_size
    line_count = 0
    with p.open("rb") as fbin, TQDM(total=total_bytes, unit="B", unit_scale=True,
                                    desc=f"[{desc}] {p.name}", leave=False,
                                    bar_format=BARFMT, **TQDM_KW) as t:
        for raw in fbin:
            t.update(len(raw))
            line = raw.decode("utf-8", "ignore")
            if line.strip():
                line_count += 1
                if line_count % 1000 == 0:
                    t.set_postfix_str(f"lines={line_count:,}")
            yield line

def _label_from_filename(path: Path) -> str:
    name = path.name.lower()
    # pull dataset key before _train/_valid
    m = re.match(r"^(.*)_(train|valid)(?:\.\w+)?$", name)
    key = m.group(1) if m else path.stem.lower()

    # explicit keys
    if key == "info" or "info" in key:
        return "info"

    # chat datasets
    if key in {"ultrachat", "openorca", "dolly", "oasst1"}:
        return "chat"

    # math datasets
    if key in {"gsm8k", "hendrycks_math", "numinamath", "svamp"}:
        return "math"

    # science datasets
    if key in {"sciq", "ai2"}:
        return "science"

    # heuristic fallbacks (keep your originals)
    if "chat" in name:    return "chat"
    if "math" in name:    return "math"
    if "science" in name: return "science"
    if re.search(r"\b(science|physics|chem|bio)\b", name): return "science"
    if re.search(r"\b(math|algebra|calc|equation)\b", name): return "math"
    return "chat"


def build_npys_per_file_with_labels(src_dir: Path, out_dir: Path, *, system_text: str = SYSTEM_PROMPT_DEFAULT, rng_seed: int = 1234):
    out_dir.mkdir(parents=True, exist_ok=True)
    files = sorted([p for p in Path(src_dir).glob("*.jsonl")])
    if not files:
        raise FileNotFoundError(f"No .jsonl in {src_dir}")

    meta = []
    file_iter = TQDM(files, desc=f"[SCAN] {src_dir.name}", unit="file",
                    total=len(files), bar_format=BARFMT, **TQDM_KW)


    for p in file_iter:
        ids_buf, mask_buf, kept_stats = [], [], []
        dropped_scaffold = dropped_over = dropped_other = 0
        kept = 0

        for ln in iter_lines_hybrid(p, desc="READ"):
            ln = ln.strip()
            if not ln:
                continue
            try:
                row = json.loads(ln)
            except Exception:
                continue

            info = (row.get("info") or "").strip()
            user = (row.get("user") or "").strip()
            asst = (row.get("assistant") or "").strip()
            if not user or not asst:
                continue

            packed = row_to_ids_and_mask(system_text, info, user, asst)
            if packed is None:
                if RUT_PAT.search(asst or ""): dropped_scaffold += 1
                else: dropped_over += 1
                continue

            ids, msk, st = packed
            if len(ids) != MAX_LEN or len(msk) != MAX_LEN:
                dropped_other += 1; continue

            ids_buf.append(ids); mask_buf.append(msk); kept_stats.append(st); kept += 1

        if kept == 0:
            TQDM.write(f"[SKIP] {p.name}: kept 0 rows after strict filtering.")
            continue

        ids_arr  = np.stack(ids_buf,  axis=0)
        mask_arr = np.stack(mask_buf, axis=0)

        # deterministic per-file shuffle
        rng = np.random.default_rng(rng_seed)
        perm = rng.permutation(ids_arr.shape[0])
        ids_arr, mask_arr = ids_arr[perm], mask_arr[perm]

        lbl = _label_from_filename(p)
        labels_arr = np.array([lbl]*ids_arr.shape[0], dtype=object)

        out_ids   = out_dir / f"{p.stem}.input_ids.npy"
        out_mask  = out_dir / f"{p.stem}.loss_mask.npy"
        out_lbls  = out_dir / f"{p.stem}.labels.npy"
        np.save(out_ids,  ids_arr)
        np.save(out_mask, mask_arr)
        np.save(out_lbls, labels_arr)

        asst_lens = np.array([s["assistant_len"] for s in kept_stats], dtype=np.int32)
        ones_pct  = float(mask_arr.sum()) / mask_arr.size
        TQDM.write(
            f"[WRITE] {p.name:25s} → kept={kept:,}  scaf_drop={dropped_scaffold:,}  over_drop={dropped_over:,}  other={dropped_other:,}  "
            f"p50={np.percentile(asst_lens,50):.1f}  ones%={ones_pct:.2%}  out=({out_ids.name}, {out_mask.name}, {out_lbls.name})"
        )

        meta.append(dict(file=p, label=lbl, out_ids=out_ids, out_mask=out_mask, out_labels=out_lbls, kept=kept))
    return meta


# Concatinate NPY files

In [None]:
def _gather_label_pool(perfile_meta: list[dict], want_label: str):
    ids_list, msk_list = [], []
    for m in perfile_meta:
        if m["label"] != want_label:
            continue
        ids = np.load(m["out_ids"],  mmap_mode="r")
        msk = np.load(m["out_mask"], mmap_mode="r")
        ids_list.append(ids); msk_list.append(msk)
    if not ids_list:
        return np.zeros((0, MAX_LEN), dtype=np.int32), np.zeros((0, MAX_LEN), dtype=np.uint8)
    return np.concatenate(ids_list, axis=0), np.concatenate(msk_list, axis=0)

# sort by decending assistant length and take until ~target_tokens
# deterministic via seed
# returns indices into ids/msk
def _select_by_token_budget(ids: np.ndarray, msk: np.ndarray, target_tokens: int, *, seed: int) -> np.ndarray:
    if ids.shape[0] == 0 or target_tokens <= 0:
        return np.array([], dtype=np.int64)
    rng = np.random.default_rng(seed)
    tok = msk.sum(axis=1) - 1
    order = np.argsort(-tok) # long answers first (fewer examples needed to hit budget)
    # small randomization within same lengths to avoid bias
    uniq_lens = {}
    for idx in order:
        L = int(tok[idx])
        uniq_lens.setdefault(L, []).append(idx)
    shuffled = []
    for L, arr in uniq_lens.items():
        a = np.array(arr, dtype=np.int64)
        rng.shuffle(a)
        shuffled.append(a)
    order = np.concatenate(shuffled, axis=0) if shuffled else np.array([], dtype=np.int64)

    picked, acc = [], 0
    for i in order:
        L = int(tok[i])
        if acc + L > target_tokens and acc > 0:
            break
        picked.append(i); acc += L
        if acc >= target_tokens:
            break
    return np.array(picked, dtype=np.int64)

def build_final_four_from_perfile(
    train_meta: list[dict],
    valid_meta: list[dict],
    *,
    ratios={"chat":7, "math":2, "science":1}, # assistant token ratios
    keep_all_info_train=True,
    valid_fraction_noninfo=0.10,
    seed=2345,
    out_dir=Path("../final_npy"),
):
    out_dir.mkdir(parents=True, exist_ok=True)

    # pools per split
    def _split_from_meta(meta, split_name):
        # keep all info rows
        info_ids, info_msk = _gather_label_pool(meta, "info")

        # build non-info pools
        chat_ids, chat_msk = _gather_label_pool(meta, "chat")
        math_ids, math_msk = _gather_label_pool(meta, "math")
        sci_ids,  sci_msk  = _gather_label_pool(meta, "science")

        # total non-info tokens available (assistant tokens)
        def total_tok(msk):
            return int((msk.sum(axis=1) - 1).clip(min=0).sum())

        avail = {
            "chat": total_tok(chat_msk),
            "math": total_tok(math_msk),
            "science": total_tok(sci_msk),
        }
        TQDM.write(f"[{split_name}] non-info token availability: {avail}")

        # budget for VALID (non-info only), TRAIN gets the rest (+ all info if train)
        if split_name == "valid":
            target_total_noninfo = int(round(sum(avail.values()) * valid_fraction_noninfo))
        else:
            # TRAIN will be built implicitly as "all non-info minus VALID selection"
            target_total_noninfo = None  # not used

        # compute per-label VALID token targets via ratios, capped by availability
        def targets_for_valid(avail_tokens, ratios):
            if target_total_noninfo is None or target_total_noninfo <= 0:
                return {"chat":0,"math":0,"science":0}
            rsum = sum(ratios.values())
            raw = {k: int(round(target_total_noninfo * (ratios[k]/rsum))) for k in ratios}
            # clamp by availability
            for k in raw:
                raw[k] = min(raw[k], avail_tokens[k])
            # fix rounding to match total
            diff = target_total_noninfo - sum(raw.values())
            if diff != 0:
                order = sorted(ratios.keys(), key=lambda k: -ratios[k])
                j = 0
                while diff != 0 and order:
                    k = order[j % len(order)]
                    cap = avail_tokens[k]
                    if diff > 0 and raw[k] < cap:
                        raw[k] += 1; diff -= 1
                    elif diff < 0 and raw[k] > 0:
                        raw[k] -= 1; diff += 1
                    j += 1
            return raw

        valid_targets = targets_for_valid(avail, ratios)
        if split_name == "valid":
            TQDM.write(f"[{split_name}] token targets (chat:math:science) = {valid_targets}")

        # choose VALID rows (non-info) by token budgets
        if split_name == "valid":
            vidx_chat = _select_by_token_budget(chat_ids, chat_msk, valid_targets["chat"], seed=seed+1)
            vidx_math = _select_by_token_budget(math_ids, math_msk, valid_targets["math"], seed=seed+2)
            vidx_sci  = _select_by_token_budget(sci_ids,  sci_msk,  valid_targets["science"], seed=seed+3)

            valid_ids = np.concatenate([
                chat_ids[vidx_chat], math_ids[vidx_math], sci_ids[vidx_sci]
            ], axis=0)
            valid_msk = np.concatenate([
                chat_msk[vidx_chat], math_msk[vidx_math], sci_msk[vidx_sci]
            ], axis=0)

            # TRAIN gets remaining + all info
            rem_chat = np.setdiff1d(np.arange(chat_ids.shape[0]), vidx_chat, assume_unique=False)
            rem_math = np.setdiff1d(np.arange(math_ids.shape[0]), vidx_math, assume_unique=False)
            rem_sci  = np.setdiff1d(np.arange(sci_ids.shape[0]),  vidx_sci,  assume_unique=False)

            train_ids_noninfo = np.concatenate([chat_ids[rem_chat], math_ids[rem_math], sci_ids[rem_sci]], axis=0)
            train_msk_noninfo = np.concatenate([chat_msk[rem_chat], math_msk[rem_math], sci_msk[rem_sci]], axis=0)

            # true by default since I intended to include all info rows
            if keep_all_info_train:
                train_ids = np.concatenate([info_ids, train_ids_noninfo], axis=0)
                train_msk = np.concatenate([info_msk, train_msk_noninfo], axis=0)
            else:
                train_ids, train_msk = train_ids_noninfo, train_msk_noninfo

            # final shuffle
            rng = np.random.default_rng(seed+42)
            if train_ids.shape[0]:
                perm = rng.permutation(train_ids.shape[0])
                train_ids, train_msk = train_ids[perm], train_msk[perm]
            if valid_ids.shape[0]:
                perm = rng.permutation(valid_ids.shape[0])
                valid_ids, valid_msk = valid_ids[perm], valid_msk[perm]

            return train_ids, train_msk, valid_ids, valid_msk

        else:
            # if we’re building TRAIN directly (without knowing VALID), just keep all (plus info),
            # but we actually call this function only for VALID split above and reuse the remainder as TRAIN.
            raise RuntimeError("Internal: call this with split_name='valid' only.")

    # build VALID from VALID meta, then TRAIN from TRAIN meta by mirroring the same process:
    # for TRAIN/VALID as two independent directories, we need to:
    #  - compute VALID selection from VALID directory (its own files)
    #  - compute TRAIN selection from TRAIN directory (its own files), but we still need the 7:2:1
    #    on the TRAIN side. we're gonna keep all info_train.jsonl and do the ratio on *non-info*; we'll do that here by pretending
    #    valid_fraction_noninfo applies within each split directory independently.

    # VALID set
    TQDM.write("▶ Building VALID (token-ratio on non-info; no info forced)")
    # for VALID we don't keep info; we just do the selection directly from its own meta
    # we call the helper that expects 'valid' to return both train and valid for that directory; we only need valid_* here
    _train_ids_dummy, _train_msk_dummy, valid_ids, valid_msk = _split_from_meta(valid_meta, split_name="valid")

    # TRAIN set
    TQDM.write("▶ Building TRAIN (token-ratio on non-info; keep all info_* from train_meta)")
    train_ids, train_msk, _valid_ids_dummy, _valid_msk_dummy = _split_from_meta(train_meta, split_name="valid")

    # save four files
    np.save(out_dir / "train_input_ids.npy",  train_ids)
    np.save(out_dir / "train_loss_mask.npy",  train_msk)
    np.save(out_dir / "valid_input_ids.npy",  valid_ids)
    np.save(out_dir / "valid_loss_mask.npy",  valid_msk)

    # small report
    def _tok_sum(msk): return int((msk.sum(axis=1) - 1).clip(min=0).sum())
    TQDM.write(f"[SAVE] TRAIN rows={train_ids.shape[0]:,}  tokens={_tok_sum(train_msk):,}")
    TQDM.write(f"[SAVE] VALID rows={valid_ids.shape[0]:,}  tokens={_tok_sum(valid_msk):,}")

    return (out_dir / "train_input_ids.npy", out_dir / "train_loss_mask.npy",
            out_dir / "valid_input_ids.npy", out_dir / "valid_loss_mask.npy")


# Orchestrate

In [9]:
TRAIN_JSONL_DIR = Path("./train_jsonl")
VALID_JSONL_DIR = Path("./valid_jsonl")

PERFILE_TRAIN = Path("../final_npy/perfile_train")
PERFILE_VALID = Path("../final_npy/perfile_valid")
FINAL_DIR     = Path("../final_npy")

print("▶ Per-file build: TRAIN")
train_meta = build_npys_per_file_with_labels(TRAIN_JSONL_DIR, PERFILE_TRAIN, system_text=SYSTEM_PROMPT_DEFAULT, rng_seed=2345)

print("▶ Per-file build: VALID")
valid_meta = build_npys_per_file_with_labels(VALID_JSONL_DIR, PERFILE_VALID, system_text=SYSTEM_PROMPT_DEFAULT, rng_seed=3456)

print("▶ Token-ratio combine (7:2:1 by assistant tokens; keep all info_train)")
train_ids_p, train_msk_p, valid_ids_p, valid_msk_p = build_final_four_from_perfile(
    train_meta, valid_meta,
    ratios={"chat":7, "math":2, "science":1},
    keep_all_info_train=True,
    valid_fraction_noninfo=0.10,
    seed=2345,
    out_dir=FINAL_DIR,
)


▶ Per-file build: TRAIN


[SCAN] train_jsonl:   0%|           | 0/11 [00:00<?]

[READ] ai2_train.jsonl:   0%|           | 0.00/564k [00:00<?]

[WRITE] ai2_train.jsonl           → kept=3,244  scaf_drop=0  over_drop=0  other=0  p50=5.0  ones%=1.45%  out=(ai2_train.input_ids.npy, ai2_train.loss_mask.npy, ai2_train.labels.npy)


[READ] dolly_train.jsonl:   0%|           | 0.00/11.7M [00:00<?]

[WRITE] dolly_train.jsonl         → kept=12,658  scaf_drop=12  over_drop=1,576  other=0  p50=44.0  ones%=13.19%  out=(dolly_train.input_ids.npy, dolly_train.loss_mask.npy, dolly_train.labels.npy)


[READ] gsm8k_train.jsonl:   0%|           | 0.00/5.40M [00:00<?]

[WRITE] gsm8k_train.jsonl         → kept=7,299  scaf_drop=1  over_drop=173  other=0  p50=130.0  ones%=27.44%  out=(gsm8k_train.input_ids.npy, gsm8k_train.loss_mask.npy, gsm8k_train.labels.npy)


[READ] hendrycks_math_train.jsonl:   0%|           | 0.00/3.45M [00:00<?]

[WRITE] hendrycks_math_train.jsonl → kept=3,644  scaf_drop=0  over_drop=1,035  other=0  p50=140.0  ones%=29.63%  out=(hendrycks_math_train.input_ids.npy, hendrycks_math_train.loss_mask.npy, hendrycks_math_train.labels.npy)


[READ] info_train.jsonl:   0%|           | 0.00/1.01M [00:00<?]

[WRITE] info_train.jsonl          → kept=3,000  scaf_drop=0  over_drop=0  other=0  p50=7.0  ones%=2.01%  out=(info_train.input_ids.npy, info_train.loss_mask.npy, info_train.labels.npy)


[READ] numinamath_train.jsonl:   0%|           | 0.00/1.30G [00:00<?]

[WRITE] numinamath_train.jsonl    → kept=308,895  scaf_drop=23,381  over_drop=520,466  other=0  p50=208.0  ones%=40.03%  out=(numinamath_train.input_ids.npy, numinamath_train.loss_mask.npy, numinamath_train.labels.npy)


[READ] oasst1_train.jsonl:   0%|           | 0.00/3.89M [00:00<?]

[WRITE] oasst1_train.jsonl        → kept=2,750  scaf_drop=18  over_drop=714  other=0  p50=143.0  ones%=30.26%  out=(oasst1_train.input_ids.npy, oasst1_train.loss_mask.npy, oasst1_train.labels.npy)


[READ] openorca_train.jsonl:   0%|           | 0.00/1.58G [00:00<?]

[WRITE] openorca_train.jsonl      → kept=499,315  scaf_drop=112,307  over_drop=323,909  other=0  p50=70.0  ones%=18.26%  out=(openorca_train.input_ids.npy, openorca_train.loss_mask.npy, openorca_train.labels.npy)


[READ] sciq_train.jsonl:   0%|           | 0.00/6.32M [00:00<?]

[WRITE] sciq_train.jsonl          → kept=10,058  scaf_drop=1  over_drop=422  other=0  p50=80.0  ones%=19.51%  out=(sciq_train.input_ids.npy, sciq_train.loss_mask.npy, sciq_train.labels.npy)


[READ] svamp_train.jsonl:   0%|           | 0.00/154k [00:00<?]

[WRITE] svamp_train.jsonl         → kept=700  scaf_drop=0  over_drop=0  other=0  p50=13.0  ones%=3.07%  out=(svamp_train.input_ids.npy, svamp_train.loss_mask.npy, svamp_train.labels.npy)


[READ] ultrachat_train.jsonl:   0%|           | 0.00/1.21G [00:00<?]

[WRITE] ultrachat_train.jsonl     → kept=377,249  scaf_drop=6,478  over_drop=274,036  other=0  p50=222.0  ones%=40.24%  out=(ultrachat_train.input_ids.npy, ultrachat_train.loss_mask.npy, ultrachat_train.labels.npy)
▶ Per-file build: VALID


[SCAN] valid_jsonl:   0%|           | 0/11 [00:00<?]

[READ] ai2_valid.jsonl:   0%|           | 0.00/150k [00:00<?]

[WRITE] ai2_valid.jsonl           → kept=844  scaf_drop=0  over_drop=0  other=0  p50=5.0  ones%=1.43%  out=(ai2_valid.input_ids.npy, ai2_valid.loss_mask.npy, ai2_valid.labels.npy)


[READ] dolly_valid.jsonl:   0%|           | 0.00/629k [00:00<?]

[WRITE] dolly_valid.jsonl         → kept=667  scaf_drop=1  over_drop=82  other=0  p50=44.0  ones%=13.32%  out=(dolly_valid.input_ids.npy, dolly_valid.loss_mask.npy, dolly_valid.labels.npy)


[READ] gsm8k_valid.jsonl:   0%|           | 0.00/972k [00:00<?]

[WRITE] gsm8k_valid.jsonl         → kept=1,288  scaf_drop=0  over_drop=31  other=0  p50=135.0  ones%=28.17%  out=(gsm8k_valid.input_ids.npy, gsm8k_valid.loss_mask.npy, gsm8k_valid.labels.npy)


[READ] hendrycks_math_valid.jsonl:   0%|           | 0.00/2.19M [00:00<?]

[WRITE] hendrycks_math_valid.jsonl → kept=2,435  scaf_drop=1  over_drop=668  other=0  p50=138.0  ones%=29.00%  out=(hendrycks_math_valid.input_ids.npy, hendrycks_math_valid.loss_mask.npy, hendrycks_math_valid.labels.npy)


[READ] info_valid.jsonl:   0%|           | 0.00/33.2k [00:00<?]

[WRITE] info_valid.jsonl          → kept=100  scaf_drop=0  over_drop=0  other=0  p50=7.0  ones%=1.99%  out=(info_valid.input_ids.npy, info_valid.loss_mask.npy, info_valid.labels.npy)


[READ] numinamath_valid.jsonl:   0%|           | 0.00/148k [00:00<?]

[WRITE] numinamath_valid.jsonl    → kept=42  scaf_drop=4  over_drop=54  other=0  p50=217.0  ones%=41.90%  out=(numinamath_valid.input_ids.npy, numinamath_valid.loss_mask.npy, numinamath_valid.labels.npy)


[READ] oasst1_valid.jsonl:   0%|           | 0.00/204k [00:00<?]

[WRITE] oasst1_valid.jsonl        → kept=152  scaf_drop=1  over_drop=35  other=0  p50=129.5  ones%=29.72%  out=(oasst1_valid.input_ids.npy, oasst1_valid.loss_mask.npy, oasst1_valid.labels.npy)


[READ] openorca_valid.jsonl:   0%|           | 0.00/82.9M [00:00<?]

[WRITE] openorca_valid.jsonl      → kept=26,335  scaf_drop=5,844  over_drop=17,060  other=0  p50=71.0  ones%=18.31%  out=(openorca_valid.input_ids.npy, openorca_valid.loss_mask.npy, openorca_valid.labels.npy)


[READ] sciq_valid.jsonl:   0%|           | 0.00/533k [00:00<?]

[WRITE] sciq_valid.jsonl          → kept=851  scaf_drop=1  over_drop=35  other=0  p50=82.0  ones%=19.40%  out=(sciq_valid.input_ids.npy, sciq_valid.loss_mask.npy, sciq_valid.labels.npy)


[READ] svamp_valid.jsonl:   0%|           | 0.00/66.0k [00:00<?]

[WRITE] svamp_valid.jsonl         → kept=300  scaf_drop=0  over_drop=0  other=0  p50=13.0  ones%=3.17%  out=(svamp_valid.input_ids.npy, svamp_valid.loss_mask.npy, svamp_valid.labels.npy)


[READ] ultrachat_valid.jsonl:   0%|           | 0.00/134M [00:00<?]

[WRITE] ultrachat_valid.jsonl     → kept=42,461  scaf_drop=721  over_drop=29,967  other=0  p50=221.0  ones%=40.11%  out=(ultrachat_valid.input_ids.npy, ultrachat_valid.loss_mask.npy, ultrachat_valid.labels.npy)
▶ Token-ratio combine (7:2:1 by assistant tokens; keep all info_train)
▶ Building VALID (token-ratio on non-info; no info forced)
[valid] non-info token availability: {'chat': 11186901, 'math': 557155, 'science': 89024}
[valid] token targets (chat:math:science) = {'chat': 842969, 'math': 251315, 'science': 89024}
▶ Building TRAIN (token-ratio on non-info; keep all info_* from train_meta)
[valid] non-info token availability: {'chat': 124801624, 'math': 64581301, 'science': 1015204}
[valid] token targets (chat:math:science) = {'chat': 13772258, 'math': 4252351, 'science': 1015204}
[SAVE] TRAIN rows=1,161,960  tokens=171,386,497
[SAVE] VALID rows=5,209  tokens=1,182,854


wait... science is 0... <- This is fixed

# NPY sanity check

In [None]:
from pathlib import Path
import numpy as np

def _nfmt(x): 
    return f"{int(x):,}"

def sanity_check_pair(ids_path: Path, mask_path: Path, *, name: str):
    print(f"\n=== {name} ===")
    ids = np.load(ids_path,  mmap_mode="r")
    msk = np.load(mask_path, mmap_mode="r")

    # shapes & dtypes
    N, T = ids.shape
    print(f"[shape] ids={ids.shape} mask={msk.shape}")
    assert ids.shape == msk.shape, "ids/mask shape mismatch"
    assert T == MAX_LEN, f"seq len must be {MAX_LEN}"
    assert ids.dtype in (np.int32, np.int64), f"ids dtype unexpected: {ids.dtype}"
    assert msk.dtype in (np.uint8, np.int8, np.int32), f"mask dtype unexpected: {msk.dtype}"

    # mask is strictly 0/1
    uniq = np.unique(msk)
    assert set(uniq.tolist()) <= {0,1}, f"mask has values beyond 0/1: {uniq}"

    # last non-PAD must be END; exactly one END per row
    # (fast vectorized: find last non-PAD index, check it's END_ID)
    nonpad = (ids != PAD_ID)
    # guard for fully-padded (this should never happen...)
    fully_pad = (~nonpad.any(axis=1))
    assert not fully_pad.any(), "some rows are fully PAD"

    last_idx = ids.shape[1] - 1 - np.argmax(nonpad[:, ::-1], axis=1)
    last_tok = ids[np.arange(N), last_idx]
    only_one_end = (ids == END_ID).sum(axis=1) == 1
    end_is_last  = (last_tok == END_ID)
    ok_end = only_one_end & end_is_last
    bad_rows = (~ok_end).sum()
    print(f"[END token] ok={_nfmt(ok_end.sum())}/{_nfmt(N)} | only_one={_nfmt(only_one_end.sum())} | end_is_last={_nfmt(end_is_last.sum())}")
    assert bad_rows == 0, f"{bad_rows} rows fail END checks"

    # mask must be zero strictly after END
    tail_mask_sum = msk[np.arange(N)[:,None], np.clip(last_idx+1, 0, T-1)[:,None]].sum()  # just a touch
    # vector check: any 1s after last_idx?
    after_end_any = (msk * (np.arange(T)[None,:] > last_idx[:,None])).any(axis=1)
    assert (~after_end_any).all(), f"{after_end_any.sum()} rows have mask after END"
    print("[mask] no ones after END ✔")

    # assistant length stats from mask (mask counting includes END)
    asst_len_incl_end = msk.sum(axis=1)
    asst_len_excl_end = asst_len_incl_end - 1
    p50 = float(np.percentile(asst_len_excl_end, 50))
    p75 = float(np.percentile(asst_len_excl_end, 75))
    p90 = float(np.percentile(asst_len_excl_end, 90))
    short = (asst_len_excl_end <= 3).mean()
    print(f"[assistant span] p50={p50:.1f}  p75={p75:.1f}  p90={p90:.1f}  <=3 tokens={short:.2%}")

    # quick pad locality: PADs should only appear at the end
    # (heuristic: find first PAD; ensure all later tokens are PAD)
    first_pad = np.argmax((ids == PAD_ID), axis=1)  # returns 0 when first is PAD; handle no-PAD rows:
    no_pad_rows = ~ (ids == PAD_ID).any(axis=1)
    if no_pad_rows.any():
        first_pad[no_pad_rows] = T  # sentinel: no PADs
    pad_tail_ok = (ids == PAD_ID).sum(axis=1) == (T - first_pad)
    assert pad_tail_ok.all(), f"{(~pad_tail_ok).sum()} rows have PAD(s) before END region"
    print("[padding] PAD only in tail ✔")

    # deep audit
    audit_npys(ids_path, mask_path, sample_k=N)

    print(f"=== {name} OK ===")

# run & create files
train_ids_p = Path("../final_npy/train_input_ids.npy")
train_msk_p = Path("../final_npy/train_loss_mask.npy")
valid_ids_p = Path("../final_npy/valid_input_ids.npy")
valid_msk_p = Path("../final_npy/valid_loss_mask.npy")

sanity_check_pair(train_ids_p, train_msk_p, name="TRAIN")
sanity_check_pair(valid_ids_p, valid_msk_p, name="VALID")



=== TRAIN ===
[shape] ids=(1161960, 512) mask=(1161960, 512)
[END token] ok=1,161,960/1,161,960 | only_one=1,161,960 | end_is_last=1,161,960
[mask] no ones after END ✔
[assistant span] p50=141.0  p75=230.0  p90=281.0  <=3 tokens=2.27%
[padding] PAD only in tail ✔


[AUDIT*] train_input_ids.npy:   0%|           | 0/1161960 [00:00<?]

[AUDIT*] rows checked=1161960/1161960
  bad_order=0  multi_or_zero_END=0  mask_after_END=0
  boundary_mismatch (tag+prefix not visible near start) = 0
  assistant_len<=3: 26387 (2.27%), rut_hits: 72 (0.01%)
=== TRAIN OK ===

=== VALID ===
[shape] ids=(5209, 512) mask=(5209, 512)
[END token] ok=5,209/5,209 | only_one=5,209 | end_is_last=5,209
[mask] no ones after END ✔
[assistant span] p50=284.0  p75=343.0  p90=352.0  <=3 tokens=4.40%
[padding] PAD only in tail ✔


[AUDIT*] valid_input_ids.npy:   0%|           | 0/5209 [00:00<?]

[AUDIT*] rows checked=5209/5209
  bad_order=0  multi_or_zero_END=0  mask_after_END=0
  boundary_mismatch (tag+prefix not visible near start) = 0
  assistant_len<=3: 229 (4.40%), rut_hits: 0 (0.00%)
=== VALID OK ===
