# xdetox pipeline (decompx masking + marco + decompx reranking)

runs the original XDetox pipeline, close to `lab.py`, with a few changes:

1. **dataset selector** - choose any dataset from `data_configs`.
2. **small-batch mode** - run only on the first `num_examples` examples.
3. **decompx masking** - token-level toxicity attribution on roberta to decide which tokens to mask with `<mask>`.
4. **marco generation** - BART base + non-toxic expert + toxic anti-expert.
5. **optional decompx-based reranking** - generate several candidates and pick the **least toxic** one.
6. **evaluation** - BLEU, bertscore, meaningbert, perplexity, and toxicity, plus a summary csv per dataset.

---

## what this does

for each chosen dataset:

1. **subsetting (optional)**  
   - if you set `num_examples`, the script writes a **subset file** under  
     `datasets/_subsets/{data_type}/...`  
   - this subset matches the format expected by `rewrite.rewrite_example.get_data()`.

2. **masking with decompx**

   - we use a roberta toxicity classifier with decompx to get **per-token toxicity importance**.
   - tokens that push the prediction towards the toxic class are **replaced with `<mask>`**.
   - masked sentences are saved as:

     - `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/masked_inputs.txt`

   - thresholds control how aggressively we mask:

     - for a threshold value $\tau$, higher $\tau$ → **less masking**, lower $\tau$ → **more masking**.

3. **generation with MaRCo (BART ensemble)**

   for each masked input, we use an ensemble of bart models:

   - **base** model (generic BART).
   - **expert** model (trained on non-toxic text).
   - **anti-expert** model (trained on toxic text).

   during generation, logits are combined as a **product-of-experts**:

   $$
   \text{logits}_{\text{ens}} = \alpha_b \cdot \text{logits}_{\text{base}}
   + \alpha_e \cdot \text{logits}_{\text{expert}}
   - \alpha_a \cdot \text{logits}_{\text{anti}}
   $$

   where:

   - $\alpha_a$ controls how strongly we **push away** from toxic patterns.
   - $\alpha_e$ controls how strongly we **pull towards** non-toxic patterns.
   - $\alpha_b$ controls the influence of the base model.

   sampling is controlled by:

   - `sample` (sampling vs greedy),
   - `top_k_gen`,
   - `top_p`,
   - `filter_p`,
   - `temperature`,
   - `rep_penalty`,
   - `max_length`.

4. **decompx-based reranking (inside `rewrite.rewrite_example`)**

   when `ranking=True`:

   - for each **input sentence**, the generation script samples `num_candidates` candidates.
   - for each candidate, it runs the decompx toxicity model **again** and computes a scalar "toxicity importance" score for that output.
   - it then chooses the candidate with **lowest summed toxicity importance** (the "least toxic" according to decompx).

   this happens inside the `rewrite.rewrite_example` module via:

   - `--ranking`
   - `--ranking_num_output {num_candidates}`

5. **evaluation**

   if `run_eval=True`, the notebook calls `evaluation.evaluate_all` for each generation folder and computes:

   - bertscore (f1)
   - meaningbert
   - bleu-4
   - perplexity (orig / gen)
   - toxicity (orig / gen)

   for each folder under:

   - `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/aa*_ae*_.../`

   it writes:

   - `orig.txt` - original toxic inputs.
   - `gen.txt` - final detoxified outputs.
   - `gen_stats.txt` - metrics for that run.

   then `_aggregate_eval_csv` aggregates across thresholds into:

   - `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

---

## `detoxify()` api

definition:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run",
    thresholds = (0.15, 0.20, 0.25),
    batch_size: int = 10,
    ranking: bool = True,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 10,
)
```

### main arguments

**dataset and i/o**

* `data_type`
  dataset key from `data_configs`, e.g.:

  * `"paradetox"`, `"dynabench_val"`, `"jigsaw_toxic"`, `"appdia_original"`, etc.

* `output_folder`
  name of the top-level run under `data/dexp_outputs/`.
  all outputs for this run go to:

  * `data/dexp_outputs/{output_folder}/{data_type}/...`

**masking / thresholds**

* `thresholds`
  tuple of decompx thresholds to try, e.g. `(0.15, 0.20, 0.25)`.
  each threshold ( \tau ) creates a folder `DecompX{τ}` and its own `masked_inputs.txt`, `orig.txt`, `gen.txt`, etc.

* `num_examples`

  * if an integer, only the **first N examples** are used (quick debugging).
  * if `None`, the full dataset is used.

**generation hyperparameters (marco)**

* `sample`

  * `True`: use sampling (random but controlled by temperature / top-k / top-p).
  * `False`: greedy decoding.

* `top_k_gen`
  top-k filter on the **ensembled** logits (only the top-k tokens by probability are kept).

* `top_p`
  nucleus (top-p) sampling on the **ensembled** logits. keeps the smallest set of tokens whose cumulative probability ≥ `top_p`.

* `filter_p`
  nucleus filter on the **base model** logits before ensembling (advanced; usually leave at `1.0`).

* `max_length`
  max length of the generated sequence (in tokens).

* `alpha_a`, `alpha_e`, `alpha_b`
  ensemble weights for anti-expert, expert, and base:

  * if `None`, defaults from `data_configs[data_type]` are used.

* `temperature`
  softens or sharpens the probability distribution:

  * higher temperature → more random.
  * lower temperature → more deterministic.
  * if `None`, the dataset default is used.

* `rep_penalty`
  repetition penalty (1.0 = no penalty). larger values discourage repeating tokens.

* `batch_size`
  number of sequences to generate in a batch (trade-off between speed and gpu memory).

**decompx reranking**

* `ranking`

  * `True`: enable decompx-based reranking of candidates inside `rewrite.rewrite_example`.
  * `False`: generate only **one** candidate per input (no reranking).

* `num_candidates`

  * when `ranking=True`, this sets `--ranking_num_output`.
  * for each input, the generator samples `num_candidates` candidates and picks the one with **lowest decompx toxicity importance**.

**evaluation**

* `run_eval`

  * if `True`, run `evaluation.evaluate_all` and write `gen_stats.txt` + summary csv.

* `overwrite_gen`

  * if `True`, regenerate outputs even if `gen.txt` already exists.

* `overwrite_eval`

  * if `True`, recompute evaluation values even if `gen_stats.txt` already exists.

* `skip_ref_eval`

  * if `True`, skip perplexity computation on references (faster).

**echo**
* `echo`
  * if `True`, print example inputs, masked inputs, generated outputs, and per-threshold evaluation metrics to the notebook.

---

## example usage

**quick test on a subset of paradetox (with reranking):**

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_demo",
    thresholds=(0.20,),
    batch_size=8,
    ranking=True,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=20,
)
```

**run without reranking (single candidate per input):**

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_no_rerank",
    thresholds=(0.20,),
    batch_size=8,
    ranking=False,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
)
```

after the run, you can inspect:

* `orig.txt` / `gen.txt` under `data/model_outputs/{output_folder}/{data_type}/DecompX{thresh}/...`
* `gen_stats.txt` for per-run metrics.
* `{data_type}.csv` for a summary over thresholds.

```
::contentReference[oaicite:0]{index=0}
```

In [None]:
from google.colab import drive; drive.mount('/content/drive')
import os, glob, re, sys, torch, json, shutil, math, nltk
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE

candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("checking mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("transformers_cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("repo folders ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi

!pip -q install bert-score

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("nltk ready")

In [None]:
from rewrite.generation import Infiller
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

## helpers

In [None]:
REPO = XDETOX_DIR

def _abs_repo_path(rel):
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        cols = df.columns.tolist()
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    import sys, os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()


def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []
    for thresh in np.arange(0.15, 0.3, 0.05, dtype=np.float64):
        mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
        base_path = os.path.join(base_out_dir, data_type, mask_dir)
        if not os.path.isdir(base_path):
            continue
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
            })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no evaluation files found to summarize.")

def _bool2str(x: bool) -> str:
    return "T" if x else "F"


def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )

In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run",
    thresholds = (0.15, 0.20, 0.25),
    echo: bool = False,
    batch_size: int = 10,
    ranking: bool = True,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 10,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if ranking and num_candidates < 1:
        raise ValueError("num_candidates must be >= 1 when ranking=True")

    if alpha_a is None:
        alpha_a = cfg["alpha_a"]
    if alpha_e is None:
        alpha_e = cfg["alpha_e"]
    if temperature is None:
        temperature = cfg["temperature"]
    if rep_penalty is None:
        rep_penalty = cfg["rep_penalty"]

    base_out_dir = os.path.join("data", "model_outputs", output_folder)
    abs_base_out_dir = os.path.join(REPO, base_out_dir)
    _ensure_dir(abs_base_out_dir)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {abs_base_out_dir}")
        print(f"number of examples to detoxify: {num_inputs}")
        print(f"thresholds: {', '.join(f'{t:.2f}' for t in thresholds)}")
        print(f"ranking: {ranking}, num_candidates: {num_candidates}")
        print("\nexample inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"

    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )

    for t in thresholds:
        mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
        thresh_root_dir = os.path.join(abs_base_out_dir, data_type, mask_dir)
        _ensure_dir(thresh_root_dir)

        cmd = [
            sys.executable, "-m", "rewrite.rewrite_example",
            "--output_dir", base_out_dir,
            "--data_type", data_type,
            "--data_path", subset_path.replace(REPO + "/", "./"),
            "--rep_penalty", str(rep_penalty),
            "--alpha_a", str(alpha_a),
            "--alpha_e", str(alpha_e),
            "--temperature", str(temperature),
            "--alpha_b", str(alpha_b),
            "--max_length", str(max_length),
            "--batch_size", str(batch_size),
            "--top_k_gen", str(top_k_gen),
            "--top_p", str(top_p),
            "--filter_p", str(filter_p),
            "--thresh", f"{t:.2f}",
        ]

        if ranking:
            cmd.append("--ranking")
            cmd.extend(["--ranking_num_output", str(num_candidates)])

        if sample:
            cmd.append("--sample")
        if overwrite_gen:
            cmd.append("--overwrite_gen")

        print("run:", " ".join(cmd))
        res = run(cmd, cwd=REPO, stdout=PIPE, stderr=PIPE, text=True)

        print("\n----- rewrite_example stdout -----")
        print(res.stdout)
        print("----- rewrite_example stderr -----")
        print(res.stderr)
        print("----- end -----\n")

        res.check_returncode()

        if echo:
            print("\n" + "-" * 80)
            print(f"threshold t={t:.2f} - sample masked and generated outputs")

            masked_path = os.path.join(thresh_root_dir, "masked_inputs.txt")
            if os.path.exists(masked_path):
                with open(masked_path, "r") as f:
                    masked_lines = [l.strip() for l in f.readlines()]
                print("example masked inputs (first up to 3):")
                for i, m in enumerate(masked_lines[:3]):
                    print(f"  masked[{i}]: {m}")
            else:
                print(f"masked_inputs.txt not found at {masked_path}")

            run_dir = os.path.join(thresh_root_dir, gen_folder)
            gen_txt = os.path.join(run_dir, "gen.txt")
            if os.path.exists(gen_txt):
                with open(gen_txt, "r") as f:
                    gen_lines = [l.strip() for l in f.readlines()]
                print("\nexample detoxified outputs (first up to 3):")
                for i, g in enumerate(gen_lines[:3]):
                    print(f"  detox[{i}]: {g}")
            else:
                print(f"gen.txt not found at {gen_txt}")

        if run_eval:
            base_path = os.path.join(abs_base_out_dir, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

            if echo:
                stats_path = os.path.join(base_path, gen_folder, "gen_stats.txt")
                if os.path.exists(stats_path):
                    stats = _read_stats_file(stats_path)
                    print("\nevaluation metrics for this run "
                          f"(t={t:.2f}):")
                    metric_keys = [
                        ("bertscore", "bertscore"),
                        ("meaningbert", "meaningbert"),
                        ("bleu4", "bleu-4"),
                        ("perplexity gen", "perplexity (gen)"),
                        ("perplexity orig", "perplexity (orig)"),
                        ("toxicity gen", "toxicity (gen)"),
                        ("toxicity orig", "toxicity (orig)"),
                    ]
                    for key, label in metric_keys:
                        val = stats.get(key, None)
                        if isinstance(val, float) and math.isnan(val):
                            continue
                        if val is None:
                            continue
                        print(f"  {label}: {val:.4f}")
                else:
                    print(f"gen_stats.txt not found at {stats_path}")

    if run_eval:
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

## example run

In [None]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_DecompX-Masking-DecompX-Reranking",
    thresholds=(0.20,),
    echo=True,
    batch_size=8,
    ranking=True,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=1000,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,
)