# xdetox with decompx masking, llm infilling, and decompx reranking

this notebook runs an xdetox variant with:

1. **decompx masking** (token-level toxicity attribution on roberta).
2. **llm infilling** using mistral-7b-instruct  
   (`mistralai/Mistral-7B-Instruct-v0.2`).
3. **decompx-based reranking** of multiple llm candidates, following the
   reranking strategy described in *"xdetox: text detoxification with token-level toxicity explanations"*.

for each toxic input sentence, the goal is to pick **one final detoxified candidate** that:

- has the **lowest decompx toxicity importance**,
- while being produced by an llm that sees both the **raw toxic sentence** and the **decompx-masked sentence**.

compared to the **decompx-masking + llm infilling + global reranking** pipeline, this notebook:

- uses the **same decompx masking** and **same llm infilling** setup,
- but replaces global reranking (toxicity + similarity + fluency) with **pure decompx toxicity-based reranking**.

---

## scoring: decompx-based reranking

the reranking stage follows the idea described in section 2.3 of the xdetox paper:

1. for each input, we generate multiple candidate detoxified sentences $s_j$ with the llm.
2. for each candidate sentence $s_j$, we run **decompx** and compute **token-level importance scores** with respect to toxicity.
3. for each candidate, we sum the importance scores of its tokens:

   - let $t_{i,j}$ be the $i$-th token in candidate sentence $s_j$.
   - let $\text{Importance}(t_{i,j})$ be its decompx importance score (toxicity contribution).

   the **total toxicity importance** of candidate $s_j$ is:

   $$
   \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
   $$

4. we choose the candidate with the **lowest total importance**:

   $$
   s^* = \arg\min_{s_j} \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
   $$

intuitively:

- a **lower sum** of token-level importance scores means **lower overall toxicity**.
- reranking selects the candidate with the **minimum decompx toxicity** among all candidates for that input.

in the implementation, this is encapsulated in a helper like:

- `rerank_candidates_decompx(sources, candidates, threshold, batch_size_mask)`

which:

- flattens all candidates,
- runs decompx on each candidate sentence,
- computes a decompx-based toxicity score per candidate,
- reshapes scores back to `[num_inputs, num_candidates]`,
- picks the **index of the candidate with the lowest score** for each input.

there is **no xlm-r / labse / gpt-2 global scoring** in this notebook.  
all reranking is done by **decompx**.

---

## decompx masking

masking uses the original xdetox **decompx masker**:

- implementation: `rewrite.mask_orig.Masker`.
- backend: roberta with decompx token-level toxicity attribution.

for each input sentence:

1. decompx computes an **importance score** for each token based on its contribution to toxicity.
2. if the importance score of a token exceeds a **threshold** $t$, that token is considered **toxic**.
3. such tokens are replaced by the `<mask>` token.

for a given dataset and decompx threshold $t$:

- inputs are loaded via `rewrite_example.get_data`.
- masked outputs are written to:

```text
  data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/masked_inputs.txt
```

(where the exact directory name may encode the decompx threshold and the fact that this is a decompx-masking + llm pipeline.)

these masked sentences are later fed into the llm infiller.

---

## llm infilling (mistral-7b-instruct)

after decompx masking, we use **mistral-7b-instruct** as an **infilling model**.

### inputs to the llm

for each example we provide **both**:

* **toxic sentence**: the original toxic sentence, unchanged.
* **masked sentence**: the decompx-masked sentence, where toxic spans have been replaced by `<mask>`.

the llm prompt is structured along the lines of:

```text
you are a helpful assistant trained to make toxic or offensive sentences more polite and respectful
while keeping their original meaning. ...

toxic sentence: {raw_toxic}
masked sentence: {masked_by_decompx}
final output:
```

the instructions emphasize:

* **only fill in the `<mask>` tokens** in the masked sentence.
* keep **all non-masked parts** of the masked sentence as close as possible to their original form.
* preserve the **meaning and intent** of the toxic sentence.
* use the **same language** as the toxic sentence.
* return **only** the final detoxified sentence **inside one pair of square brackets**:

```text
[detoxified sentence here.]
```

### candidate generation

for each $(\text{toxic}, \text{masked})$ pair, we ask mistral for `num_candidates` completions:

* the generation parameters include:

  * `llm_sample` (sampling vs greedy),
  * `llm_temperature`,
  * `llm_top_p`,
  * `max_new_tokens`.

* for each completion, we:

  1. **extract content inside the first `[ ... ]` block.**
  2. **strip any remaining outer brackets.**
  3. **normalize whitespace.**

if a cleaned candidate is empty, we fall back to using the masked sentence.

the result: for each input sentence, we obtain a list of `num_candidates` **llm-generated candidates** to be reranked by decompx.

---

## end-to-end flow: masking, llm infilling, decompx reranking

for each dataset:

1. **subset selection**

   * the script can run on the full dataset or only the first `num_examples` instances.
   * a subset file is written under:

```text
datasets/_subsets/{data_type}/
```

2. **decompx masking (per threshold)**

   * for each threshold $t$ in `thresholds`, we run decompx masking.
   * masked sentences are saved to:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/masked_inputs.txt
```

3. **llm infilling (mistral-7b-instruct)**

   * for each masked sentence and its corresponding toxic input, we call mistral with the prompt described above.
   * we generate `num_candidates` candidates per input.
   * the raw llm outputs are post-processed (bracket extraction, whitespace cleanup).

4. **decompx-based reranking**

   * for each threshold $t$ and each input sentence, we apply decompx to **all llm candidates**.

   * for each candidate $s_j$, we compute the total toxicity importance:

$$
\sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
$$

   * we choose the candidate with the **lowest** total importance as the final output.

   * for each run, we write:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/orig.txt
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/gen.txt
```

   where:

   * `orig.txt`: original toxic inputs (one per line),
   * `gen.txt`: selected (reranked) llm outputs (one per line),
   * `{run_folder}` encodes llm generation hyperparameters.

---

## evaluation

if `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* bertscore (f1),
* meaningbert,
* bleu-4,
* toxicity (orig / gen),
* perplexity (orig / gen).

for each `(threshold, run_folder)` we write:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/gen_stats.txt
```

the notebook also builds a **summary csv per dataset** by scanning all `DecompX_LLM_DecompX*` directories:

```text
data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
```

this csv aggregates:

* `threshold` (decompx masking / reranking threshold),
* `folder` (run folder name),
* `bertscore`, `meaningbert`, `bleu4`,
* `perplexity_gen`, `perplexity_orig`,
* `toxicity_gen`, `toxicity_orig`.

---

## how to use `detoxify()`

a typical function signature for this notebook looks like:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_decompx_mask_llm_decompx",
    thresholds = (0.20,),
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    mask_batch_size: int = 10,
    llm_sample: bool = True,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    max_new_tokens: int = 64,
    num_candidates: int = 3,
)
```

### key arguments

#### core i/o

* `data_type`:

  * key in `data_configs`, for example:

    * `"paradetox"`, `"dynabench_val"`, `"dynabench_test"`,
    * `"jigsaw_toxic"`, `"microagressions_val"`, `"sbf_val"`,
    * `"appdia_original"`, `"appdia_discourse"`, etc.

* `output_folder`:

  * top-level directory under:

    ```text
    data/model_outputs/{output_folder}/{data_type}/...
    ```

* `num_examples`:

  * `None`: use the full dataset.
  * integer: run on the first `num_examples` examples.

* `overwrite_gen`:

  * `False`: if `gen.txt` already exists for a given `(threshold, run_folder)`, reuse existing generations.
  * `True`: regenerate and overwrite `gen.txt`.

* `echo`:

  * if `True`, print:

    * basic dataset info,
    * example inputs,
    * example masked sentences,
    * example final outputs,
    * per-run metrics (if `run_eval=True`).

#### decompx masking and thresholds

* `thresholds`:

  * tuple of decompx thresholds (e.g. `(0.15, 0.20, 0.25)`).
  * for each $t$, we:

    * run decompx masking with threshold $t$,
    * run llm infilling,
    * apply decompx-based reranking (using the same decompx mechanism).

* `mask_batch_size`:

  * batch size used when running decompx masking over inputs.

#### llm infilling (mistral)

* `llm_sample`:

  * `True`: sampling.
  * `False`: deterministic decoding.

* `llm_temperature`:

  * sampling temperature for mistral (used when `llm_sample=True`).

* `llm_top_p`:

  * top-p nucleus sampling cutoff.

* `max_new_tokens`:

  * maximum number of new tokens generated per candidate.

* `num_candidates`:

  * number of llm candidates per input to be reranked by decompx.

#### evaluation

* `run_eval`:

  * if `True`, compute evaluation metrics and write `gen_stats.txt` files.

* `overwrite_eval`:

  * if `True`, recompute metrics even if `gen_stats.txt` already exists.

* `skip_ref_eval`:

  * if `True`, skip reference-based evaluation (for example, perplexity on reference outputs).

---

## example calls

### quick sanity check (single threshold, small subset)

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_decompx_mask_llm_decompx_demo_50_ex",
    thresholds=(0.20,),
    echo=True,
    num_examples=50,
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    mask_batch_size=8,
    llm_sample=True,
    llm_temperature=0.7,
    llm_top_p=0.95,
    max_new_tokens=64,
    num_candidates=10,
)
```

### larger run (multiple thresholds, full dataset)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_decompx_mask_llm_decompx_full",
    thresholds=(0.15, 0.20, 0.25),
    echo=True,
    num_examples=None,
    overwrite_gen=False,
    run_eval=True,
    overwrite_eval=False,
    skip_ref_eval=False,
    mask_batch_size=8,
    llm_sample=True,
    llm_temperature=0.7,
    llm_top_p=0.95,
    max_new_tokens=64,
    num_candidates=10,
)
```

after running `detoxify`, you can inspect:

* per-threshold, per-run outputs:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/orig.txt
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/gen.txt
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM_DecompX{t}/{run_folder}/gen_stats.txt
```

* aggregated metrics:

```text
data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
```

this notebook lets you compare:

* **decompx masking + llm infilling + decompx reranking** (this pipeline),
* against:

  * **decompx masking + llm infilling + global reranking**, and
  * the original **decompx masking + marco + decompx/global reranking** pipelines,

on the same datasets, using a decompx-based toxicity selection rule.

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("try mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using xdetox_dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("transformers_cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from rewrite.mask_orig import Masker as Masker_single
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("nltk ok")

In [None]:
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

In [None]:
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
def _decompx_mask_texts(texts: List[str],
                        threshold: float = 0.20,
                        batch_size: int = 16) -> List[str]:
    if not texts:
        return []

    masker = Masker_single()
    masked_all = []
    for i in tqdm(range(0, len(texts), batch_size),
                  desc="decompx masking", leave=False):
        batch = texts[i:i + batch_size]
        batch_out = masker.process_text(sentence=batch, threshold=threshold)
        masked_all.extend(batch_out)
    masker.release_model()

    cleaned = [
        m.replace("<s>", "").replace("</s>", "").strip()
        for m in masked_all
    ]
    return cleaned

def _decompx_toxicity_scores(texts: List[str],
                             threshold: float = 0.20,
                             batch_size: int = 16) -> np.ndarray:
    if not texts:
        return np.zeros((0,), dtype=float)

    masked = _decompx_mask_texts(texts, threshold=threshold, batch_size=batch_size)
    scores = []
    for m in masked:
        num_masks = len(re.findall(r"<mask>", m))
        tokens = m.split()
        length = max(len(tokens), 1)
        scores.append(num_masks / length)
    return np.asarray(scores, dtype=float)

def rerank_candidates_decompx(
    sources: List[str],
    candidates: List[List[str]],
    threshold: float = 0.20,
    batch_size_mask: int = 16,
):
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    scores = _decompx_toxicity_scores(
        flat_cands,
        threshold=threshold,
        batch_size=batch_size_mask,
    )

    scores2 = scores.reshape(N, C)
    best_idx = np.argmin(scores2, axis=1)

    details = {
        "score": scores2,
    }
    return best_idx, details

In [None]:
def _parse_run_folder_name(folder_name):
    return True

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []

    root = os.path.join(base_out_dir, data_type)
    if not os.path.isdir(root):
        print("no evaluation directory found:", root)
        return

    for mask_dir in os.listdir(root):
        if not mask_dir.startswith("DecompX_LLM"):
            continue
        thresh_str = mask_dir.replace("DecompX_LLM", "")
        try:
            threshold = float(thresh_str)
        except Exception:
            threshold = np.nan

        base_path = os.path.join(root, mask_dir)
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        threshold,
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
            })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no evaluation files found to summarize.")

In [None]:
USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("llm device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

INFILL_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences
more polite and respectful by INFILLING <mask> tokens in a masked sentence.

You are given:
- Toxic Sentence: the original sentence, which may contain offensive language.
- Masked Sentence: the same sentence, but all toxic spans are replaced by <mask>.

Your rules:
1. Only replace <mask> tokens in the Masked Sentence with polite, non-toxic alternatives.
2. Do NOT change any other words or punctuation from the Masked Sentence,
   except for small grammar fixes needed after infilling.
3. Preserve the original meaning and intent of the Toxic Sentence as much as possible.
4. Keep the same language as the Toxic Sentence.
5. If the Masked Sentence has no <mask> tokens, return it unchanged.

Output rules (very strict):
- Return ONLY the final detoxified sentence inside ONE pair of square brackets, like:
  [You are such a rude person, nobody wants to hear your opinion.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

INFILL_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Masked Sentence: You're such a <mask>, nobody wants to hear your <mask>.
Step 1 - Decide polite replacements for <mask>: "rude person", "opinion"
Step 2 - Insert them into the Masked Sentence, keeping all other tokens:
You're such a rude person, nobody wants to hear your opinion.
Final Output: [You're such a rude person, nobody wants to hear your opinion.]"""

def _lazy_load_llm_infiller():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"loading llm: {LLM_MODEL_NAME} on {DEVICE_LLM} ...")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("llm loaded")

def _extract_bracket_content(text: str) -> str:
    text = text.strip()
    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()
    if "[" in text:
        return text.split("[", 1)[1].strip()
    return text

def _cleanup_llm_output(s: str) -> str:
    s = s.strip()
    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()
    s = re.sub(r"\s+", " ", s).strip()
    return s

@torch.no_grad()
def llm_infill_candidates(
    toxic_sentences: List[str],
    masked_sentences: List[str],
    num_candidates: int = 3,
    temperature: float = 0.7,
    top_p: float = 0.95,
    max_new_tokens: int = 64,
    sample: bool = True,
) -> List[List[str]]:
    _lazy_load_llm_infiller()
    assert len(toxic_sentences) == len(masked_sentences), "length mismatch"

    all_candidates: List[List[str]] = []

    for idx in tqdm(range(len(toxic_sentences)), desc="llm infilling", leave=False):
        toxic = toxic_sentences[idx]
        masked = masked_sentences[idx]

        messages = [
            {
                "role": "system",
                "content": INFILL_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + INFILL_FEW_SHOT,
            },
            {
                "role": "user",
                "content": (
                    f"Toxic Sentence: {toxic}\n"
                    f"Masked Sentence: {masked}\n"
                    "Final Output:"
                ),
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                INFILL_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + INFILL_FEW_SHOT
                + "\n\nToxic Sentence: "
                + toxic
                + "\nMasked Sentence: "
                + masked
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=sample,
            temperature=temperature if sample else 0.0,
            top_p=top_p,
            num_return_sequences=num_candidates,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        input_len = inputs["input_ids"].shape[1]

        cands_for_this = []
        for k in range(num_candidates):
            gen_text = _LLM_TOKENIZER.decode(
                gen[k][input_len:], skip_special_tokens=True
            )
            detox = _extract_bracket_content(gen_text)
            detox = _cleanup_llm_output(detox)
            if not detox:
                detox = masked
            cands_for_this.append(detox)

        all_candidates.append(cands_for_this)

    return all_candidates

In [None]:
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_llm_run_folder_name(
    temperature: float,
    sample: bool,
    top_p: float,
    max_new_tokens: int,
    num_candidates: int,
):
    return (
        "llmtemp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topp" + str(top_p) +
        "_maxnew" + str(max_new_tokens) +
        "_ncand" + str(num_candidates)
    )

def _run_decompx_masking_llm_infill_and_decompx_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    mask_batch_size,
    llm_sample,
    llm_temperature,
    llm_top_p,
    max_new_tokens,
    num_candidates,
    decompx_threshold,
    overwrite_gen=False,
    inputs=None,
    rerank_batch_size: int = 16,
    echo: bool = False,
):
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"#inputs at thresh={thresh}: {len(inputs)}")

    mask_dir = f"DecompX_LLM{abs(thresh):g}" if thresh != 0 else "DecompX_LLM0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        print(f"running decompx masking (threshold={thresh:.2f}) to create masked_inputs.txt ...")
        decoded_mask_inputs = _decompx_mask_texts(
            inputs, threshold=thresh, batch_size=mask_batch_size
        )
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d).strip() + "\n")
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print("\nexample masked inputs (first up to 3):")
        for i, m in enumerate(decoded_mask_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    run_folder = _build_llm_run_folder_name(
        llm_temperature, llm_sample, llm_top_p, max_new_tokens, num_candidates
    )
    final_abs = os.path.join(cur_abs, run_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("gen exists at:", gen_txt, "— skipping")
        _ensure_dir(final_abs)
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]
        return inputs, decoded_mask_inputs, best_generations, final_abs

    _ensure_dir(final_abs)

    print(f"llm infilling: generating {num_candidates} candidates per input (sampling={llm_sample})")
    all_candidates = llm_infill_candidates(
        toxic_sentences=inputs,
        masked_sentences=decoded_mask_inputs,
        num_candidates=num_candidates,
        temperature=llm_temperature,
        top_p=llm_top_p,
        max_new_tokens=max_new_tokens,
        sample=llm_sample,
    )

    print(f"decompx reranking of llm candidates (threshold={decompx_threshold:.2f}) ...")
    best_idx, details = rerank_candidates_decompx(
        sources=inputs,
        candidates=all_candidates,
        threshold=decompx_threshold,
        batch_size_mask=rerank_batch_size,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    if echo:
        print("\nexample detoxified outputs (first up to 3):")
        for i, g in enumerate(best_generations[:3]):
            print(f"  detox[{i}]: {g}")

    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("saved:", orig_txt)
    print("saved:", gen_txt)

    return inputs, decoded_mask_inputs, best_generations, final_abs

In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_decompx_mask_llm_decompx",
    thresholds = (0.20,),
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    mask_batch_size: int = 10,
    rerank_batch_size: int = 16,
    decompx_rerank_threshold: float = None,
    llm_sample: bool = True,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    max_new_tokens: int = 64,
    num_candidates: int = 3,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {base_out_abs}")
        print(f"number of examples: {num_inputs}")
        print(f"thresholds: {', '.join(f'{t:.2f}' for t in thresholds)}")
        print(f"llm: temperature={llm_temperature}, top_p={llm_top_p}, "
              f"sample={llm_sample}, max_new_tokens={max_new_tokens}")
        print(f"num_candidates per input: {num_candidates}")
        print("\nexample inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    last_run_dir = None
    for t in thresholds:
        print("=" * 60)
        print(f"decompx masking threshold = {t:.2f}")
        effective_rerank_t = decompx_rerank_threshold if decompx_rerank_threshold is not None else t

        inputs, masked_inputs, best_generations, run_dir = \
            _run_decompx_masking_llm_infill_and_decompx_reranking_for_threshold(
                data_type=data_type,
                subset_path=subset_path,
                thresh=t,
                base_out_rel=base_out_rel,
                mask_batch_size=mask_batch_size,
                llm_sample=llm_sample,
                llm_temperature=llm_temperature,
                llm_top_p=llm_top_p,
                max_new_tokens=max_new_tokens,
                num_candidates=num_candidates,
                decompx_threshold=effective_rerank_t,
                overwrite_gen=overwrite_gen,
                inputs=inputs,
                rerank_batch_size=rerank_batch_size,
                echo=echo,
            )
        last_run_dir = run_dir

        if run_eval:
            mask_dir = f"DecompX_LLM{abs(t):g}" if t != 0 else "DecompX_LLM0.0"
            base_path = os.path.join(base_out_abs, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

            if echo:
                run_folder = os.path.basename(run_dir)
                stats_path = os.path.join(base_path, run_folder, "gen_stats.txt")
                if os.path.exists(stats_path):
                    stats = _read_stats_file(stats_path)
                    print("\neval metrics for this run "
                          f"(t={t:.2f}):")
                    metric_keys = [
                        ("bertscore",        "bertscore"),
                        ("meaningbert",      "meaningbert"),
                        ("bleu4",            "bleu-4"),
                        ("perplexity gen",   "perplexity (gen)"),
                        ("perplexity orig",  "perplexity (orig)"),
                        ("toxicity gen",     "toxicity (gen)"),
                        ("toxicity orig",    "toxicity (orig)"),
                    ]
                    for key, label in metric_keys:
                        val = stats.get(key, None)
                        if isinstance(val, float) and math.isnan(val):
                            continue
                        if val is None:
                            continue
                        print(f"  {label}: {val:.4f}")
                else:
                    print(f"gen_stats.txt not found at {stats_path}")

    if run_eval:
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

In [None]:
# example (small subset; adjust as needed):
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_decompx_mask_llm_decompx_demo_50_examples",
#     thresholds=(0.20,),
#     echo=True,
#     num_examples=50,
#     overwrite_gen=True,
#     run_eval=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     mask_batch_size=8,
#     rerank_batch_size=16,
#     llm_sample=True,
#     llm_temperature=0.7,
#     llm_top_p=0.95,
#     max_new_tokens=64,
#     num_candidates=10,
# )

In [15]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_DecompX-Masking_LLM-Infilling_DecompX-Reranking_Pipeline",
    thresholds=(0.20,),
    echo=True,
    num_examples=1000,
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    mask_batch_size=8,
    rerank_batch_size=16,
    llm_sample=True,
    llm_temperature=0.7,
    llm_top_p=0.95,
    max_new_tokens=64,
    num_candidates=10,
)


[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_DecompX-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] Thresholds (DecompX masking): 0.20
[echo] LLM: temperature=0.7, top_p=0.95, sample=True, max_new_tokens=64
[echo] num_candidates per input: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
DecompX masking threshold = 0.20
#inputs at thresh=0.2: 671
Running DecompX masking (threshold=0.20) to create masked_inputs.txt ...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).
Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequence

DecompX masking:   0%|          | 0/84 [00:00<?, ?it/s]




[echo] Example masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated  and " right <mask> now " would be good .
  masked[2]: * neither * of my coworkers gave a <mask> when it came time to ditch mitch . ugh .
LLM infilling: generating 10 candidates per input (sampling=True)
Loading LLM infiller: mistralai/Mistral-7B-Instruct-v0.2 on cuda ...


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM infiller loaded.


LLM infilling:   0%|          | 0/671 [00:00<?, ?it/s]

DecompX reranking of LLM candidates (threshold=0.20) ...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking:   0%|          | 0/420 [00:00<?, ?it/s]




[echo] Example detoxified outputs (first up to 3):
  detox[0]: or the loud person one - thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and " right now " would be good .
  detox[2]: neither of my coworkers showed interest when it came time to ditch mitch . ugh .
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/DecompX_LLM0.2/llmtemp0.7_sampleT_topp0.95_maxnew64_ncand10/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/DecompX_LLM0.2/llmtemp0.7_sampleT_topp0.95_maxnew64_ncand10/gen.txt
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/DecompX_LLM0.2/llmtemp0.7_sampleT_topp0.95_maxnew64_ncand10/orig.txt --ge