# Cleanup for OASST1

In [1]:
from pathlib import Path
import json, ast
import numpy as np
import pandas as pd
from tqdm import tqdm

ROOT = Path("./")
TRAIN_PARQUET = ROOT / "train_parquet"
VALID_PARQUET = ROOT / "valid_parquet"

TRAIN_IN = TRAIN_PARQUET / "oasst1_train_1.parquet"
VALID_IN = VALID_PARQUET / "oasst1_valid_1.parquet"

OUT_TRAIN_DIR = ROOT / "train_jsonl"
OUT_VALID_DIR = ROOT / "valid_jsonl"
OUT_TRAIN_DIR.mkdir(parents=True, exist_ok=True)
OUT_VALID_DIR.mkdir(parents=True, exist_ok=True)

TRAIN_JSONL = OUT_TRAIN_DIR / "oasst1_train.jsonl"
VALID_JSONL = OUT_VALID_DIR / "oasst1_valid.jsonl"

tqdm.pandas()


followings are the predefined evaluation for each question/answer <br/>
since we've got multiples of answer for each question, we'll going to score those the the below factor <br/>
then choose the answer with highest score

In [2]:
W = {
    "severe_toxicity": 6.0,
    "threat": 5.0,
    "identity_attack": 4.0,
    "toxicity": 3.0,
    "insult": 2.0,
    "obscene": 1.0,
    "sexual_explicit": 1.0,
}
METRIC_KEYS = list(W.keys())


In [3]:
def maybe_parse_obj(x):
    if isinstance(x, (dict, list)):
        return x
    if isinstance(x, str):
        s = x.strip()
        if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
            try:
                return json.loads(s)
            except Exception:
                try:
                    return ast.literal_eval(s)
                except Exception:
                    return x
    return x

def normalize_role(x):
    return str(x).strip().lower() if isinstance(x, str) else None

def pick_text_column(df):
    for c in ("text", "content", "body", "message", "reply_text"):
        if c in df.columns:
            return c
    raise RuntimeError("No text/content column found (looked for: text, content, body, message, reply_text).")

def extract_detoxify_metrics(df):
    out = pd.DataFrame(index=df.index)
    has_nested = "detoxify" in df.columns
    nested = df["detoxify"].progress_apply(maybe_parse_obj) if has_nested else None

    for k in METRIC_KEYS:
        col_flat = f"detoxify.{k}"
        if col_flat in df.columns:
            vals = pd.to_numeric(df[col_flat], errors="coerce")
        elif has_nested:
            vals = nested.progress_apply(lambda d: (d.get(k) if isinstance(d, dict) else np.nan))
            vals = pd.to_numeric(vals, errors="coerce")
        elif k in df.columns:
            vals = pd.to_numeric(df[k], errors="coerce")
        else:
            vals = pd.Series(np.nan, index=df.index)
        out[k] = vals.astype(float).clip(lower=0.0, upper=1.0)

    risk = None
    for k, w in W.items():
        part = out[k] * w
        risk = part if risk is None else (risk + part)
    out["risk"] = risk
    return out

also, it includes multiples of question/answer from different language <br/>
since our pretrained model is trained with english specific <br/>
we'll going to filter only the one with `en` tag then process those scoring, etc.

In [4]:
def process_oasst_file(in_path: Path, out_jsonl: Path,
                       lang_keep: str = "en",
                       only_root_prompts: bool = True,
                       require_answer_lang_match: bool = True):
    print(f"\n=== Processing {in_path.name} ===")
    df = pd.read_parquet(in_path)

    # schema
    id_col = "message_id" if "message_id" in df.columns else ("id" if "id" in df.columns else None)
    parent_col = "parent_id"  if "parent_id"  in df.columns else ("reply_to" if "reply_to" in df.columns else None)
    role_col = "role"       if "role"       in df.columns else None
    text_col = pick_text_column(df)
    if id_col is None or parent_col is None or role_col is None:
        raise RuntimeError(f"Missing required columns. Found: {list(df.columns)}")

    # roles + risk
    df["_role"] = df[role_col].map(normalize_role)
    detox_df = extract_detoxify_metrics(df)
    df = pd.concat([df, detox_df], axis=1)

    # language filters
    if "lang" in df.columns and lang_keep:
        df["lang_norm"] = df["lang"].astype(str).str.lower()
    else:
        df["lang_norm"] = None

    # users: prompter/user (optionally only those with parent_id null)
    is_user = df["_role"].isin(["prompter", "user"])
    if only_root_prompts:
        # NaN/None parent_id means root
        users_mask = is_user & df[parent_col].isna()
    else:
        users_mask = is_user

    if lang_keep:
        users_mask = users_mask & (df["lang_norm"] == lang_keep)

    users = df.loc[users_mask, [id_col, text_col, "lang_norm"]].rename(
        columns={id_col: "uid", text_col: "user", "lang_norm": "user_lang"}
    )

    # assistants: direct children of those user uids
    is_asst = df["_role"].isin(["assistant", "assistant_reply", "assistant_response", "assistant_bot"])
    asst = df.loc[is_asst, [id_col, parent_col, text_col, "risk", "lang_norm"]].rename(
        columns={id_col: "aid", parent_col: "uid", text_col: "assistant", "lang_norm": "asst_lang"}
    )

    # keep only replies to our selected users
    asst = asst.merge(users[["uid"]], on="uid", how="inner")

    # require assistant language == en as well?
    if require_answer_lang_match and lang_keep:
        asst = asst[asst["asst_lang"] == lang_keep]

    # drop empty assistant text
    asst = asst[pd.notna(asst["assistant"]) & (asst["assistant"].astype(str).str.len() > 0)]

    # missing metrics → +inf so they lose if safer reply exists
    asst["risk"] = pd.to_numeric(asst["risk"], errors="coerce")
    asst["risk"] = asst["risk"].where(pd.notna(asst["risk"]), np.inf)

    # min-risk per uid
    if len(asst) == 0 or len(users) == 0:
        print("No candidates after filtering.")
        # still emit an empty file
        out_jsonl.write_text("", encoding="utf-8")
        return users.assign(aid=None, assistant=None, risk=np.nan)

    idx_min = asst.groupby("uid")["risk"].idxmin()
    best = asst.loc[idx_min].copy()

    pairs = users.merge(best[["uid", "aid", "assistant", "risk"]], on="uid", how="inner")

    # write JSONL
    pairs_out = pairs[["user", "assistant"]]
    with out_jsonl.open("w", encoding="utf-8") as f:
        for _, r in tqdm(pairs_out.iterrows(), total=len(pairs_out), desc=f"Writing {out_jsonl.name}"):
            json.dump({"user": r["user"], "assistant": r["assistant"]}, f, ensure_ascii=False)
            f.write("\n")

    print({
        "users_kept": len(users),
        "assistant_candidates": len(asst),
        "final_pairs": len(pairs),
        "risk_non_nan_frac": float(pairs["risk"].notna().mean()) if len(pairs) else 0.0,
    })
    return pairs, df, users, asst


In [5]:
pairs_train, df_train, users_train, asst_train = process_oasst_file(
    TRAIN_IN, TRAIN_JSONL, lang_keep="en", only_root_prompts=True, require_answer_lang_match=True
)
display(pairs_train[["user", "assistant", "risk"]].head(5))

pairs_valid, df_valid, users_valid, asst_valid = process_oasst_file(
    VALID_IN, VALID_JSONL, lang_keep="en", only_root_prompts=True, require_answer_lang_match=True
)
display(pairs_valid[["user", "assistant", "risk"]].head(5))



=== Processing oasst1_train_1.parquet ===


100%|██████████| 84437/84437 [00:00<00:00, 1888440.99it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1785367.61it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1790232.06it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1698305.54it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1812635.04it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1746332.31it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1807832.81it/s]
100%|██████████| 84437/84437 [00:00<00:00, 1736979.30it/s]
Writing oasst1_train.jsonl: 100%|██████████| 3482/3482 [00:00<00:00, 20278.88it/s]

{'users_kept': 3482, 'assistant_candidates': 10894, 'final_pairs': 3482, 'risk_non_nan_frac': 1.0}





Unnamed: 0,user,assistant,risk
0,Can you write a short introduction about the r...,"""Monopsony"" refers to a market structure where...",0.00201
1,Can you explain contrastive learning in machin...,Sure! Let's say you want to build a model whic...,0.002268
2,I want to start doing astrophotography as a ho...,Getting started in astrophotography can seem d...,0.002741
3,Can you give me an example of a python script ...,Sure! Here's an example Python script that use...,0.002181
4,How can I learn to optimize my webpage for sea...,Learning to optimize your webpage for search e...,0.002405



=== Processing oasst1_valid_1.parquet ===


100%|██████████| 4401/4401 [00:00<00:00, 2187878.62it/s]
100%|██████████| 4401/4401 [00:00<00:00, 1516773.37it/s]
100%|██████████| 4401/4401 [00:00<00:00, 663091.17it/s]
100%|██████████| 4401/4401 [00:00<00:00, 994672.48it/s]
100%|██████████| 4401/4401 [00:00<00:00, 1720649.88it/s]
100%|██████████| 4401/4401 [00:00<00:00, 1062212.68it/s]
100%|██████████| 4401/4401 [00:00<00:00, 1432272.80it/s]
100%|██████████| 4401/4401 [00:00<00:00, 1161024.71it/s]
Writing oasst1_valid.jsonl: 100%|██████████| 188/188 [00:00<00:00, 17307.87it/s]

{'users_kept': 188, 'assistant_candidates': 568, 'final_pairs': 188, 'risk_non_nan_frac': 1.0}





Unnamed: 0,user,assistant,risk
0,What do you think about ChatGPT?,"As an open source alternative to ChatGPT, I do...",0.001962
1,Can you please provide me the names of the two...,The atomic Bomb Go game took place in the subu...,0.003174
2,How would the Future of AI in 10 Years look?,"I am a Large Language Model (LLM), so I do not...",0.00194
3,How can L’Hopital’s Rule be used in day to day...,I'm always thrilled to talk about how math can...,0.002137
4,How do I build a PC?,Here are the steps you can follow to build a P...,0.002322


In [6]:
def inspect_prompt(df, users_df, asst_df, uid: str, id_col="message_id", parent_col="parent_id", text_col=None):
    if text_col is None:
        text_col = pick_text_column(df)

    print("\n— USER —")
    urow = df.loc[df[id_col] == uid, [id_col, parent_col, "role", "lang", text_col]]
    display(urow)

    print("\n— ASSISTANT CANDIDATES (sorted by risk) —")
    cands = asst_df.loc[asst_df["uid"] == uid].sort_values("risk")[["aid","uid","risk","asst_lang","assistant"]]
    display(cands.head(10))
    
inspect_prompt(df_train, users_train, asst_train, uid="6ab24d72-0181-4594-a9cd-deaf170242fb")



— USER —


Unnamed: 0,message_id,parent_id,role,lang,text
0,6ab24d72-0181-4594-a9cd-deaf170242fb,,prompter,en,Can you write a short introduction about the r...



— ASSISTANT CANDIDATES (sorted by risk) —


Unnamed: 0,aid,uid,risk,asst_lang,assistant
0,c8e83833-ecbc-44fe-b6db-735228c25a1c,6ab24d72-0181-4594-a9cd-deaf170242fb,0.00201,en,"""Monopsony"" refers to a market structure where..."
2,636dd191-50df-4894-ba9a-cd7f00767258,6ab24d72-0181-4594-a9cd-deaf170242fb,0.002252,en,Monopsony refers to a market structure where t...
1,343ee2d4-87ae-41fd-a768-bdd65959dc4a,6ab24d72-0181-4594-a9cd-deaf170242fb,0.006789,en,Monopsony is a market structure in which there...
