# Cleanup AI2

In [None]:
from pathlib import Path
import json, re
from typing import Any, Dict, Iterable, Tuple, List

from tqdm.auto import tqdm

try:
    import pyarrow.parquet as pq
    import pyarrow as pa
    HAVE_PA = True
except Exception:
    HAVE_PA = False

try:
    import pandas as pd
    HAVE_PD = True
except Exception:
    HAVE_PD = False

BASE = Path("ai2")
TRAIN_GLOB = "ai2_train_*.parquet"
VALID_GLOB = "ai2_valid_*.parquet"

OUT_TRAIN = Path("./train_jsonl/ai2_train.jsonl")
OUT_VALID = Path("./valid_jsonl/ai2_valid.jsonl")

BATCH_SIZE = 8192
DO_DEDUP = True

In [None]:
def to_clean_str(x: Any) -> str:
    return str(x).strip()

# parse answerKey
# returns (status, label) where status ∈ {"ok","multi","missing","bad"}
# - "ok": single label like "A".."Z" (case-insensitive)
# - "multi": multiple labels present (e.g., "A,B" or ["A","B"])
# - "missing": None or empty
# - "bad": unparseable
def parse_answer_key(raw) -> Tuple[str, str | None]:
    if raw is None:
        return "missing", None

    # list/tuple/set -> check cardinality
    if isinstance(raw, (list, tuple, set)):
        vals = [to_clean_str(x).upper() for x in raw if to_clean_str(x)]
        labs = [re.match(r"^[A-Z]$", v).group(0) for v in vals if re.match(r"^[A-Z]$", v)]
        uniq = sorted(set(labs))
        if len(uniq) == 0:
            return "bad", None
        if len(uniq) > 1:
            return "multi", None
        return "ok", uniq[0]

    s = to_clean_str(raw).upper()
    if not s:
        return "missing", None

    # extract letter tokens; treat more than one as multi
    letters = re.findall(r"[A-Z]", s)
    uniq = sorted(set(letters))
    if len(uniq) == 0:
        return "bad", None
    if len(uniq) > 1:
        return "multi", None
    return "ok", uniq[0]


def extract_pair_strict(obj: Dict[str, Any]) -> Tuple[str, str] | Tuple[None, None, Dict[str,int]]:
    counters = {"missing_fields": 0, "bad_choices": 0, "num_multi": 0, "key_not_found": 0}

    if not all(k in obj for k in ("question", "choices", "answerKey")):
        counters["missing_fields"] += 1
        return None, None, counters

    q = to_clean_str(obj["question"])
    if not q:
        counters["missing_fields"] += 1
        return None, None, counters

    choices = obj.get("choices", {})
    labels = choices.get("label")
    texts  = choices.get("text")
    if not isinstance(labels, list) or not isinstance(texts, list) or len(labels) != len(texts) or len(labels) == 0:
        counters["bad_choices"] += 1
        return None, None, counters

    mapping = {}
    for lab, txt in zip(labels, texts):
        lab = to_clean_str(lab).upper()
        txt = to_clean_str(txt)
        if not re.match(r"^[A-Z]$", lab) or not txt:
            counters["bad_choices"] += 1
            return None, None, counters
        if lab in mapping:
            # duplicate label -> treat as bad to avoid ambiguity
            counters["bad_choices"] += 1
            return None, None, counters
        mapping[lab] = txt

    status, key = parse_answer_key(obj["answerKey"])
    if status == "multi":
        counters["num_multi"] += 1
        return None, None, counters
    if status != "ok" or key not in mapping:
        counters["key_not_found"] += 1
        return None, None, counters

    user = q
    assistant = mapping[key]
    return user, assistant, counters

def build_jsonl_from_parquets(in_paths: List[Path], out_path: Path) -> Dict[str, int]:
    """
    Convert AI2 parquet parts → normalized JSONL with progress bars.
    STRICT as per extract_pair_strict().
    """
    out_path.parent.mkdir(parents=True, exist_ok=True)

    total_rows = None
    if HAVE_PA:
        try:
            total_rows = sum(pq.ParquetFile(p).metadata.num_rows for p in in_paths)
        except Exception:
            total_rows = None

    kept = deduped = total_seen = 0
    c_missing = c_bad_choices = c_multi = c_key_missing = 0

    seen = set() if DO_DEDUP else None

    desc = f"writing {out_path.name}"
    with out_path.open("w", encoding="utf-8") as f, tqdm(total=total_rows, unit="rows", desc=desc, leave=False) as pbar:

        if HAVE_PA:
            for p in in_paths:
                pf = pq.ParquetFile(p)
                for batch in pf.iter_batches(batch_size=BATCH_SIZE):
                    rows = pa.Table.from_batches([batch]).to_pylist()
                    for rec in rows:
                        total_seen += 1
                        res = extract_pair_strict(rec)
                        if res[0] is None:
                            # accumulate deltas
                            d = res[2]
                            c_missing     += d.get("missing_fields", 0)
                            c_bad_choices += d.get("bad_choices", 0)
                            c_multi       += d.get("num_multi", 0)
                            c_key_missing += d.get("key_not_found", 0)
                        else:
                            user, assistant, _ = res
                            if DO_DEDUP:
                                h = (user, assistant)
                                if h in seen:
                                    deduped += 1
                                    continue
                                seen.add(h)
                            f.write(json.dumps({"user": user, "assistant": assistant}, ensure_ascii=False) + "\n")
                            kept += 1
                    pbar.update(len(rows))
                    pbar.set_postfix(written=kept, deduped=deduped, missing=c_missing,
                                     bad_choices=c_bad_choices, multi=c_multi, key_miss=c_key_missing)
        elif HAVE_PD:
            for p in in_paths:
                df = pd.read_parquet(p)
                recs = df.to_dict(orient="records")
                for rec in tqdm(recs, unit="rows", desc=desc, leave=False):
                    total_seen += 1
                    res = extract_pair_strict(rec)
                    if res[0] is None:
                        d = res[2]
                        c_missing     += d.get("missing_fields", 0)
                        c_bad_choices += d.get("bad_choices", 0)
                        c_multi       += d.get("num_multi", 0)
                        c_key_missing += d.get("key_not_found", 0)
                    else:
                        user, assistant, _ = res
                        if DO_DEDUP:
                            h = (user, assistant)
                            if h in seen:
                                deduped += 1
                                continue
                            seen.add(h)
                        f.write(json.dumps({"user": user, "assistant": assistant}, ensure_ascii=False) + "\n")
                        kept += 1
        else:
            raise RuntimeError("Neither pyarrow nor pandas is available to read parquet.")

    return {
        "total_rows_in_files": int(total_rows) if total_rows is not None else total_seen,
        "seen_rows": total_seen,
        "written": kept,
        "deduped": deduped,
        "dropped_missing_fields": c_missing,
        "dropped_bad_choices": c_bad_choices,
        "num_multi": c_multi,
        "key_not_found_or_bad": c_key_missing,
    }

In [None]:
train_parts = sorted(BASE.glob(TRAIN_GLOB), key=lambda p: p.name)
valid_parts = sorted(BASE.glob(VALID_GLOB), key=lambda p: p.name)

print("Train parts:")
for p in train_parts: print(" -", p)
print("\nValid parts:")
for p in valid_parts: print(" -", p)

train_stats = build_jsonl_from_parquets(train_parts, OUT_TRAIN)
valid_stats = build_jsonl_from_parquets(valid_parts, OUT_VALID)

print("\n--- Done ---")
print("Train:", train_stats)
print("Valid:", valid_stats)

Train parts:
 - ai2\ai2_train_1.parquet
 - ai2\ai2_train_2.parquet

Valid parts:
 - ai2\ai2_valid_1.parquet
 - ai2\ai2_valid_2.parquet


writing ai2_train.jsonl:   0%|          | 0/3370 [00:00<?, ?rows/s]

writing ai2_valid.jsonl:   0%|          | 0/869 [00:00<?, ?rows/s]


--- Done ---
Train: {'total_rows_in_files': 3370, 'seen_rows': 3370, 'written': 3244, 'deduped': 1, 'dropped_missing_fields': 0, 'dropped_bad_choices': 125, 'num_multi': 0, 'key_not_found_or_bad': 0}
Valid: {'total_rows_in_files': 869, 'seen_rows': 869, 'written': 844, 'deduped': 0, 'dropped_missing_fields': 0, 'dropped_bad_choices': 25, 'num_multi': 0, 'key_not_found_or_bad': 0}


In [9]:
def peek_jsonl(path: Path, k: int = 3):
    print(f"\n--- {path.name} (first {k}) ---")
    with path.open("r", encoding="utf-8") as f:
        for i, ln in enumerate(f):
            if i >= k: break
            print(ln.rstrip())

peek_jsonl(OUT_TRAIN, 3)
peek_jsonl(OUT_VALID, 3)


--- ai2_train.jsonl (first 3) ---
{"user": "Which factor will most likely cause a person to develop a fever?", "assistant": "a bacterial population in the bloodstream"}
{"user": "Lichens are symbiotic organisms made of green algae and fungi. What do the green algae supply to the fungi in this symbiotic relationship?", "assistant": "food"}
{"user": "When a switch is used in an electrical circuit, the switch can", "assistant": "stop and start the flow of current."}

--- ai2_valid.jsonl (first 3) ---
{"user": "Which technology was developed most recently?", "assistant": "cellular telephone"}
{"user": "A student hypothesizes that algae are producers. Which question will best help the student determine if this is correct?", "assistant": "Do algae use sunlight to make food?"}
{"user": "Soccer players use their muscle systems to kick a ball into a goal. What organ system coordinates the muscles?", "assistant": "The nervous system"}


In [10]:
def sanity_check_jsonl(path: Path, max_lines: int | None = None):
    n = bad_json = nonstring = short_user = short_asst = 0
    with path.open("r", encoding="utf-8") as f:
        for i, ln in enumerate(tqdm(f, unit="lines", desc=f"scan {path.name}", leave=False)):
            if (max_lines is not None) and (i >= max_lines): break
            try:
                obj = json.loads(ln)
            except Exception:
                bad_json += 1; continue
            u, a = obj.get("user",""), obj.get("assistant","")
            if not isinstance(u, str) or not isinstance(a, str):
                nonstring += 1; continue
            if len(u) < 5: short_user += 1
            if len(a) < 1: short_asst += 1
            n += 1
    print(f"{path.name}: scanned {n} | bad_json={bad_json} nonstring={nonstring} short_user(<5)={short_user} short_asst(<1)={short_asst}")

sanity_check_jsonl(OUT_TRAIN)
sanity_check_jsonl(OUT_VALID)

scan ai2_train.jsonl: 0lines [00:00, ?lines/s]

ai2_train.jsonl: scanned 3244 | bad_json=0 nonstring=0 short_user(<5)=0 short_asst(<1)=0


scan ai2_valid.jsonl: 0lines [00:00, ?lines/s]

ai2_valid.jsonl: scanned 844 | bad_json=0 nonstring=0 short_user(<5)=0 short_asst(<1)=0
