In [1]:
# =========================================================
# Step 1 — Build CPC classification dataset (A–H, Y)
# - Input: folder of JSONL files with {"pn", "c": {...}, "cpc": [...]}
# - Output: CSVs with two columns: text,label  (Claim 1, single CPC section)
# =========================================================
from __future__ import annotations
import json, re, unicodedata, random, csv
from pathlib import Path
from collections import Counter
from typing import Dict, Iterable, Tuple, List
from tqdm.auto import tqdm

# -------- CONFIG (edit) --------
INPUT_DIR     = Path("../data/ep-b1-claims-cpc")  # ~204 JSONL files
OUT_DIR       = Path("../data/cpc_cls")           # will write train/val/test CSVs
SPLIT_RATIOS  = (0.90, 0.05, 0.05)               # train/val/test
SEED          = 42

# Cleaning options (aligned with tokenizer corpus builder)
NFKC          = True
MIN_LEN       = 20
MAX_LEN       = 5000
REPLACE_REFNUM = True   # replace (101), [0032], {12,14} → <REFNUM>

# Charset guards (byte-BPE friendly)
MIN_PRINTABLE_RATIO = 0.98
MAX_NONASCII_RATIO  = 0.0  # 0 = disable

# -------- Patterns (same spirit as earlier) --------
REF_PARENS = r"""[\(\[\{]\s*(?:\d+[A-Za-z]*[′'″]*)(?:\s*,\s*\d+[A-Za-z]*[′'″]*)*\s*[\)\]\}]"""
REF_REGEX  = re.compile(REF_PARENS)
WS         = re.compile(r"\s+")

PRINTABLE_SET = set(
    list(map(chr, range(32,127)))
) | {"’","“","”","–","—","·","•","°","µ","²","³","±","≥","≤","½","¼","¾","™","®","§"}

VALID_SECTIONS = set("ABCDEFGHY")  # include Y if present

# -------- Helpers --------
def is_charset_ok(s: str) -> bool:
    if not s: return False
    total = len(s)
    printable = sum((ch in PRINTABLE_SET) or ch.isprintable() for ch in s)
    if printable / total < MIN_PRINTABLE_RATIO:
        return False
    if MAX_NONASCII_RATIO:
        nonascii = sum(ord(ch) > 127 for ch in s)
        if (nonascii / total) > MAX_NONASCII_RATIO:
            return False
    return True

def clean_text(t: str) -> str:
    t = t.strip().replace("\n", " ")
    if NFKC:
        t = unicodedata.normalize("NFKC", t)
    if REPLACE_REFNUM:
        t = REF_REGEX.sub(" <REFNUM> ", t)
        t = re.sub(r"(?:\s*<REFNUM>\s*){2,}", " <REFNUM> ", t)
    t = WS.sub(" ", t).strip()
    return t

def sections_from_cpc(codes: Iterable[str]) -> List[str]:
    secs = []
    for c in codes or []:
        if not isinstance(c, str) or not c: continue
        s = c[0].upper()
        if s in VALID_SECTIONS:
            secs.append(s)
    return secs

def iter_claim1_and_sections(fp: Path):
    with fp.open("r", encoding="utf-8", errors="ignore") as fh:
        for line in fh:
            try:
                obj = json.loads(line)
            except json.JSONDecodeError:
                continue
            claims = obj.get("c") or {}
            # claim 1 text
            c1 = None
            if isinstance(claims, dict):
                # keys may be "1", "01", etc.
                for k in ("1","01"):
                    if k in claims and isinstance(claims[k], str):
                        c1 = claims[k]; break
                if c1 is None and claims:
                    # fall back to first item by key order
                    k = sorted(claims.keys(), key=lambda x: (len(x), x))[0]
                    c1 = claims.get(k) if isinstance(claims.get(k), str) else None
            elif isinstance(claims, list) and claims:
                c1 = claims[0] if isinstance(claims[0], str) else None
            if not c1: 
                continue

            secs = sections_from_cpc(obj.get("cpc"))
            yield c1, secs

# -------- Build dataset --------
files = sorted(INPUT_DIR.glob("*.jsonl"))
rows: List[Tuple[str,str]] = []
discard_reasons = Counter()

for fp in tqdm(files, desc="Files", unit="file"):
    for c1, secs in iter_claim1_and_sections(fp):
        # reduce to unique section letters
        uniq = sorted(set(secs))
        if len(uniq) != 1:
            discard_reasons["multi_or_none_sections"] += 1
            continue
        label = uniq[0]
        text = clean_text(c1)

        if (MIN_LEN and len(text) < MIN_LEN) or (MAX_LEN and len(text) > MAX_LEN):
            discard_reasons["length"] += 1
            continue
        if not is_charset_ok(text):
            discard_reasons["charset"] += 1
            continue

        rows.append((text, label))

print(f"Kept: {len(rows)}  | Discarded: {sum(discard_reasons.values())}  → {dict(discard_reasons)}")

# -------- Stratified split (by label) --------
random.seed(SEED)
by_label: Dict[str, List[Tuple[str,str]]] = {}
for t,l in rows:
    by_label.setdefault(l, []).append((t,l))

train_r, val_r, test_r = SPLIT_RATIOS
splits = {"train": [], "val": [], "test": []}

for l, items in by_label.items():
    random.shuffle(items)
    n = len(items)
    n_train = int(round(train_r * n))
    n_val   = int(round(val_r * n))
    train = items[:n_train]
    val   = items[n_train:n_train+n_val]
    test  = items[n_train+n_val:]
    splits["train"].extend(train)
    splits["val"].extend(val)
    splits["test"].extend(test)

# Shuffle within each split
for k in splits:
    random.shuffle(splits[k])

# -------- Write CSVs --------
OUT_DIR.mkdir(parents=True, exist_ok=True)
for split_name, items in splits.items():
    out_csv = OUT_DIR / f"cpc_cls_{split_name}.csv"
    with out_csv.open("w", encoding="utf-8", newline="") as f:
        w = csv.writer(f)
        w.writerow(["text","label"])
        w.writerows(items)
    print(f"Wrote {split_name}: {len(items)} → {out_csv}")

# Label distribution report
for k in splits:
    c = Counter(lbl for _, lbl in splits[k])
    print(f"{k} label counts:", dict(sorted(c.items())))


Files:   0%|          | 0/204 [00:00<?, ?file/s]

Kept: 216284  | Discarded: 177958  → {'multi_or_none_sections': 147406, 'length': 30552}
Wrote train: 194656 → ../data/cpc_cls/cpc_cls_train.csv
Wrote val: 10814 → ../data/cpc_cls/cpc_cls_val.csv
Wrote test: 10814 → ../data/cpc_cls/cpc_cls_test.csv
train label counts: {'A': 41167, 'B': 34502, 'C': 14684, 'D': 2529, 'E': 7095, 'F': 15157, 'G': 36104, 'H': 43418}
val label counts: {'A': 2287, 'B': 1917, 'C': 816, 'D': 140, 'E': 394, 'F': 842, 'G': 2006, 'H': 2412}
test label counts: {'A': 2287, 'B': 1917, 'C': 816, 'D': 141, 'E': 394, 'F': 842, 'G': 2005, 'H': 2412}
