# XDetox Pipeline (DecompX Masking + MaRCo + DecompX Reranking)

This notebook runs the original XDetox pipeline, close to `lab.py`, with a few quality-of-life changes:

1. **Dataset selector** – choose any dataset from `data_configs`.
2. **Small-batch mode** – run only on the first `num_examples` examples.
3. **DecompX masking** – token-level toxicity attribution on RoBERTa to decide which tokens to mask with `<mask>`.
4. **MaRCo generation** – BART base + non-toxic expert + toxic anti-expert.
5. **Optional DecompX-based reranking** – generate several candidates and pick the **least toxic** one.
6. **Evaluation** – BLEU, BERTScore, MeaningBERT, perplexity, and toxicity, plus a summary CSV per dataset.

---

## What this pipeline does

For each chosen dataset:

1. **Subsetting (optional)**  
   - If you set `num_examples`, the script writes a **subset file** under  
     `datasets/_subsets/{data_type}/...`  
   - This subset matches the format expected by `rewrite.rewrite_example.get_data()`.

2. **Masking with DecompX**

   - We use a RoBERTa toxicity classifier with DecompX to get **per-token toxicity importance**.
   - Tokens that push the prediction towards the toxic class are **replaced with `<mask>`**.
   - Masked sentences are saved as:

     - `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/masked_inputs.txt`

   - Thresholds control how aggressively we mask:

     - For a threshold value $\tau$, higher $\tau$ → **less masking**, lower $\tau$ → **more masking**.

3. **Generation with MaRCo (BART ensemble)**

   For each masked input, we use an ensemble of BART models:

   - **Base** model (generic BART).
   - **Expert** model (trained on non-toxic text).
   - **Anti-expert** model (trained on toxic text).

   During generation, logits are combined as a **product-of-experts**:

   $$
   \text{logits}_{\text{ens}} = \alpha_b \cdot \text{logits}_{\text{base}}
   + \alpha_e \cdot \text{logits}_{\text{expert}}
   - \alpha_a \cdot \text{logits}_{\text{anti}}
   $$

   where:

   - $\alpha_a$ controls how strongly we **push away** from toxic patterns.
   - $\alpha_e$ controls how strongly we **pull towards** non-toxic patterns.
   - $\alpha_b$ controls the influence of the base model.

   Sampling is controlled by:

   - `sample` (sampling vs greedy),
   - `top_k_gen`,
   - `top_p`,
   - `filter_p`,
   - `temperature`,
   - `rep_penalty`,
   - `max_length`.

4. **DecompX-based reranking (inside `rewrite.rewrite_example`)**

   When `ranking=True`:

   - For each **input sentence**, the generation script samples `num_candidates` candidates.
   - For each candidate, it runs the DecompX toxicity model **again** and computes a scalar “toxicity importance” score for that output.
   - It then chooses the candidate with **lowest summed toxicity importance** (the “least toxic” according to DecompX).

   This happens inside the `rewrite.rewrite_example` module via:

   - `--ranking`
   - `--ranking_num_output {num_candidates}`

5. **Evaluation**

   If `run_eval=True`, the notebook calls `evaluation.evaluate_all` for each generation folder and computes:

   - BERTScore (F1)
   - MeaningBERT
   - BLEU-4
   - Perplexity (orig / gen)
   - Toxicity (orig / gen)
   - Optionally percent toxic (if you keep those columns)

   For each folder under:

   - `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/aa*_ae*_.../`

   it writes:

   - `orig.txt` – original toxic inputs.
   - `gen.txt` – final detoxified outputs.
   - `gen_stats.txt` – metrics for that run.

   Then `_aggregate_eval_csv` aggregates across thresholds into:

   - `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

---

## `detoxify()` API

Definition:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run",
    thresholds = (0.15, 0.20, 0.25),
    batch_size: int = 10,
    ranking: bool = True,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 10,
)
```

### Main arguments

**Dataset and I/O**

* `data_type`
  Dataset key from `data_configs`, e.g.:

  * `"paradetox"`, `"dynabench_val"`, `"jigsaw_toxic"`, `"appdia_original"`, etc.

* `output_folder`
  Name of the top-level run under `data/dexp_outputs/`.
  All outputs for this run go to:

  * `data/dexp_outputs/{output_folder}/{data_type}/...`

**Masking / thresholds**

* `thresholds`
  Tuple of DecompX thresholds to try, e.g. `(0.15, 0.20, 0.25)`.
  Each threshold ( \tau ) creates a folder `DecompX{τ}` and its own `masked_inputs.txt`, `orig.txt`, `gen.txt`, etc.

* `num_examples`

  * If an integer, only the **first N examples** are used (quick debugging).
  * If `None`, the full dataset is used.

**Generation hyperparameters (MaRCo)**

* `sample`

  * `True`: use sampling (random but controlled by temperature / top-k / top-p).
  * `False`: greedy decoding.

* `top_k_gen`
  Top-k filter on the **ensembled** logits (only the top-k tokens by probability are kept).

* `top_p`
  Nucleus (top-p) sampling on the **ensembled** logits. Keeps the smallest set of tokens whose cumulative probability ≥ `top_p`.

* `filter_p`
  Nucleus filter on the **base model** logits before ensembling (advanced; usually leave at `1.0`).

* `max_length`
  Maximum length of the generated sequence (in tokens).

* `alpha_a`, `alpha_e`, `alpha_b`
  Ensemble weights for anti-expert, expert, and base:

  * If `None`, defaults from `data_configs[data_type]` are used.

* `temperature`
  Softens or sharpens the probability distribution:

  * Higher temperature → more random.
  * Lower temperature → more deterministic.
  * If `None`, the dataset default is used.

* `rep_penalty`
  Repetition penalty (1.0 = no penalty). Larger values discourage repeating tokens.

* `batch_size`
  Number of sequences to generate in a batch (trade-off between speed and GPU memory).

**DecompX reranking**

* `ranking`

  * `True`: enable DecompX-based reranking of candidates inside `rewrite.rewrite_example`.
  * `False`: generate only **one** candidate per input (no reranking).

* `num_candidates`

  * When `ranking=True`, this sets `--ranking_num_output`.
  * For each input, the generator samples `num_candidates` candidates and picks the one with **lowest DecompX toxicity importance**.

**Evaluation**

* `run_eval`

  * If `True`, run `evaluation.evaluate_all` and write `gen_stats.txt` + summary CSV.

* `overwrite_gen`

  * If `True`, regenerate outputs even if `gen.txt` already exists.

* `overwrite_eval`

  * If `True`, recompute evaluation values even if `gen_stats.txt` already exists.

* `skip_ref_eval`

  * If `True`, skip perplexity computation on references (faster).

**Echo**
* `echo`
  * If `True`, print example inputs, masked inputs, generated outputs, and per-threshold evaluation metrics to the notebook.

---

## Example Usage

**Quick test on a subset of ParaDetox (with reranking):**

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_demo",
    thresholds=(0.20,),     # single threshold for quick check
    batch_size=8,
    ranking=True,           # use DecompX-based reranking
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,        # use only first 50 examples
    run_eval=True,          # compute BLEU / BERTScore / MeaningBERT / PPL / Toxicity
    overwrite_gen=True,     # regenerate gen.txt
    overwrite_eval=True,    # recompute gen_stats.txt
    skip_ref_eval=False,
    num_candidates=20,      # 20 candidates per input for DecompX reranking
)
```

**Run without reranking (single candidate per input):**

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_no_rerank",
    thresholds=(0.20,),
    batch_size=8,
    ranking=False,          # no DecompX reranking
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
)
```

After the run, you can inspect:

* `orig.txt` / `gen.txt` under `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/...`
* `gen_stats.txt` for per-run metrics.
* `{data_type}.csv` for a summary over thresholds.

```
::contentReference[oaicite:0]{index=0}
```


In [None]:
#@title Mount Drive & locate XDetox
from google.colab import drive; drive.mount('/content/drive')
import os, glob, re, sys, torch, json, shutil, math, nltk
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE

# Try My Drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"



Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox


In [None]:
#@title Runtime setup (paths, cache, GPU)
# HuggingFace cache inside the repo (persists on Drive)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

# Add repo to PYTHONPATH
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("XDETOX_DIR:", XDETOX_DIR)
print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: Tesla T4


In [None]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")


Repo folders OK.


In [None]:
#@title Install dependencies (restart runtime if warnings/errors appear)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi

# BERTScore dependency required by evaluate/datasets
!pip -q install bert-score

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m21.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m70.7 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.3.1 which is incompatible.[0m[31m
[0m

In [None]:

#@title NLTK data
nltk.download("punkt", quiet=True)
# Some Colab images need this table; ignore if unavailable:
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")



NLTK ready


In [None]:
#@title Import from 'rewrite'
from rewrite.generation import Infiller
from rewrite import rewrite_example as rx
import argparse as _argparse



In [None]:

#@title Data configs (same as lab.py)
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("Datasets:", ", ".join(data_configs.keys()))


Datasets: microagressions_val, microagressions_test, sbf_val, sbf_test, dynabench_val, dynabench_test, jigsaw_toxic, paradetox, appdia_original, appdia_discourse


### Helpers: subset data, call rewrite, run evaluation

In [None]:
REPO = XDETOX_DIR

def _abs_repo_path(rel):
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """Create a small subset file matching the expected format used by rewrite_example.get_data().
    Returns the path to the *new* subset file (or original path if n is None).
    """
    if n is None or n <= 0:
        return data_path  # no subset

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        cols = df.columns.tolist()
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        # txt file
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        # tsv
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    # Fallback: copy original
    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    import sys, os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")  # note: now _stats.txt
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("Eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()


def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    """Read gen_stats_notox.txt into a dict of floats; tolerate '(skipped): None'."""
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            # normalize keys like 'perplexity orig (skipped)' -> 'perplexity orig'
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []
    for thresh in np.arange(0.15, 0.3, 0.05, dtype=np.float64):
        mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
        base_path = os.path.join(base_out_dir, data_type, mask_dir)
        if not os.path.isdir(base_path):
            continue
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),   # NEW
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
                # Optional: you can also keep the percent toxic columns
                # "percent_toxic_gen": s.get("percent toxic gen", np.nan),
                # "percent_toxic_ref": s.get("percent toxic ref", np.nan),
            })

    if rows:
        # Column order: MeaningBERT between BERTScore and BLEU4
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
            # "percent_toxic_gen", "percent_toxic_ref",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")
def _bool2str(x: bool) -> str:
    return "T" if x else "F"


def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    """
    Rebuild the run folder name used by rewrite.rewrite_example so we can
    locate gen.txt / orig.txt / gen_stats.txt for echo and evaluation.
    """
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )


In [None]:
#@title `detoxify()` — run masking + generation + optional eval

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run",
    thresholds = (0.15, 0.20, 0.25),
    echo: bool = False,             # NEW
    batch_size: int = 10,
    ranking: bool = True,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,   # if None, take from data_configs
    alpha_e: float = None,   # if None, take from data_configs
    alpha_b: float = 1.0,
    temperature: float = None,  # if None, from data_configs
    rep_penalty: float = None,  # if None, from data_configs
    num_examples: int = 100,    # small-batch control; set None to use full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # number of candidates per input for DecompX reranking
    num_candidates: int = 10,
):
    """
    Run the XDetox pipeline similarly to lab.py but with small-batch support,
    optional DecompX-based reranking, and evaluation including
    BLEU, BERTScore, MeaningBERT, perplexity, and toxicity.

    When `ranking=True`, `num_candidates` controls how many candidates per
    input DecompX generates and reranks (via --ranking_num_output).

    If echo=True, the function will print:
      - how many examples and which dataset,
      - a few example inputs,
      - a few masked inputs (per threshold),
      - a few detoxified outputs (per threshold),
      - evaluation metrics (per threshold) from gen_stats.txt (if run_eval=True).
    """

    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if ranking and num_candidates < 1:
        raise ValueError("num_candidates must be >= 1 when ranking=True")

    # fallbacks from data_configs
    if alpha_a is None:
        alpha_a = cfg["alpha_a"]
    if alpha_e is None:
        alpha_e = cfg["alpha_e"]
    if temperature is None:
        temperature = cfg["temperature"]
    if rep_penalty is None:
        rep_penalty = cfg["rep_penalty"]

    # Use model_outputs instead of dexp_outputs
    base_out_dir = os.path.join("data", "model_outputs", output_folder)
    abs_base_out_dir = os.path.join(REPO, base_out_dir)
    _ensure_dir(abs_base_out_dir)

    # small subset path creation
    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    # Load inputs once (same logic as rewrite_example.get_data) for echo
    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"[echo] Dataset: {data_type}")
        print(f"[echo] Subset path: {subset_path}")
        print(f"[echo] Output base: {abs_base_out_dir}")
        print(f"[echo] Number of examples to detoxify: {num_inputs}")
        print(f"[echo] Thresholds: {', '.join(f'{t:.2f}' for t in thresholds)}")
        print(f"[echo] ranking: {ranking}, num_candidates: {num_candidates}")
        print("\n[echo] Example inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    # These are the fixed types used by the MaRCo models
    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"

    # Folder name used by rewrite_example for this set of hyperparameters
    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )

    # run thresholds
    for t in thresholds:
        mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
        thresh_root_dir = os.path.join(abs_base_out_dir, data_type, mask_dir)
        _ensure_dir(thresh_root_dir)

        # Build command to call rewrite module
        cmd = [
            sys.executable, "-m", "rewrite.rewrite_example",
            "--output_dir", base_out_dir,           # relative to REPO
            "--data_type", data_type,
            "--data_path", subset_path.replace(REPO + "/", "./"),
            "--rep_penalty", str(rep_penalty),
            "--alpha_a", str(alpha_a),
            "--alpha_e", str(alpha_e),
            "--temperature", str(temperature),
            "--alpha_b", str(alpha_b),
            "--max_length", str(max_length),
            "--batch_size", str(batch_size),
            "--top_k_gen", str(top_k_gen),
            "--top_p", str(top_p),
            "--filter_p", str(filter_p),
            "--thresh", f"{t:.2f}",
        ]

        # DecompX reranking: pass both the flag and the number of candidates
        if ranking:
            cmd.append("--ranking")
            cmd.extend(["--ranking_num_output", str(num_candidates)])

        if sample:
            cmd.append("--sample")
        if overwrite_gen:
            cmd.append("--overwrite_gen")

        print("Run:", " ".join(cmd))
        run(cmd, cwd=REPO, check=True)

        # ------------------------------------------------------------------
        # echo: show masked inputs and detoxified outputs for this threshold
        # ------------------------------------------------------------------
        if echo:
            print("\n" + "-" * 80)
            print(f"[echo] Threshold t={t:.2f} — sample masked and generated outputs")

            # masked inputs (one file per threshold)
            masked_path = os.path.join(thresh_root_dir, "masked_inputs.txt")
            if os.path.exists(masked_path):
                with open(masked_path, "r") as f:
                    masked_lines = [l.strip() for l in f.readlines()]
                print("[echo] Example masked inputs (first up to 3):")
                for i, m in enumerate(masked_lines[:3]):
                    print(f"  masked[{i}]: {m}")
            else:
                print(f"[echo] masked_inputs.txt not found at {masked_path}")

            # generated detoxified outputs for this hyperparameter setting
            run_dir = os.path.join(thresh_root_dir, gen_folder)
            gen_txt = os.path.join(run_dir, "gen.txt")
            if os.path.exists(gen_txt):
                with open(gen_txt, "r") as f:
                    gen_lines = [l.strip() for l in f.readlines()]
                print("\n[echo] Example detoxified outputs (first up to 3):")
                for i, g in enumerate(gen_lines[:3]):
                    print(f"  detox[{i}]: {g}")
            else:
                print(f"[echo] gen.txt not found at {gen_txt}")

        # Optional evaluation (BLEU / BERTScore / MeaningBERT / PPL / Toxicity)
        if run_eval:
            base_path = os.path.join(abs_base_out_dir, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

            # echo: print metrics for THIS run (this threshold + this gen folder)
            if echo:
                stats_path = os.path.join(base_path, gen_folder, "gen_stats.txt")
                if os.path.exists(stats_path):
                    stats = _read_stats_file(stats_path)
                    print("\n[echo] Evaluation metrics for this run "
                          f"(t={t:.2f}):")
                    metric_keys = [
                        ("bertscore", "BERTScore"),
                        ("meaningbert", "MeaningBERT"),
                        ("bleu4", "BLEU-4"),
                        ("perplexity gen", "Perplexity (gen)"),
                        ("perplexity orig", "Perplexity (orig)"),
                        ("toxicity gen", "Toxicity (gen)"),
                        ("toxicity orig", "Toxicity (orig)"),
                    ]
                    for key, label in metric_keys:
                        val = stats.get(key, None)
                        if isinstance(val, float) and math.isnan(val):
                            continue
                        if val is None:
                            continue
                        print(f"  {label}: {val:.4f}")
                else:
                    print(f"[echo] gen_stats.txt not found at {stats_path}")

    # Summarize metrics across thresholds into a CSV
    if run_eval:
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )


### Example run

In [None]:

detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_DecompX-Masking-DecompX-Reranking",
    thresholds=(0.20,),
    echo=True,
    batch_size=8,
    ranking=True,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=1000,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,
)



[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking-DecompX-Reranking
[echo] Number of examples to detoxify: 671
[echo] Thresholds: 0.20
[echo] ranking: True, num_candidates: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Run: /usr/bin/python3 -m rewrite.rewrite_example --output_dir data/model_outputs/XDetox_w_DecompX-Masking-DecompX-Reranking --data_type paradetox --data_path ./datasets/_subsets/paradetox/test_toxic_parallel.txt --rep_penalty 1.0 --alpha_a 1.5 --alpha_e 4.75 --temperature 2.5 --alpha_b 1.0 --max_length 96 --batch_size 8 --top_k_g

CalledProcessError: Command '['/usr/bin/python3', '-m', 'rewrite.rewrite_example', '--output_dir', 'data/model_outputs/XDetox_w_DecompX-Masking-DecompX-Reranking', '--data_type', 'paradetox', '--data_path', './datasets/_subsets/paradetox/test_toxic_parallel.txt', '--rep_penalty', '1.0', '--alpha_a', '1.5', '--alpha_e', '4.75', '--temperature', '2.5', '--alpha_b', '1.0', '--max_length', '96', '--batch_size', '8', '--top_k_gen', '50', '--top_p', '0.95', '--filter_p', '1.0', '--thresh', '0.20', '--ranking', '--ranking_num_output', '10', '--sample', '--overwrite_gen']' returned non-zero exit status 1.