In [1]:
import os
import json
import gzip
import pickle
import re
from pathlib import Path
from typing import Optional, Dict, Any, List

from datasets import load_dataset
from tqdm import tqdm

from dolma import BaseTagger, add_tagger
from dolma.core.data_types import DocResult, Document, Span


OUTPUT_MIXED_DIR = Path("data/output/mixed")
OUTPUT_MIXED_DIR.mkdir(parents=True, exist_ok=True)

PROGRESS_PATH = Path("data/state/c4_progress.json")
PROGRESS_PATH.parent.mkdir(parents=True, exist_ok=True)

SHARD_SIZE = 50_000

  from .autonotebook import tqdm as notebook_tqdm


### Utility Functions

In [2]:
def _safe_here() -> Path:
    """Return a reasonable 'here' directory, even in notebooks/REPL."""
    if "__file__" in globals():
        return Path(__file__).resolve()
    return Path.cwd()


def _find_pickle(explicit: Optional[str]) -> Path:
    """
    Find the aapiGroups.pkl file.

    Search order:
      1. Explicit path argument
      2. Env var: AAPI_KEYWORDS_PICKLE
      3. Common repo-relative locations
    """
    tried: List[str] = []

    def _check(path: Path) -> Optional[Path]:
        tried.append(str(path))
        if path.exists():
            return path.resolve()
        return None

    if explicit:
        p = Path(explicit).expanduser().resolve()
        found = _check(p)
        if found:
            return found

    env_path = os.environ.get("AAPI_KEYWORDS_PICKLE")
    if env_path:
        p = Path(env_path).expanduser().resolve()
        found = _check(p)
        if found:
            return found

    here = _safe_here()
    repo_root = here.parents[2] if len(here.parents) >= 3 else here

    candidates = [
        repo_root / "data" / "aapiGroups.pkl",
        here.parent / "data" / "aapiGroups.pkl",
        repo_root / "utils" / "data" / "aapiGroups.pkl",
        Path.cwd() / "data" / "aapiGroups.pkl",
    ]

    for cand in candidates:
        found = _check(cand)
        if found:
            return found

    raise FileNotFoundError(
        "Could not locate aapiGroups.pkl. Searched:\n  - "
        + "\n  - ".join(tried)
        + "\nTip: set env var AAPI_KEYWORDS_PICKLE=/abs/path/to/aapiGroups.pkl "
        "or pass keyword_pickle=... when constructing AAPIKeywordsTagger."
    )

### Dolma Tagger

In [3]:
@add_tagger("aapi_keywords_v1")
class AAPIKeywordsTagger(BaseTagger):
    """
    Tags documents that contain any AAPI-related keyword loaded from a pickle.
    """

    def __init__(self, keyword_pickle: Optional[str] = None) -> None:
        super().__init__()

        self.keyword_pickle = _find_pickle(keyword_pickle)

        with self.keyword_pickle.open("rb") as f:
            raw_terms = pickle.load(f)

        terms = [str(t).strip().lower() for t in list(raw_terms)]
        if not terms:
            raise ValueError(f"No terms found in {self.keyword_pickle}")

        pattern = r"\b(" + "|".join(re.escape(t) for t in terms) + r")\b"
        self.regex = re.compile(pattern, flags=re.IGNORECASE)

    def predict(self, doc: Document) -> DocResult:
        text = doc.text or ""
        matches = self.regex.findall(text)
        if not matches:
            # no matches → score 0
            span = Span(start=0, end=0, type="aapi_keyword", score=0.0)
            return DocResult(doc=doc, spans=[span])

        # unique matches → score is count of unique AAPI terms present
        unique_matches = {m.lower() for m in matches}
        score = float(len(unique_matches))

        span = Span(
            start=0,
            end=len(text),
            type="aapi_keyword",
            score=score,
        )
        return DocResult(doc=doc, spans=[span])

### Dolma Mixer

In [4]:
def mix_aapi_doc(result: DocResult) -> Optional[Dict[str, Any]]:
    """
    Given a DocResult from AAPIKeywordsTagger, return a JSON-serializable dict
    if score > 0, else return None (filter out the doc).
    """
    doc = result.doc
    spans = result.spans or []

    score = 0.0
    if spans:
        score = float(spans[0].score or 0.0)

    if score <= 0.0:
        return None

    return {
        "id": doc.id,
        "text": doc.text,
        "source": getattr(doc, "source", None),
        "aapi_score": score,
        "aapi_spans": [
            {
                "start": s.start,
                "end": s.end,
                "type": s.type,
            }
            for s in spans
        ],
    }

### Process Helpers

In [None]:
def save_progress(count: int) -> None:
    """
    Save how many C4 docs we've fully traversed/processed (for resume).
    """
    PROGRESS_PATH.parent.mkdir(parents=True, exist_ok=True)
    with PROGRESS_PATH.open("w", encoding="utf-8") as f:
        json.dump({"c4_docs_done": count}, f)


def load_progress() -> int:
    """
    Load how many C4 docs we've previously processed. Returns 0 if none.
    """
    if PROGRESS_PATH.exists():
        with PROGRESS_PATH.open("r", encoding="utf-8") as f:
            data = json.load(f)
        return int(data.get("c4_docs_done", 0))
    return 0


def open_new_shard(out_dir: Path, shard_idx: int) -> gzip.GzipFile:
    """
    Open a new gzip'd JSONL shard for writing mixed docs.
    """
    out_dir.mkdir(parents=True, exist_ok=True)
    shard_path = out_dir / f"mixed.{shard_idx:09d}.jsonl.gz"
    return gzip.open(shard_path, "wt", encoding="utf-8")

### Actual Process

In [6]:
def run_loop(MAX_DOCS: int = 10_000, out_dirname: Path = OUTPUT_MIXED_DIR) -> None:
    """
    Stream Dolma C4 docs, tag them with AAPIKeywordsTagger, filter them via
    mix_aapi_doc, and write the kept ones into sharded .jsonl.gz files.

    Resumable via PROGRESS_PATH (counts how many C4 docs have been processed).
    """
    pre_computed = load_progress()
    print(f"Resuming from {pre_computed} pre-computed C4 documents.")

    kept = 0
    traversed = 0
    processed = pre_computed

    tagger = AAPIKeywordsTagger()
    ds = load_dataset(
        "allenai/dolma",
        "v1_7",
        split="train",
        streaming=True,
        trust_remote_code=True,
    )

    shard_idx = 0
    docs_in_shard = 0
    out_f: Optional[gzip.GzipFile] = None

    pbar = tqdm(
        total=MAX_DOCS,
        initial=0,
        desc="Processing C4 docs",
    )

    actual_shard_size = min(SHARD_SIZE, MAX_DOCS)
    for data in ds:
        if data.get("source") != "c4":
            continue

        if traversed < pre_computed:
            traversed += 1
            continue

        doc = Document(
            id=data["id"],
            text=data["text"],
            source=data.get("source"),
        )

        tagged = tagger.predict(doc)
        mixed = mix_aapi_doc(tagged)

        traversed += 1
        processed += 1
        pbar.update(1)
        
        if mixed is not None:
            if out_f is None or docs_in_shard >= actual_shard_size:
                if out_f is not None:
                    out_f.close()
                out_f = open_new_shard(out_dirname, shard_idx)
                shard_idx += 1
                docs_in_shard = 0

            out_f.write(json.dumps(mixed, ensure_ascii=False) + "\n")
            docs_in_shard += 1
            kept += 1
            

        if processed % 1000 == 0:
            save_progress(traversed)
        if MAX_DOCS and docs_in_shard >= MAX_DOCS:
            break

    pbar.close()
    if out_f is not None:
        out_f.close()

    save_progress(traversed)


### Running the code

In [None]:
# Creates 100 files of SHARD_SIZE each
run_loop(SHARD_SIZE * 800)

Resuming from 0 pre-computed C4 documents.


Processing C4 docs:   0%|          | 300/40000000 [13:31<16774:31:04,  1.51s/it] 

Opening new shard: data/output/mixed/mixed.000000000.jsonl.gz


Processing C4 docs:   2%|▏         | 941898/40000000 [34:01<7:30:08, 1446.16it/s] 

Opening new shard: data/output/mixed/mixed.000000001.jsonl.gz


Processing C4 docs:   5%|▍         | 1867478/40000000 [1:36:48<5:19:56, 1986.45it/s]  

Opening new shard: data/output/mixed/mixed.000000002.jsonl.gz


Processing C4 docs:   7%|▋         | 2799176/40000000 [2:35:35<7:01:40, 1470.37it/s]  

Opening new shard: data/output/mixed/mixed.000000003.jsonl.gz


Processing C4 docs:   9%|▉         | 3730963/40000000 [3:21:21<5:16:20, 1910.83it/s]  

Opening new shard: data/output/mixed/mixed.000000004.jsonl.gz


Processing C4 docs:  12%|█▏        | 4658978/40000000 [4:01:06<5:33:35, 1765.71it/s]  

Opening new shard: data/output/mixed/mixed.000000005.jsonl.gz


Processing C4 docs:  13%|█▎        | 5277929/40000000 [4:38:47<5:09:24, 1870.37it/s]  Got disconnected from remote data host. Retrying in 5sec [1/20]
Processing C4 docs:  14%|█▍        | 5597273/40000000 [5:07:07<6:36:24, 1446.46it/s]  

Opening new shard: data/output/mixed/mixed.000000006.jsonl.gz


Processing C4 docs:  16%|█▋        | 6530375/40000000 [6:43:22<1162:24:54,  8.00it/s] 

Opening new shard: data/output/mixed/mixed.000000007.jsonl.gz


Processing C4 docs:  19%|█▊        | 7456856/40000000 [6:56:04<6:36:25, 1368.19it/s] 

Opening new shard: data/output/mixed/mixed.000000008.jsonl.gz


Processing C4 docs:  20%|█▉        | 7936651/40000000 [7:23:24<4:41:24, 1898.95it/s] Got disconnected from remote data host. Retrying in 5sec [1/20]
Processing C4 docs:  21%|██        | 8387163/40000000 [8:12:37<4:51:06, 1809.87it/s]  

Opening new shard: data/output/mixed/mixed.000000009.jsonl.gz


Processing C4 docs:  23%|██▎       | 9322062/40000000 [8:46:24<4:17:38, 1984.52it/s]  

Opening new shard: data/output/mixed/mixed.000000010.jsonl.gz


Processing C4 docs:  25%|██▌       | 10030754/40000000 [8:55:13<5:32:32, 1502.04it/s]