In [None]:
# ==============================
# patnroberta - Independent Claims Corpus Builder
# ==============================
# - Processes EPO JSONL files containing claims
# - Cleans text, normalizes reference numerals, dedupes
# - Produces train/val/test line-separated corpora with adjustable ratios
# ==============================

from __future__ import annotations
import re, json, unicodedata, string, random, math
from pathlib import Path
from tqdm.auto import tqdm
from typing import Iterable, Tuple, Dict, List

# ---------- Config ----------
INPUT_FOLDER = "../data/0-scraped"
OUTPUT_PREFIX = "../data/1-corpus-ind-claims/corpus"   # will write *_train.txt, *_val.txt, *_test.txt

# Cleaning/normalization
REFS_MODE    = "replace"     # "keep" | "remove" | "replace"  (replace -> <REFNUM>)
ADD_EOS      = False         # RoBERTa does not need EOS in raw text
ALL_CLAIMS   = False         # True = keep all claims; False = only independent claim (first)
NFKC         = True          # normalize to NFKC
DEDUPE       = True          # deduplicate lines (normalized key)
MIN_LEN      = 20            # 0 = disable
MAX_LEN      = 5000          # 0 = disable

# Charset guards (byte-BPE friendly): do not penalize non-ASCII
MIN_PRINTABLE_RATIO = 0.98   # drop lines with < this fraction of printable chars
MAX_NONASCII_RATIO  = 0.0    # 0 disables the non-ASCII cap

# Splits (must sum to 1.0). Examples: (0.98, 0.01, 0.01) or (0.90, 0.05, 0.05)
TRAIN_RATIO, VAL_RATIO, TEST_RATIO = 0.98, 0.01, 0.01
RANDOM_SEED = 42             # for deterministic shuffling

# ---------- Patterns ----------
# Reference numerals like (101), [0032], {12, 14}, etc.
REF_PARENS = r"""[\(\[\{]\s*(?:\d+[A-Za-z]*[′'″]*)(?:\s*,\s*\d+[A-Za-z]*[′'″]*)*\s*[\)\]\}]"""
REF_REGEX  = re.compile(REF_PARENS)

# whitespace collapse
WS = re.compile(r"\s+")

# digits collapse (for dedupe key)
DIGITS = re.compile(r"\d+")

# Allow extra useful Unicode in patents
PRINTABLE_SET = set(string.printable) | {"’","“","”","–","—","·","•","°","µ","²","³","±","≥","≤","·","½","¼","¾","™","®","§"}

# ---------- Helpers ----------
def is_line_charset_ok(s: str) -> bool:
    if not s:
        return False
    total = len(s)
    printable = sum((ch in PRINTABLE_SET) or ch.isprintable() for ch in s)
    if total == 0 or printable / total < MIN_PRINTABLE_RATIO:
        return False
    if MAX_NONASCII_RATIO:
        nonascii = sum(ord(ch) > 127 for ch in s)
        if (nonascii / total) > MAX_NONASCII_RATIO:
            return False
    return True

def process_ref_numerals(text: str, mode: str) -> str:
    if mode == "keep":
        return text
    if mode == "remove":
        return REF_REGEX.sub(" ", text)
    if mode == "replace":
        out = REF_REGEX.sub(" <REFNUM> ", text)
        # collapse repeats of <REFNUM>
        return re.sub(r"(?:\s*<REFNUM>\s*){2,}", " <REFNUM> ", out)
    raise ValueError("REFS_MODE must be 'keep', 'remove', or 'replace'")

def clean_claim(text: str, refs_mode: str, add_eos: bool, nfkc: bool) -> str:
    t = text.strip().replace("\n", " ")
    if nfkc:
        t = unicodedata.normalize("NFKC", t)
    t = process_ref_numerals(t, refs_mode)
    t = WS.sub(" ", t).strip()
    if add_eos and not t.endswith("<EOS>"):
        t = f"{t} <EOS>"
    return t

def dedupe_key(t: str) -> str:
    """
    Build a normalized key for deduplication while keeping original text cased.
    - remove <REFNUM>
    - collapse digits to 0
    - normalize spaces
    - lowercase for key only
    """
    k = t.replace("<REFNUM>", " ")
    k = DIGITS.sub("0", k)
    k = " ".join(k.split()).lower()
    return k

def validate_splits(train_r: float, val_r: float, test_r: float) -> Tuple[int,int,int]:
    s = train_r + val_r + test_r
    if abs(s - 1.0) > 1e-8:
        raise ValueError(f"Split ratios must sum to 1.0, got {s}")
    # Return as percentages for logging
    return int(train_r*100), int(val_r*100), int(test_r*100)

# ---------- Main processing ----------
def iter_claims_from_file(fp: Path, all_claims: bool) -> Iterable[str]:
    with open(fp, "r", encoding="utf-8", errors="ignore") as fh:
        for line in fh:
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                continue
            claims = data.get("c", {})
            if isinstance(claims, dict):
                values = list(claims.values())
            elif isinstance(claims, list):
                values = claims
            else:
                values = []
            for i, claim_text in enumerate(values):
                if not isinstance(claim_text, str):
                    continue
                yield claim_text
                if not all_claims:
                    break  # only the first (independent) claim

def build_corpus(in_dir: str) -> Tuple[List[str], Dict[str,int], int]:
    in_path = Path(in_dir)
    files = sorted(in_path.glob("*.jsonl"))
    seen_keys = set()
    kept: List[str] = []
    stats = {
        "files_total": len(files),
        "lines_raw": 0,
        "kept": 0,
        "dupe": 0,
        "len_drop": 0,
        "charset_drop": 0,
        "json_errors": 0,  # (approx; we skip but don't count per-line errors distinctly)
    }

    with tqdm(total=len(files), desc="Files", unit="file") as pbar:
        for fp in files:
            try:
                for raw_claim in iter_claims_from_file(fp, ALL_CLAIMS):
                    stats["lines_raw"] += 1
                    t = clean_claim(raw_claim, REFS_MODE, ADD_EOS, NFKC)

                    # Length guards
                    if (MIN_LEN and len(t) < MIN_LEN) or (MAX_LEN and len(t) > MAX_LEN):
                        stats["len_drop"] += 1
                        continue

                    # Charset guards
                    if not is_line_charset_ok(t):
                        stats["charset_drop"] += 1
                        continue

                    # Dedup (normalized key)
                    if DEDUPE:
                        key = dedupe_key(t)
                        if key in seen_keys:
                            stats["dupe"] += 1
                            continue
                        seen_keys.add(key)

                    kept.append(t)
                    stats["kept"] += 1
            finally:
                pbar.update(1)

    return kept, stats, len(files)

def write_splits(lines: List[str], out_prefix: str,
                 train_r: float, val_r: float, test_r: float,
                 seed: int = 42) -> Dict[str,int]:
    random.Random(seed).shuffle(lines)
    n = len(lines)
    n_train = int(round(train_r * n))
    n_val   = int(round(val_r   * n))
    n_test  = n - n_train - n_val  # remainder to test to ensure sum == n

    train = lines[:n_train]
    val   = lines[n_train:n_train+n_val]
    test  = lines[n_train+n_val:]

    out_train = f"{out_prefix}_train.txt"
    out_val   = f"{out_prefix}_val.txt"
    out_test  = f"{out_prefix}_test.txt"

    for path, arr in [(out_train, train), (out_val, val), (out_test, test)]:
        with open(path, "w", encoding="utf-8") as f:
            for t in arr:
                f.write(t + "\n")

    return {"train": len(train), "val": len(val), "test": len(test),
            "out_train": out_train, "out_val": out_val, "out_test": out_test}

# ---------- Run ----------
if __name__ == "__main__":
    tr_p, va_p, te_p = validate_splits(TRAIN_RATIO, VAL_RATIO, TEST_RATIO)
    print(f"Split ratios → train {tr_p}%, val {va_p}%, test {te_p}%")

    lines, stats, nfiles = build_corpus(INPUT_FOLDER)
    print("\n=== Stats ===")
    print(f"Files processed           : {nfiles}")
    print(f"Raw claims seen           : {stats['lines_raw']}")
    print(f"Kept (after cleaning/dedupe): {stats['kept']}")
    print(f"Dropped (length)          : {stats['len_drop']}")
    print(f"Dropped (charset)         : {stats['charset_drop']}")
    print(f"Duplicates removed        : {stats['dupe']}")

    split_info = write_splits(lines, OUTPUT_PREFIX,
                              TRAIN_RATIO, VAL_RATIO, TEST_RATIO,
                              seed=RANDOM_SEED)
    print("\n=== Outputs ===")
    print(f"Train lines: {split_info['train']} → {split_info['out_train']}")
    print(f"Val lines  : {split_info['val']}   → {split_info['out_val']}")
    print(f"Test lines : {split_info['test']}  → {split_info['out_test']}")


Split ratios → train 98%, val 1%, test 1%


Files:   0%|          | 0/204 [00:00<?, ?file/s]


=== Stats ===
Files processed           : 204
Raw claims seen           : 394242
Kept (after cleaning/dedupe): 344539
Dropped (length)          : 49504
Dropped (charset)         : 0
Duplicates removed        : 199

=== Outputs ===
Train lines: 337648 → ../data/1-corpus-ind-claims_train.txt
Val lines  : 3445   → ../data/1-corpus-ind-claims_val.txt
Test lines : 3446  → ../data/1-corpus-ind-claims_test.txt


# All claims

In [1]:
# ==============================
# patnroberta - All Claims Corpus Builder
# ==============================
# - Processes EPO JSONL files containing claims
# - Cleans text, normalizes reference numerals, dedupes
# - Produces train/val/test line-separated corpora with adjustable ratios
# ==============================

from __future__ import annotations
import re, json, unicodedata, string, random, math
from pathlib import Path
from tqdm.auto import tqdm
from typing import Iterable, Tuple, Dict, List

# ---------- Config ----------
INPUT_FOLDER = "../data/0-scraped"
OUTPUT_PREFIX = "../data/corpus"   # will write *_train.txt, *_val.txt, *_test.txt

# Cleaning/normalization
REFS_MODE    = "replace"     # "keep" | "remove" | "replace"  (replace -> <REFNUM>)
ADD_EOS      = False         # RoBERTa does not need EOS in raw text
ALL_CLAIMS   = True          # True = keep all claims; False = only independent claim (first)
NFKC         = True          # normalize to NFKC
DEDUPE       = True          # deduplicate lines (normalized key)
MIN_LEN      = 20            # 0 = disable
MAX_LEN      = 5000          # 0 = disable

# Charset guards (byte-BPE friendly): do not penalize non-ASCII
MIN_PRINTABLE_RATIO = 0.98   # drop lines with < this fraction of printable chars
MAX_NONASCII_RATIO  = 0.0    # 0 disables the non-ASCII cap

# Splits (must sum to 1.0). Examples: (0.98, 0.01, 0.01) or (0.90, 0.05, 0.05)
TRAIN_RATIO, VAL_RATIO, TEST_RATIO = 0.98, 0.01, 0.01
RANDOM_SEED = 42             # for deterministic shuffling

# ---------- Patterns ----------
# Reference numerals like (101), [0032], {12, 14}, etc.
REF_PARENS = r"""[\(\[\{]\s*(?:\d+[A-Za-z]*[′'″]*)(?:\s*,\s*\d+[A-Za-z]*[′'″]*)*\s*[\)\]\}]"""
REF_REGEX  = re.compile(REF_PARENS)

# whitespace collapse
WS = re.compile(r"\s+")

# digits collapse (for dedupe key)
DIGITS = re.compile(r"\d+")

# Allow extra useful Unicode in patents
PRINTABLE_SET = set(string.printable) | {"’","“","”","–","—","·","•","°","µ","²","³","±","≥","≤","·","½","¼","¾","™","®","§"}

# ---------- Helpers ----------
def is_line_charset_ok(s: str) -> bool:
    if not s:
        return False
    total = len(s)
    printable = sum((ch in PRINTABLE_SET) or ch.isprintable() for ch in s)
    if total == 0 or printable / total < MIN_PRINTABLE_RATIO:
        return False
    if MAX_NONASCII_RATIO:
        nonascii = sum(ord(ch) > 127 for ch in s)
        if (nonascii / total) > MAX_NONASCII_RATIO:
            return False
    return True

def process_ref_numerals(text: str, mode: str) -> str:
    if mode == "keep":
        return text
    if mode == "remove":
        return REF_REGEX.sub(" ", text)
    if mode == "replace":
        out = REF_REGEX.sub(" <REFNUM> ", text)
        # collapse repeats of <REFNUM>
        return re.sub(r"(?:\s*<REFNUM>\s*){2,}", " <REFNUM> ", out)
    raise ValueError("REFS_MODE must be 'keep', 'remove', or 'replace'")

def clean_claim(text: str, refs_mode: str, add_eos: bool, nfkc: bool) -> str:
    t = text.strip().replace("\n", " ")
    if nfkc:
        t = unicodedata.normalize("NFKC", t)
    t = process_ref_numerals(t, refs_mode)
    t = WS.sub(" ", t).strip()
    if add_eos and not t.endswith("<EOS>"):
        t = f"{t} <EOS>"
    return t

def dedupe_key(t: str) -> str:
    """
    Build a normalized key for deduplication while keeping original text cased.
    - remove <REFNUM>
    - collapse digits to 0
    - normalize spaces
    - lowercase for key only
    """
    k = t.replace("<REFNUM>", " ")
    k = DIGITS.sub("0", k)
    k = " ".join(k.split()).lower()
    return k

def validate_splits(train_r: float, val_r: float, test_r: float) -> Tuple[int,int,int]:
    s = train_r + val_r + test_r
    if abs(s - 1.0) > 1e-8:
        raise ValueError(f"Split ratios must sum to 1.0, got {s}")
    # Return as percentages for logging
    return int(train_r*100), int(val_r*100), int(test_r*100)

# ---------- Main processing ----------
def iter_claims_from_file(fp: Path, all_claims: bool) -> Iterable[str]:
    with open(fp, "r", encoding="utf-8", errors="ignore") as fh:
        for line in fh:
            try:
                data = json.loads(line)
            except json.JSONDecodeError:
                continue
            claims = data.get("c", {})
            if isinstance(claims, dict):
                values = list(claims.values())
            elif isinstance(claims, list):
                values = claims
            else:
                values = []
            for i, claim_text in enumerate(values):
                if not isinstance(claim_text, str):
                    continue
                yield claim_text
                if not all_claims:
                    break  # only the first (independent) claim

def build_corpus(in_dir: str) -> Tuple[List[str], Dict[str,int], int]:
    in_path = Path(in_dir)
    files = sorted(in_path.glob("*.jsonl"))
    seen_keys = set()
    kept: List[str] = []
    stats = {
        "files_total": len(files),
        "lines_raw": 0,
        "kept": 0,
        "dupe": 0,
        "len_drop": 0,
        "charset_drop": 0,
        "json_errors": 0,  # (approx; we skip but don't count per-line errors distinctly)
    }

    with tqdm(total=len(files), desc="Files", unit="file") as pbar:
        for fp in files:
            try:
                for raw_claim in iter_claims_from_file(fp, ALL_CLAIMS):
                    stats["lines_raw"] += 1
                    t = clean_claim(raw_claim, REFS_MODE, ADD_EOS, NFKC)

                    # Length guards
                    if (MIN_LEN and len(t) < MIN_LEN) or (MAX_LEN and len(t) > MAX_LEN):
                        stats["len_drop"] += 1
                        continue

                    # Charset guards
                    if not is_line_charset_ok(t):
                        stats["charset_drop"] += 1
                        continue

                    # Dedup (normalized key)
                    if DEDUPE:
                        key = dedupe_key(t)
                        if key in seen_keys:
                            stats["dupe"] += 1
                            continue
                        seen_keys.add(key)

                    kept.append(t)
                    stats["kept"] += 1
            finally:
                pbar.update(1)

    return kept, stats, len(files)

def write_splits(lines: List[str], out_prefix: str,
                 train_r: float, val_r: float, test_r: float,
                 seed: int = 42) -> Dict[str,int]:
    random.Random(seed).shuffle(lines)
    n = len(lines)
    n_train = int(round(train_r * n))
    n_val   = int(round(val_r   * n))
    n_test  = n - n_train - n_val  # remainder to test to ensure sum == n

    train = lines[:n_train]
    val   = lines[n_train:n_train+n_val]
    test  = lines[n_train+n_val:]

    out_train = f"{out_prefix}_train.txt"
    out_val   = f"{out_prefix}_val.txt"
    out_test  = f"{out_prefix}_test.txt"

    for path, arr in [(out_train, train), (out_val, val), (out_test, test)]:
        with open(path, "w", encoding="utf-8") as f:
            for t in arr:
                f.write(t + "\n")

    return {"train": len(train), "val": len(val), "test": len(test),
            "out_train": out_train, "out_val": out_val, "out_test": out_test}

# ---------- Run ----------
if __name__ == "__main__":
    tr_p, va_p, te_p = validate_splits(TRAIN_RATIO, VAL_RATIO, TEST_RATIO)
    print(f"Split ratios → train {tr_p}%, val {va_p}%, test {te_p}%")

    lines, stats, nfiles = build_corpus(INPUT_FOLDER)
    print("\n=== Stats ===")
    print(f"Files processed           : {nfiles}")
    print(f"Raw claims seen           : {stats['lines_raw']}")
    print(f"Kept (after cleaning/dedupe): {stats['kept']}")
    print(f"Dropped (length)          : {stats['len_drop']}")
    print(f"Dropped (charset)         : {stats['charset_drop']}")
    print(f"Duplicates removed        : {stats['dupe']}")

    split_info = write_splits(lines, OUTPUT_PREFIX,
                              TRAIN_RATIO, VAL_RATIO, TEST_RATIO,
                              seed=RANDOM_SEED)
    print("\n=== Outputs ===")
    print(f"Train lines: {split_info['train']} → {split_info['out_train']}")
    print(f"Val lines  : {split_info['val']}   → {split_info['out_val']}")
    print(f"Test lines : {split_info['test']}  → {split_info['out_test']}")


Split ratios → train 98%, val 1%, test 1%


Files:   0%|          | 0/204 [00:00<?, ?file/s]


=== Stats ===
Files processed           : 204
Raw claims seen           : 4593858
Kept (after cleaning/dedupe): 4420734
Dropped (length)          : 82477
Dropped (charset)         : 0
Duplicates removed        : 90647

=== Outputs ===
Train lines: 4332319 → ../data/corpus_train.txt
Val lines  : 44207   → ../data/corpus_val.txt
Test lines : 44208  → ../data/corpus_test.txt
