
# XDetox Pipeline
This notebook mirrors the `lab.py` pipeline with the following changes:

- **Small-batch runs**: choose how many examples to process.
- **Dataset picker**: run a single dataset or **all**.
- **Optional evaluation** (BLEU / BERTScore / Perplexity / MeaningBERT / Toxicity).
- **Same output layout** as `lab.py`.
- **Reranking** toggle.

> **Prereqs**: You have the `XDetox` repo available on Drive, or cloned locally in Colab. Point `XDETOX_DIR` below.


In [1]:
#@title Mount Drive & locate XDetox
from google.colab import drive; drive.mount('/content/drive')
import os, glob, re, sys, torch, json, shutil, math, nltk
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE

# Try My Drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"



Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox


In [2]:
#@title Runtime setup (paths, cache, GPU)
# HuggingFace cache inside the repo (persists on Drive)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

# Add repo to PYTHONPATH
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("XDETOX_DIR:", XDETOX_DIR)
print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


In [3]:
#@title Install dependencies (restart runtime if warnings/errors appear)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi

# BERTScore dependency required by evaluate/datasets
!pip -q install bert-score

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m77.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m76.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.3.1 which is incompatible.[0m[

In [4]:

#@title NLTK data
nltk.download("punkt", quiet=True)
# Some Colab images need this table; ignore if unavailable:
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")



NLTK ready


In [5]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")


Repo folders OK.


In [6]:

#@title Data configs (same as lab.py)
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("Datasets:", ", ".join(data_configs.keys()))


Datasets: microagressions_val, microagressions_test, sbf_val, sbf_test, dynabench_val, dynabench_test, jigsaw_toxic, paradetox, appdia_original, appdia_discourse


### Helpers: subset data, call rewrite, run evaluation

In [7]:
REPO = XDETOX_DIR

def _abs_repo_path(rel):
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """Create a small subset file matching the expected format used by rewrite_example.get_data().
    Returns the path to the *new* subset file (or original path if n is None).
    """
    if n is None or n <= 0:
        return data_path  # no subset

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        cols = df.columns.tolist()
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        # txt file
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        # tsv
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    # Fallback: copy original
    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    import sys, os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")  # note: now _stats.txt
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("Eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()


def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    """Read gen_stats_notox.txt into a dict of floats; tolerate '(skipped): None'."""
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            # normalize keys like 'perplexity orig (skipped)' -> 'perplexity orig'
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []
    for thresh in np.arange(0.15, 0.3, 0.05, dtype=np.float64):
        mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
        base_path = os.path.join(base_out_dir, data_type, mask_dir)
        if not os.path.isdir(base_path):
            continue
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),   # NEW
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
                # Optional: you can also keep the percent toxic columns
                # "percent_toxic_gen": s.get("percent toxic gen", np.nan),
                # "percent_toxic_ref": s.get("percent toxic ref", np.nan),
            })

    if rows:
        # Column order: MeaningBERT between BERTScore and BLEU4
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
            # "percent_toxic_gen", "percent_toxic_ref",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")


### `detoxify()` — run masking + generation (+ optional eval)

In [8]:

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run",
    thresholds = (0.15, 0.20, 0.25),
    batch_size: int = 10,
    ranking: bool = True,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,   # if None, take from data_configs
    alpha_e: float = None,   # if None, take from data_configs
    alpha_b: float = 1.0,
    temperature: float = None,  # if None, from data_configs
    rep_penalty: float = None,  # if None, from data_configs
    num_examples: int = 100,    # small-batch control; set None to use full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
):
    """Run the XDetox pipeline similarly to lab.py but with small-batch support and optional eval (no toxicity)."""
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    # fallbacks from data_configs
    if alpha_a is None: alpha_a = cfg["alpha_a"]
    if alpha_e is None: alpha_e = cfg["alpha_e"]
    if temperature is None: temperature = cfg["temperature"]
    if rep_penalty is None: rep_penalty = cfg["rep_penalty"]

    base_out_dir = os.path.join("data", "dexp_outputs", output_folder)
    abs_base_out_dir = os.path.join(REPO, base_out_dir)
    _ensure_dir(abs_base_out_dir)

    # small subset path creation
    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(data_type, original_data_path, num_examples, subset_dir)

    # run thresholds
    for t in thresholds:
        mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
        mix_out_dir = os.path.join(abs_base_out_dir, data_type, mask_dir)
        _ensure_dir(mix_out_dir)

        # Build command to call rewrite module
        cmd = [
            sys.executable, "-m", "rewrite.rewrite_example",
            "--output_dir", base_out_dir,           # relative to REPO
            "--data_type", data_type,
            "--data_path", subset_path.replace(REPO + "/", "./"),
            "--rep_penalty", str(rep_penalty),
            "--alpha_a", str(alpha_a),
            "--alpha_e", str(alpha_e),
            "--temperature", str(temperature),
            "--alpha_b", str(alpha_b),
            "--max_length", str(max_length),
            "--batch_size", str(batch_size),
            "--top_k_gen", str(top_k_gen),
            "--top_p", str(top_p),
            "--filter_p", str(filter_p),
            "--thresh", f"{t:.2f}",
        ]
        if ranking:
            cmd.append("--ranking")
        if sample:
            cmd.append("--sample")
        if overwrite_gen:
            cmd.append("--overwrite_gen")

        print("Run:", " ".join(cmd))
        run(cmd, cwd=REPO, check=True)

        # Optional evaluation (no toxicity)
        if run_eval:
            base_path = os.path.join(abs_base_out_dir, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,      # you can parameterize this
                tox_batch_size=32
            )

    # Summarize metrics across thresholds
    if run_eval:
        _aggregate_eval_csv(output_folder, data_type, os.path.join(REPO, "data", "dexp_outputs", output_folder))


### `detoxify()` — parameters

- **data_type** *(str)*: Which dataset config to use. Must be one of the keys in `data_configs` (e.g., `paradetox`, `dynabench_val`).
- **output_folder** *(str)*: Top-level run name under `data/dexp_outputs/`. All results are saved here.

- **thresholds** *(tuple[float])*: DecompX masking thresholds to try (e.g., `(0.15, 0.20, 0.25)`). Higher → less masking.
- **batch_size** *(int)*: Generation batch size (lower for smaller GPUs like T4).
- **ranking** *(bool)*: If `True`, generate multiple candidates per input and select the least toxic via DecompX-based re-ranking.
- **sample** *(bool)*: If `True`, enable sampling; if `False`, greedy decoding.
- **top_k_gen** *(int)*: Top-k filter for sampling on ensembled logits (used when `sample=True`).
- **top_p** *(float)*: Nucleus (top-p) sampling value on ensembled logits (used when `sample=True`).
- **filter_p** *(float)*: Nucleus (top-p) filter applied to the *base* model logits before ensembling (advanced; usually leave at `1.0`).
- **max_length** *(int)*: Maximum generated length (tokens).

- **alpha_a / alpha_e / alpha_b** *(float)*: Weights for anti-expert, expert, and base logits in the product-of-experts. If `None`, defaults from `data_configs` are used.
- **temperature** *(float)*: Softens logits for sampling. Higher → more random. If `None`, uses dataset default.
- **rep_penalty** *(float)*: Repetition penalty (1.0 = off). If `None`, uses dataset default.

- **num_examples** *(int | None)*: If set, run on the first *N* examples (creates a subset file). Use `None` to process the full dataset.
- **overwrite_gen** *(bool)*: Re-run generation even if `gen.txt` exists.

- **run_eval** *(bool)*: If `True`, run evaluation **without toxicity** (BLEU, BERTScore, Perplexity) and write `gen_stats_notox.txt`.
- **overwrite_eval** *(bool)*: Re-compute evaluation even if stats exist.
- **skip_ref_eval** *(bool)*: If `True`, skip computing perplexity on references.

**Outputs**
- Files saved under `data/dexp_outputs/{output_folder}/{data_type}/DecompX{thresh}/.../`:
  - `masked_inputs.txt`, `orig.txt`, `gen.txt`
  - `gen_stats.txt` (if `run_eval=True`)
  - `{data_type}.csv` summary across thresholds (if `run_eval=True`)


### Example run

In [9]:

# Example: paradetox, first 50 lines, sampling + ranking on
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_demo",
#     thresholds=(0.20,),         # single threshold for quick test
#     batch_size=8,
#     ranking=True,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,            # small subset
#     run_eval=True,              # compute BLEU/BERTScore/Perplexity/Toxicity
#     skip_ref_eval=False
# )


Run: /usr/bin/python3 -m rewrite.rewrite_example --output_dir data/dexp_outputs/colab_run_demo --data_type paradetox --data_path ./datasets/_subsets/paradetox/test_toxic_parallel.txt --rep_penalty 1.0 --alpha_a 1.5 --alpha_e 4.75 --temperature 2.5 --alpha_b 1.0 --max_length 96 --batch_size 8 --top_k_gen 50 --top_p 0.95 --filter_p 1.0 --thresh 0.20 --ranking --ranking_eval_output 20 --sample --overwrite_gen


CalledProcessError: Command '['/usr/bin/python3', '-m', 'rewrite.rewrite_example', '--output_dir', 'data/dexp_outputs/colab_run_demo', '--data_type', 'paradetox', '--data_path', './datasets/_subsets/paradetox/test_toxic_parallel.txt', '--rep_penalty', '1.0', '--alpha_a', '1.5', '--alpha_e', '4.75', '--temperature', '2.5', '--alpha_b', '1.0', '--max_length', '96', '--batch_size', '8', '--top_k_gen', '50', '--top_p', '0.95', '--filter_p', '1.0', '--thresh', '0.20', '--ranking', '--ranking_eval_output', '20', '--sample', '--overwrite_gen']' returned non-zero exit status 2.

In [None]:
# # re-run eval only
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_demo",
#     thresholds=(0.20,),
#     batch_size=8,
#     ranking=True,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,
#     run_eval=True,
#     overwrite_eval=True,   # <--- force rewrite gen_stats.txt
#     skip_ref_eval=False
# )
