# Getting started
This notebook saves `datasets/WikiLarge` as Huggingface's Dataset format to
`datasets/wikilarge_dataset` and cleans it to `datasets/wikilarge_dataset_clean`

The repo should have [WikiLarge](https://github.com/XingxingZhang/dress) at `datasets/WikiLarge`:

In [None]:
!ls -l ./datasets/WikiLarge

In [None]:
# Install dependencies
!pip install -q sentencepiece
!pip install -q --upgrade datasets fsspec

# WikiLarge to wikilarge_dataset
Creating Huggingface dataset from WikiLarge at `dataset/wikilarge_dataset`

In [None]:
from datasets import Dataset, DatasetDict, load_from_disk
import os
import re

REPLACEMENTS = {
    "-LRB-": "(", "-RRB-": ")",
    "-LSB-": "[", "-RSB-": "]",
    "-LCB-": "{", "-RCB-": "}",
}
_PATTERN = re.compile("|".join(map(re.escape, REPLACEMENTS.keys())))

def clean_text(text: str) -> str:
    return _PATTERN.sub(lambda m: REPLACEMENTS[m.group(0)], text)

def load_split(src_path, dst_path) -> Dataset:
      with open(src_path, 'r', encoding='utf-8') as f_src, open(dst_path, 'r', encoding='utf-8') as f_dst:
          src_text = _PATTERN.sub(lambda m: REPLACEMENTS[m.group(0)], f_src.read())
          dst_text = _PATTERN.sub(lambda m: REPLACEMENTS[m.group(0)], f_dst.read())
      src_lines = [line.strip() for line in src_text.splitlines()]
      dst_lines = [line.strip() for line in dst_text.splitlines()]
      return Dataset.from_dict({"source": src_lines, "target": dst_lines})

PATH_WIKILARGE_DIR = os.path.join(os.getcwd(), "datasets/WikiLarge")
PATH_DATASET_OUT = os.path.join(os.getcwd(), "datasets/wikilarge_dataset")


if os.path.exists(PATH_DATASET_OUT):
    dataset = load_from_disk(PATH_DATASET_OUT)
    print("Loaded ds from disk")
else:
    wl_root = os.path.join(PATH_WIKILARGE_DIR, "wiki.full.aner.ori.")

    dataset = DatasetDict({
        "train": load_split(wl_root + "train.src", wl_root + "train.dst"),
        "validation": load_split(wl_root + "valid.src", wl_root + "valid.dst"),
        "test": load_split(wl_root + "test.src", wl_root + "test.dst")
    })
    dataset.save_to_disk(PATH_DATASET_OUT)
    print(f"Saved ds to disk at `{PATH_DATASET_OUT}`")

In [4]:
#sanity check
display(dataset["train"].to_pandas().head(10))

Unnamed: 0,source,target
0,There is manuscript evidence that Austen conti...,There is some proof that Austen continued to w...
1,"In a remarkable comparative analysis , Mandaea...",Mandaean scholar Säve-Söderberg showed that Ma...
2,"Before Persephone was released to Hermes , who...",When Demeter went to the Underworld to rescue ...
3,Cogeneration plants are commonly found in dist...,Cogeneration plants are commonly found in dist...
4,"Geneva ( , ; , ; , ; ; ) is the second-most-po...",The city 's main newspaper is the Tribune de G...
5,When Japan earned another race on the F1 sched...,When Japan was added back to the F1 schedule t...
6,This marked the first motorcycle racing event ...,This was the first motorcycle racing event at ...
7,Ka La Ku'oko ' a ( Hawaiian Independence Day )...,November 28 - Ka La Ku'oko ` a : Hawaiian Inde...
8,"A scythe ( Oxford English Dictionary , Oxford ...","A scythe ( , from Old English siðeOxford Engli..."
9,Toshihide Saito is a Japanese football ( socce...,Toshihide Saito ( born 20 April 1973 ) is a Ja...


# Cleaning wikilarge_dataset

In [None]:
# =========================
# Cheap filters + dedup + detailed reporting
# =========================

import re
import hashlib
from collections import Counter
# ------------ Config ------------
MIN_SRC_TOKENS = 4
MAX_SRC_TOKENS = 256
MIN_TGT_TOKENS = 2
MAX_TGT_TOKENS = 256
MIN_COMP_RATIO = 0.40    # len(simple)/len(complex) >= 0.40
MAX_COMP_RATIO = 0.95    # <= 0.95
MAX_JACCARD   = 0.98     # lexical near-identity cutoff

DRIVE_CLEAN_DATASET_OUT = os.path.join(os.getcwd(), "datasets/wikilarge_dataset_clean")

# ------------ Helpers ------------
def jaccard(a_tokens, b_tokens) -> float:
    """Lexical Jaccard similarity between token sets."""
    A, B = set(a_tokens), set(b_tokens)
    if not A and not B:
        return 1.0
    return len(A & B) / max(1, len(A | B))

def _norm_for_hash(s: str) -> str:
    """Normalize text for hashing (lower + collapse spaces)."""
    s = s.lower().strip()
    s = re.sub(r"\s+", " ", s)
    return s

def _pair_key(src: str, tgt: str) -> str:
    """Stable hash key for a (source,target) pair."""
    h1 = hashlib.md5(_norm_for_hash(src).encode()).hexdigest()
    h2 = hashlib.md5(_norm_for_hash(tgt).encode()).hexdigest()
    return f"{h1}#{h2}"

# ------------ Core: annotate + report + clean ------------
def annotate_basic_flags(batch):
    """
    Compute per-example flags for the basic rules in a single batched pass.
    Also return helper fields: compression ratio (cr), jaccard (jac), and reason.
    """
    sources = batch["source"]
    targets = batch["target"]

    src_len_ok, tgt_len_ok = [], []
    ratio_ok, jacc_ok = [], []
    keep_basic, first_fail = [], []

    cr_list, jac_list = [], []

    for s, t in zip(sources, targets):
        s = (s or "").strip()
        t = (t or "").strip()
        sw, tw = s.split(), t.split()

        sl_ok = (MIN_SRC_TOKENS <= len(sw) <= MAX_SRC_TOKENS)
        tl_ok = (MIN_TGT_TOKENS <= len(tw) <= MAX_TGT_TOKENS)

        cr = (len(tw) / max(1, len(sw))) if sw else 0.0
        ro = (MIN_COMP_RATIO <= cr <= MAX_COMP_RATIO)

        jac = jaccard(sw, tw)
        jo = (jac < MAX_JACCARD)

        src_len_ok.append(sl_ok); tgt_len_ok.append(tl_ok)
        ratio_ok.append(ro); jacc_ok.append(jo)
        cr_list.append(cr); jac_list.append(jac)

        reason = "keep"
        if not (sl_ok and tl_ok):
            reason = "length"
        elif not ro:
            reason = "compression"
        elif not jo:
            reason = "near_identity"

        first_fail.append(reason)
        keep_basic.append(reason == "keep")

    return {
        "src_len_ok": src_len_ok, "tgt_len_ok": tgt_len_ok,
        "ratio_ok": ratio_ok, "jacc_ok": jacc_ok,
        "keep_basic": keep_basic, "first_fail": first_fail,
        "cr": cr_list, "jac": jac_list
    }

def report_basic_filter_stats(ds_split: Dataset, split_name: str):
    """
    Print non-exclusive and exclusive breakdown; return cleaned split and counts.
    """
    total = len(ds_split)
    if total == 0:
        print(f"[{split_name}] empty split")
        return ds_split, 0, 0

    # Single batched pass to compute all flags
    flagged = ds_split.map(
        annotate_basic_flags, batched=True, batch_size=2048,
        desc=f"Annotate flags ({split_name})"
    )

    # Non-exclusive failure counts
    len_fail   = sum([not (a and b) for a, b in zip(flagged["src_len_ok"], flagged["tgt_len_ok"])])
    ratio_fail = sum([not x for x in flagged["ratio_ok"]])
    jacc_fail  = sum([not x for x in flagged["jacc_ok"]])

    # Exclusive first-fail counts
    cnt = Counter(flagged["first_fail"])
    excl_len    = cnt.get("length", 0)
    excl_ratio  = cnt.get("compression", 0)
    excl_jacc   = cnt.get("near_identity", 0)
    excl_keep   = cnt.get("keep", 0)
    removed_exclusive = excl_len + excl_ratio + excl_jacc
    print(f"\n[{split_name}]")
    print(f"Basic filter diagnostics (non-exclusive):")
    print(f"  Length fails       : {len_fail:>7} / {total} ({100*len_fail/total:5.2f}%)")
    print(f"  Compression fails  : {ratio_fail:>7} / {total} ({100*ratio_fail/total:5.2f}%)")
    print(f"  Near-identity fails: {jacc_fail:>7} / {total} ({100*jacc_fail/total:5.2f}%)")

    print(f"\nBasic filter diagnostics (exclusive, first-fail order):")
    print(f"  length       : {excl_len:>7}")
    print(f"  compression  : {excl_ratio:>7}")
    print(f"  near_identity: {excl_jacc:>7}")
    print(f"  kept         : {excl_keep:>7}")
    print(f"  -> total removed by basic filters (exclusive): {removed_exclusive} "
          f"({100*removed_exclusive/total:5.2f}%)")

    # Keep only passed examples (single pass over indices)
    kept_idx = [i for i, ok in enumerate(flagged["keep_basic"]) if ok]
    kept = ds_split.select(kept_idx)

    # Dedup (O(n) set)
    before_dedup = len(kept)
    seen, kept2_idx = set(), []
    for i, ex in enumerate(kept):
        key = _pair_key(ex["source"], ex["target"])
        if key in seen:
            continue
        seen.add(key)
        kept2_idx.append(i)
    after_dedup = len(kept2_idx)
    dedup_removed = before_dedup - after_dedup

    print(f"\nDedup:")
    print(f"  before dedup: {before_dedup}")
    print(f"  after  dedup: {after_dedup}")
    print(f"  removed     : {dedup_removed} ({100*dedup_removed/max(1,before_dedup):5.2f}%)")

    cleaned = kept.select(kept2_idx)
    return cleaned, removed_exclusive, dedup_removed

def apply_cheap_filters_and_dedup_with_report(ds: DatasetDict) -> DatasetDict:
    """
    Apply filters + dedup with detailed prints per split.
    """
    out = {}
    for split in ds.keys():
        cleaned, removed_basic, removed_dedup = report_basic_filter_stats(ds[split], split)
        print(f"\nsummary: {len(ds[split])} -> {len(cleaned)} "
              f"(removed basic: {removed_basic}, dedup: {removed_dedup})\n")
        out[split] = cleaned
    cleaned_ds = DatasetDict(out)
    print(f"\nOriginal sizes: {ds.num_rows}")
    print(f"Cleaned sizes:  {cleaned_ds.num_rows}")
    return cleaned_ds


if os.path.exists(DRIVE_CLEAN_DATASET_OUT):
    clean_ds = load_from_disk(DRIVE_CLEAN_DATASET_OUT)
    print("Loaded ds from disk")
else:
    clean_ds = apply_cheap_filters_and_dedup_with_report(dataset)
    clean_ds.save_to_disk(DRIVE_CLEAN_DATASET_OUT)
    print(f"Saved ds to disk at `{DRIVE_CLEAN_DATASET_OUT}`")


Most of the examples fell on compression - check here different compression ratios

In [6]:
def what_if_min_comp(flagged, new_min_ratio=0.35):
    """
    Quick 'what-if' for MIN_COMP_RATIO without recomputing everything.
    Uses precomputed 'cr' from annotate_basic_flags to recalc compression pass/fail.
    Returns (would_fail_count, total_examples).
    NOTE: This is non-exclusive (ignores ordering with other rules) – intended for sensitivity only.
    """
    ratio_ok_new = [(cr >= new_min_ratio and cr <= MAX_COMP_RATIO) for cr in flagged["cr"]]
    drop_by_comp = sum([not ok for ok in ratio_ok_new])
    return drop_by_comp, len(ratio_ok_new)

flagged_train = dataset["train"].map(
    annotate_basic_flags, batched=True, batch_size=2048, desc="Annotate flags (train)"
)

drop40, total = what_if_min_comp(flagged_train, 0.40)
drop35, _     = what_if_min_comp(flagged_train, 0.35)
drop30, _     = what_if_min_comp(flagged_train, 0.30)
drop25, _     = what_if_min_comp(flagged_train, 0.25)
drop20, _     = what_if_min_comp(flagged_train, 0.20)

print(f"Compression fails if MIN_COMP_RATIO=0.40: {drop40}/{total} ({100*drop40/total:.2f}%)")
print(f"Compression fails if MIN_COMP_RATIO=0.35: {drop35}/{total} ({100*drop35/total:.2f}%)")
print(f"Compression fails if MIN_COMP_RATIO=0.30: {drop30}/{total} ({100*drop30/total:.2f}%)")
print(f"Compression fails if MIN_COMP_RATIO=0.25: {drop25}/{total} ({100*drop25/total:.2f}%)")
print(f"Compression fails if MIN_COMP_RATIO=0.20: {drop20}/{total} ({100*drop20/total:.2f}%)")


Annotate flags (train): 100%|██████████| 296402/296402 [00:04<00:00, 70856.03 examples/s]


Compression fails if MIN_COMP_RATIO=0.40: 171429/296402 (57.84%)
Compression fails if MIN_COMP_RATIO=0.35: 163750/296402 (55.25%)
Compression fails if MIN_COMP_RATIO=0.30: 155908/296402 (52.60%)
Compression fails if MIN_COMP_RATIO=0.25: 148996/296402 (50.27%)
Compression fails if MIN_COMP_RATIO=0.20: 143349/296402 (48.36%)


Visual random examples that didn't make the cut for you to check:

In [7]:
import random

def sample_removed_by_reason(orig_split, flagged_split, reason="compression", k=20):
    """
    Print k examples that were removed with a specific first-fail 'reason'
    according to annotate_basic_flags on the ORIGINAL split.
    Reasons: 'length' | 'compression' | 'near_identity' | 'keep'
    """
    idxs = [i for i, r in enumerate(flagged_split["first_fail"]) if r == reason]
    random.Random().shuffle(idxs)
    idxs = idxs[:k]
    for i in idxs:
        s = orig_split[i]["source"]
        t = orig_split[i]["target"]
        cr = flagged_split[i]["cr"]
        jac = flagged_split[i]["jac"]
        print(f"[reason={reason} cr={cr:.2f} jac={jac:.2f}]")
        print("SRC:", s)
        print("TGT:", t)
        print()

sample_removed_by_reason(dataset["train"], flagged_train, reason="compression", k=10)

[reason=compression cr=0.21 jac=0.04]
SRC: Michaelsberg Abbey or Michelsberg Abbey , also St. Michael 's Abbey , Bamberg ( or Michelsberg ) is a former Benedictine monastery in Bamberg in Bavaria , Germany .
TGT: It belonged to the bishop .

[reason=compression cr=1.00 jac=0.91]
SRC: In the 10 playoff games prior to the cancellation , VÃ zina had won six games , lost three and tied one , with one shutout .
TGT: In the ten playoff games prior to the cancellation , VÃ zina had won six games , lost three and tied one , with one shutout .

[reason=compression cr=1.05 jac=0.95]
SRC: For the next 204 years , the scientific and thermometry communities worldwide referred to this scale as the centigrade scale .
TGT: For the next 204 years , the scientific and thermometry communities worldwide referred to this scale as the â centigrade scale .

[reason=compression cr=1.47 jac=0.81]
SRC: Corbehem is a commune in the Pas-de-Calais department in the Nord-Pas-de-Calais region of France .
TGT: Corbeh

Original compress ratio vs Clean compress ration

In [8]:
import numpy as np
def comp_ratio(ds):
    return float(np.mean([len(t.split())/max(1,len(s.split())) for s,t in zip(ds["source"], ds["target"])]))

print(f"CR train orig : {comp_ratio(dataset['train']):.2f}")
print(f"CR train clean: {comp_ratio(clean_ds['train']):.2f}")

CR train orig : 0.88
CR train clean: 0.70


Average text length

In [9]:
def avg_len(texts):
    return float(np.mean([len(x.split()) for x in texts]))
print(f"              ( orig → clean)")
print(f"source length: {avg_len(dataset['train']['source']):.2f} → {avg_len(clean_ds['train']['source']):.2f}")
print(f"target length: {avg_len(dataset['train']['target']):.2f} → {avg_len(clean_ds['train']['target']):.2f}")

              ( orig → clean)
source length: 25.17 → 26.49
target length: 18.51 → 18.29
