# xdetox with llm masking, llm infilling, and decompx reranking

This notebook runs an XDetox pipeline with:

1. **llm masking** using Mistral-7B-Instruct (`mistralai/Mistral-7B-Instruct-v0.2`), which detects toxic spans and replaces them with `<mask>`.
2. **llm infilling** (same mistral model) that fills the `<mask>` tokens with safe alternatives while keeping the rest of the sentence almost unchanged.
3. **decompx reranking** of multiple llm candidates per input, using token-level importance scores for toxicity from decompx.

the goal is to pick, for each toxic input sentence, **one best detoxified candidate** that is:

* As **non-toxic** as possible.
* As **semantically close** as possible to the original.
* As **fluent and natural** as possible.

main differences from other pipelines:

* both **masking** and **infilling** are done by an llm (Mistral-7B-Instruct).
* There is **no MaRCo / BART generation**.
* **decompx is used only for reranking**, not for masking.

---

## scoring: decompx reranking

Let $s_j$ be a candidate detoxified sentence, and let $t_{i,j}$ be its tokens.
decompx gives an **importance score** $\text{Importance}(t_{i,j})$ for each token, which measures how much that token contributes to predicted toxicity.

For each candidate $s_j$, decompx reranking computes a **toxicity score** as the sum of token importances:

$$
\text{Score}(s_j) = \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
$$

where $N_j$ is the number of tokens in $s_j$.

the final chosen sentence $s^*$ is:

$$
s^* = \arg\min_{s_j} \left( \sum_{i=1}^{N_j} \text{Importance}(t_{i,j}) \right)
$$

in words:

* decompx is applied to **each candidate sentence**.
* For each candidate, we **aggregate token-level toxicity importance**.
* The candidate with the **lowest total importance** (lowest contribution to toxicity) is selected.

the notebook uses a **decompx threshold** `decompx_threshold` that controls which tokens are considered toxic enough to be masked / receive high importance. a lower threshold is more sensitive and may highlight more tokens, while a higher threshold is stricter.

---

## llm masking (mistral-7b-instruct)

### prompted masking behavior

masking is done by a chat-style llm:

* Model: `mistralai/Mistral-7B-Instruct-v0.2`.
* The LLM is instructed to:

  * **Identify toxic, offensive, or profane words or short phrases**.
  * Replace **each toxic span** with a **single `<mask>` token**.
  * allow **multiple `<mask>` tokens** in one sentence (one per toxic span).
  * if multiple neighboring words are toxic, **collapse them into one `<mask>`**
    (no `<mask> <mask> ...` runs).
  * Keep **all non-toxic words and punctuation unchanged**.
  * **Not** paraphrase, summarize, or reorder the sentence.
  * Return the masked sentence **inside exactly one pair of brackets**:

    ```text
    [This is a <mask> example.]
    ```

### post-processing of llm masks

because llm output can be noisy, the notebook cleans the raw masked output:

1. **Extract the bracket content**:

   * read the first `[ ... ]` block if present.
   * If there is `[` but no `]`, take everything after the first `[` as the sentence.
   * if there are no brackets, fall back to the full string.

2. **Strip outer brackets** that may remain.

3. **normalize whitespace** (collapse repeated spaces).

4. **Normalize `<mask>` tokens**:

   * variants like `<Mask>`, `<MASK>`, `< mask >` are normalized to `<mask>`.

5. **collapse runs of `<mask>`**:

   * any sequence `<mask> <mask> <mask>` becomes a single `<mask>`.

6. if cleaning yields an empty string, fall back to the original masked text.

all cleaned, llm-masked sentences are written to:

* `data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/masked_inputs.txt`

and reused across later runs with the same `output_folder` and `data_type`.

---

## llm infilling (mistral-7b-instruct)

after masking, detoxification is also done by mistral-7b-instruct via **infilling**.

for each example, the llm sees two inputs:

1. **Toxic Sentence**: the original toxic sentence.
2. **masked sentence**: the same sentence, where toxic spans have been replaced with `<mask>`.

the infilling prompt instructs the llm to:

* treat the **masked sentence** as a **template**.
* For each `<mask>` token:

  * Insert a **short, non-toxic word or phrase** that fits the context.
  * preserve the meaning and intent of the **toxic sentence** as much as possible.
* Keep **all non-masked text unchanged**, except for small edits needed for grammar or agreement.
* keep the **language the same** (no translation).
* output only the final detoxified sentence in **one pair of brackets**:

  ```text
  [Detoxified sentence here.]
  ```

for each `(source, masked)` pair, the notebook:

1. builds a prompt that includes both **toxic sentence** and **masked sentence** plus a few-shot example.
2. calls `generate(...)` with:

   * `num_return_sequences = num_candidates`,
   * `do_sample = llm_sample`,
   * `temperature = llm_temperature` (if sampling),
   * `top_p = llm_top_p`,
   * `max_new_tokens = llm_max_new_tokens`.
3. Decodes only the **new tokens** after the prompt.
4. extracts the text inside `[ ... ]` and cleans it:

   * remove outer brackets,
   * normalize whitespace,
   * remove any leftover `<mask>` tokens.

if cleaning produces an empty string, the pipeline falls back to the original toxic input for that candidate.

this produces a list of candidates:

* `candidates[i]` is a list of length `num_candidates` for input $i$.

---

## decompx reranking of llm candidates

once we have llm-infilling candidates, we apply **decompx reranking**:

1. **flatten** all candidates into one list while tracking which input they belong to.

2. for each candidate, run decompx to obtain token-level importance scores for toxicity.

3. For each candidate $s_j$, compute:

   $$
   \text{Score}(s_j) = \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
   $$

   where $N_j$ is the number of tokens in the candidate.

4. for each input sentence, collect the scores of all its candidates and choose the candidate with the **lowest decompx score**.

the `decompx_threshold` parameter:

* controls the **sensitivity** of decompx to toxicity.
* a lower value will tend to assign non-zero importance to more tokens (more sensitive).
* A higher value will only highlight more strongly toxic tokens.

for each run folder, the notebook writes:

* `orig.txt` — original toxic inputs (one per line),
* `gen.txt` — chosen detoxified outputs (one per line).

outputs are stored under:

```text
data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/{run_folder}/
```

where `{run_folder}` encodes llm infilling hyperparameters and decompx settings.

---

## evaluation

if `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* bertscore (F1)
* meaningbert
* bleu-4
* toxicity (orig / gen) using xlm-r
* perplexity (orig / gen) using gpt-2

for each run folder, it writes:

* `gen_stats.txt` under:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/{run_folder}/
  ```

it also creates a **summary csv per dataset**:

* `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

the csv aggregates metrics over all run folders inside `LLM_Mask_LLM_DecompX/`.

for compatibility with other xdetox pipelines, the csv keeps a `threshold` column, which is used here as a **label** (for example `0.20`) tied to the decompx reranking configuration, not to a masking threshold.

---

## how to use `detoxify()`

function signature (conceptual):

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_infill_decompx",
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # LLM infilling:
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
    # DecompX reranking:
    decompx_threshold: float = 0.20,
    decompx_batch_size: int = 16,
)
```

### key arguments

#### core i/o

* `data_type`: dataset key from `data_configs`, for example:

  * `"paradetox"`, `"dynabench_val"`, `"dynabench_test"`,
  * `"jigsaw_toxic"`, `"microagressions_val"`, `"sbf_val"`,
  * `"appdia_original"`, `"appdia_discourse"`, etc.

* `output_folder`:

  * top-level directory under `data/model_outputs/`:

    ```text
    data/model_outputs/{output_folder}/{data_type}/...
    ```

* `num_examples`:

  * if set, only the first `num_examples` examples are processed.
  * use `None` to run on the full dataset.

* `overwrite_gen`:

  * if `False` and a matching `gen.txt` already exists, the notebook reuses previous generations.
  * If `True`, it regenerates outputs for that run folder.

* `echo`:

  * if `True`, the notebook prints:

    * dataset and subset path,
    * output base directory,
    * a few example inputs,
    * a few llm-masked outputs,
    * a few final detoxified outputs,
    * and, if `run_eval=True`, evaluation metrics.

#### llm masking (mistral)

* masking uses **mistral-7b-instruct** with a fixed system prompt and few-shot example.
* there is **no masking threshold**; the llm decides which spans to mask based on the instructions.
* masked sentences are cached in:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/masked_inputs.txt
  ```

and reused if present.

#### llm infilling (mistral)

* `num_candidates`:

  * number of infilling candidates to generate per input.

* `llm_temperature`:

  * sampling temperature (only used if `llm_sample=True`).

* `llm_top_p`:

  * nucleus sampling parameter during infilling.

* `llm_max_new_tokens`:

  * maximum number of new tokens to generate beyond the prompt.

* `llm_sample`:

  * `True`: stochastic sampling with `temperature` and `top_p`.
  * `False`: deterministic decoding (similar to greedy).

the infilling step uses **both** the toxic sentence and the masked sentence, and is strictly instructed to **only change `<mask>` spans** plus minor grammar fixes.

#### decompx reranking

* `decompx_threshold`:

  * decompx threshold used when computing token-level toxicity importance.
  * roughly controls how aggressively decompx marks tokens as toxic.

* `decompx_batch_size`:

  * batch size for decompx processing during reranking.

for each input:

* the pipeline uses decompx to score each llm candidate.
* it picks the candidate whose sum of token importance scores is **minimal**.

#### evaluation

* `run_eval`:

  * if `True`, run `evaluation.evaluate_all` and write `gen_stats.txt`.

* `overwrite_eval`:

  * If `False` and `gen_stats.txt` exists, keep existing metrics.
  * if `True`, recompute metrics.

* `skip_ref_eval`:

  * if `True`, skip some reference-based parts (for example, reference perplexity).

---

## example calls

### quick sanity check on a small subset

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_llm_mask_infill_decompx_demo_50_examples",
    echo=True,
    num_examples=50,               # small subset for testing
    run_eval=True,                 # BLEU / BERTScore / MeaningBERT / PPL / Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,             # LLM infilling candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    decompx_threshold=0.20,
    decompx_batch_size=16,
)
```

### larger run (more candidates, full dataset)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_llm_mask_infill_decompx_full",
    echo=True,
    num_examples=None,             # full dataset
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=False,
    skip_ref_eval=False,
    num_candidates=20,
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    decompx_threshold=0.20,
    decompx_batch_size=16,
)
```

after running `detoxify`, you can inspect:

* final inputs and outputs:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/{run_folder}/orig.txt
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/{run_folder}/gen.txt
  ```

* per-run evaluation metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_DecompX/{run_folder}/gen_stats.txt
  ```

* aggregated metrics across runs:

  ```text
  data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
  ```

this pipeline lets you compare:

* **llm-masking + llm-infilling + decompx reranking**

against other xdetox pipelines such as:

* **llm-masking + llm-infilling + global reranking**,
* **decompx-masking + llm-infilling + decompx or global reranking**, and
* **llm-masking or decompx-masking with marco generation**.

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

# try my drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("try mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using xdetox_dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("transformers_cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("repo folders ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score


In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from rewrite.mask_orig import Masker as Masker_single
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("ok")

In [None]:
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

In [None]:
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
def _decompx_mask_texts(texts: List[str],
                        threshold: float = 0.20,
                        batch_size: int = 16) -> List[str]:
    if not texts:
        return []

    masker = Masker_single()
    masked_all = []
    for i in tqdm(range(0, len(texts), batch_size),
                  desc="decompx masking for reranking", leave=False):
        batch = texts[i:i + batch_size]
        batch_out = masker.process_text(sentence=batch, threshold=threshold)
        masked_all.extend(batch_out)
    cleaned = [
        m.replace("<s>", "").replace("</s>", "").strip()
        for m in masked_all
    ]
    masker.release_model()
    return cleaned

def _decompx_toxicity_scores(texts: List[str],
                             threshold: float = 0.20,
                             batch_size: int = 16) -> np.ndarray:
    if not texts:
        return np.zeros((0,), dtype=float)

    masked = _decompx_mask_texts(texts, threshold=threshold, batch_size=batch_size)
    scores = []
    for m in masked:
        num_masks = len(re.findall(r"<mask>", m))
        tokens = m.split()
        length = max(len(tokens), 1)
        scores.append(num_masks / length)
    return np.asarray(scores, dtype=float)

def rerank_candidates_decompx(
    sources: List[str],
    candidates: List[List[str]],
    threshold: float = 0.20,
    batch_size_mask: int = 16,
):
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    scores = _decompx_toxicity_scores(
        flat_cands,
        threshold=threshold,
        batch_size=batch_size_mask,
    )

    scores2 = scores.reshape(N, C)
    best_idx = np.argmin(scores2, axis=1)

    details = {
        "score": scores2,
    }
    return best_idx, details

In [None]:
def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []

    mask_dir = "LLM_Mask_LLM_DecompX"
    base_path = os.path.join(base_out_dir, data_type, mask_dir)
    if not os.path.isdir(base_path):
        print("no evaluation directory found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "threshold":        0.20,
            "folder":           folder,
            "bertscore":        s.get("bertscore", np.nan),
            "meaningbert":      s.get("meaningbert", np.nan),
            "bleu4":            s.get("bleu4", np.nan),
            "perplexity_gen":   s.get("perplexity gen", np.nan),
            "perplexity_orig":  s.get("perplexity orig", np.nan),
            "toxicity_gen":     s.get("toxicity gen", np.nan),
            "toxicity_orig":    s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no evaluation files found")

In [None]:
USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("llm device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

def _lazy_load_llm():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"loading llm: {LLM_MODEL_NAME} on {DEVICE_LLM}")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("llm loaded")


def _extract_bracket_content(text: str) -> str:
    text = text.strip()

    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    if "[" in text:
        return text.split("[", 1)[1].strip()

    return text

In [None]:
MASK_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences safer
by MASKING toxic words with the special token <mask>, while keeping the original sentence
structure and wording as much as possible.

You must behave like a MASKER, not a full rewriter.

Your task:
1. Identify toxic, offensive, or profane words or short phrases.
2. For each toxic span, replace the entire span with a single <mask> token.
3. There may be multiple toxic spans in one sentence, so multiple <mask> tokens are allowed.
4. If several neighboring words are toxic, you must still use only a single <mask> token
   in that place. In other words, if you would place "<mask> <mask>" or a longer sequence
   of <mask> tokens, collapse them into a single <mask> so that there are never multiple
   <mask> tokens in a row.
5. Do NOT rewrite, paraphrase, or summarize the sentence.
6. Do NOT add, remove, or reorder non-toxic words or punctuation.
7. Keep punctuation and spacing as close to the original as possible.
8. If there is no toxic content, return the sentence unchanged.

Output rules (format is very strict):
- ONLY return the final masked sentence inside ONE pair of square brackets, like:
  [This is a <mask> example.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any language tags or metadata.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

MASK_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Step 1 - Identify toxic words: "stupid idiot", "crap"
Step 2 - Mask toxic words (do NOT rewrite the rest):
You're such a <mask>, nobody wants to hear your <mask>.
Final Output: [You're such a <mask>, nobody wants to hear your <mask>.]"""

def _postprocess_llm_mask(masked_text: str) -> str:
    s = masked_text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()
    s = re.sub(r"<\s*mask\s*>", "<mask>", s, flags=re.IGNORECASE)
    s = re.sub(r"(?:\s*<mask>\s*){2,}", " <mask> ", s)
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return masked_text.strip()

    return s

@torch.no_grad()
def llm_mask_sentences(sentences: List[str]) -> List[str]:
    _lazy_load_llm()
    masked = []
    for s in tqdm(sentences, desc="llm masking", leave=False):
        messages = [
            {
                "role": "system",
                "content": MASK_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + MASK_FEW_SHOT,
            },
            {
                "role": "user",
                "content": f"Toxic Sentence: {s}\nFinal Output:",
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                MASK_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + MASK_FEW_SHOT
                + "\n\nToxic Sentence: "
                + s
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            temperature=0.0,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        gen_text = _LLM_TOKENIZER.decode(
            gen[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        masked_text = _extract_bracket_content(gen_text)
        masked_text = _postprocess_llm_mask(masked_text)

        if not masked_text:
            masked_text = s
        masked.append(masked_text)

    return masked

In [None]:
INFILL_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences
more polite and respectful by INFILLING the special token <mask>.

You are NOT a free rewriter. You must keep all non-masked text as close as possible
to the given masked sentence.

You are given two inputs:
1) Toxic Sentence: the original toxic sentence.
2) Masked Sentence: the same sentence, where toxic spans are replaced with <mask>.

Your task:
1. For each <mask> token in the Masked Sentence, replace it with a short, non-toxic
   word or phrase that fits the context and preserves the meaning of the Toxic Sentence.
2. Do NOT modify any other words or punctuation outside the <mask> spans, unless a very
   small change is needed to fix grammar or agreement.
3. Preserve the original meaning and intent as much as possible, but make the sentence
   safe and respectful.
4. Keep the language the same as the original (do NOT translate).

Output rules (VERY STRICT):
- ONLY return the final detoxified sentence with all <mask> tokens filled.
- Wrap the final sentence in exactly ONE pair of square brackets, e.g.:
  [Detoxified sentence here.]
- Do NOT include the Toxic Sentence or Masked Sentence in your output.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any other '[' or ']' characters.
"""

INFILL_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Masked Sentence: You're such a <mask>, nobody wants to hear your <mask>.
Step 1 - Decide safe replacements for each <mask>: "rude person", "opinion"
Step 2 - Infill the masked sentence, keeping all other words the same:
You're such a rude person, nobody wants to hear your opinion.
Final Output: [You're such a rude person, nobody wants to hear your opinion.]"""

def _postprocess_llm_infill(text: str) -> str:
    s = text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()

    s = s.replace("<mask>", " ")
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return text.strip()
    return s

@torch.no_grad()
def llm_infill_candidates(
    sources: List[str],
    masked: List[str],
    num_candidates: int = 3,
    temperature: float = 0.7,
    top_p: float = 0.95,
    max_new_tokens: int = 64,
    sample: bool = True,
) -> List[List[str]]:
    assert len(sources) == len(masked), "sources and masked length mismatch"
    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    _lazy_load_llm()
    all_cands: List[List[str]] = []

    for src, msk in tqdm(
        list(zip(sources, masked)),
        desc="llm infilling",
        leave=False,
    ):
        messages = [
            {
                "role": "system",
                "content": INFILL_SYSTEM_PROMPT + "\n\nHere is an example:\n" + INFILL_FEW_SHOT,
            },
            {
                "role": "user",
                "content": (
                    f"Toxic Sentence: {src}\n"
                    f"Masked Sentence: {msk}\n"
                    f"Final Output:"
                ),
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                INFILL_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + INFILL_FEW_SHOT
                + "\n\nToxic Sentence: "
                + src
                + "\nMasked Sentence: "
                + msk
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        input_len = inputs["input_ids"].shape[1]

        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=sample,
            temperature=float(temperature) if sample else 0.0,
            top_p=top_p,
            num_return_sequences=num_candidates,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )

        cand_list = []
        for idx in range(num_candidates):
            gen_text = _LLM_TOKENIZER.decode(
                gen[idx][input_len:],
                skip_special_tokens=True,
            )
            cleaned = _extract_bracket_content(gen_text)
            cleaned = _postprocess_llm_infill(cleaned)
            if not cleaned:
                cleaned = src
            cand_list.append(cleaned)

        all_cands.append(cand_list)

    return all_cands

In [None]:
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_run_folder_name(
    llm_temperature: float,
    llm_top_p: float,
    llm_sample: bool,
    num_candidates: int,
    max_new_tokens: int,
    decompx_threshold: float,
):
    return (
        f"llmtemp{llm_temperature}_topp{llm_top_p}_"
        f"sample{_bool2str(llm_sample)}_"
        f"nc{num_candidates}_"
        f"maxntok{max_new_tokens}_"
        f"dxth{decompx_threshold}"
    )

In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_infill_decompx",
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
    decompx_threshold: float = 0.20,
    decompx_batch_size_mask: int = 16,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {base_out_abs}")
        print(f"num examples: {num_inputs}")
        print(f"num_candidates: {num_candidates}")
        print(f"decompx threshold: {decompx_threshold}")
        print("\nexample inputs (first 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    mask_dir = "LLM_Mask_LLM_DecompX"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        print("running llm masking to create masked_inputs.txt")
        masked_inputs = llm_mask_sentences(inputs)
        masked_inputs = [re.sub(r"\s+", " ", d).strip() for d in masked_inputs]
        with open(masked_file, "w") as f:
            for d in masked_inputs:
                f.write(d + "\n")
    else:
        with open(masked_file, "r") as f:
            masked_inputs = [s.strip() for s in f.readlines()]
        print("reusing existing masked_inputs.txt")

    assert len(masked_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print("\nexample llm-masked inputs (first 3):")
        for i, m in enumerate(masked_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    run_folder = _build_run_folder_name(
        llm_temperature=llm_temperature,
        llm_top_p=llm_top_p,
        llm_sample=llm_sample,
        num_candidates=num_candidates,
        max_new_tokens=llm_max_new_tokens,
        decompx_threshold=decompx_threshold,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)
    orig_txt = os.path.join(final_abs, "orig.txt")
    gen_txt = os.path.join(final_abs, "gen.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("gen already exists at:", gen_txt, "— skipping.")
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]
        if echo:
            print("\nexample detoxified outputs (first 3):")
            for i in range(min(3, len(best_generations))):
                print(f"  detox[{i}]: {best_generations[i]}")
    else:
        print(f"llm infilling: generating {num_candidates} candidates per input (sampling={llm_sample})")
        all_candidates = llm_infill_candidates(
            sources=inputs,
            masked=masked_inputs,
            num_candidates=num_candidates,
            temperature=llm_temperature,
            top_p=llm_top_p,
            max_new_tokens=llm_max_new_tokens,
            sample=llm_sample,
        )

        global _LLM_MODEL, _LLM_TOKENIZER
        try:
            del _LLM_MODEL
            del _LLM_TOKENIZER
        except Exception:
            pass
        _LLM_MODEL = None
        _LLM_TOKENIZER = None
        if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
            torch.cuda.empty_cache()

        print(f"decompx reranking (threshold={decompx_threshold:.2f})")
        best_idx, details = rerank_candidates_decompx(
            sources=inputs,
            candidates=all_candidates,
            threshold=decompx_threshold,
            batch_size_mask=decompx_batch_size_mask,
        )
        best_generations = [
            all_candidates[i][best_idx[i]] for i in range(len(inputs))
        ]

        with open(orig_txt, "w") as f:
            for l in inputs:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")
        with open(gen_txt, "w") as f:
            for l in best_generations:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")

        print("saved:", orig_txt)
        print("saved:", gen_txt)

        if echo:
            print("\nexample detoxified outputs (first 3):")
            for i in range(min(3, len(best_generations))):
                print(f"  detox[{i}]: {best_generations[i]}")

    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, mask_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if echo:
            stats_path = os.path.join(final_abs, "gen_stats.txt")
            if os.path.exists(stats_path):
                stats = _read_stats_file(stats_path)
                print("\neval metrics for this run:")
                metric_keys = [
                    ("bertscore", "bertscore"),
                    ("meaningbert", "meaningbert"),
                    ("bleu4", "bleu-4"),
                    ("perplexity gen", "perplexity (gen)"),
                    ("perplexity orig", "perplexity (orig)"),
                    ("toxicity gen", "toxicity (gen)"),
                    ("toxicity orig", "toxicity (orig)"),
                ]
                for key, label in metric_keys:
                    val = stats.get(key, None)
                    if isinstance(val, float) and math.isnan(val):
                        continue
                    if val is None:
                        continue
                    print(f"  {label}: {val:.4f}")
            else:
                print("\ngen_stats.txt not found")

In [None]:
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_infill_decompx_demo_50_examples",
#     echo=True,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     num_candidates=10,
#     llm_temperature=0.7,
#     llm_top_p=0.95,
#     llm_max_new_tokens=64,
#     llm_sample=True,
#     decompx_threshold=0.20,
#     decompx_batch_size_mask=16,
# )

In [17]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_LLM-Masking_LLM-Infilling_DecompX-Reranking_Pipeline",
    echo=True,
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,         # LLM candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    decompx_threshold=0.20,
    decompx_batch_size_mask=16,
)

[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_DecompX-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] num_candidates per input: 10
[echo] DecompX threshold (reranking): 0.2

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Running LLM masking (Mistral) to create masked_inputs.txt ...
Loading LLM (Mistral-7B-Instruct): mistralai/Mistral-7B-Instruct-v0.2 on cuda ...


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM loaded.


LLM masking (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]




[echo] Example LLM-masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated and "<mask> right <mask> now <mask> " would be good .
  masked[2]: neither of my coworkers gave a <mask> when it came time to ditch mitch . ugh .
LLM infilling: generating 10 candidates per input (sampling=True)


LLM infilling (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]

DecompX reranking (threshold=0.20) ...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/420 [00:00<?, ?it/s]



Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/LLM_Mask_LLM_DecompX/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64_dxth0.2/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/LLM_Mask_LLM_DecompX/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64_dxth0.2/gen.txt

[echo] Example detoxified outputs (first up to 3):
  detox[0]: or the loud obnoxious one - thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and "immediate" "soon" "would be good".
  detox[2]: neither of my coworkers showed concern when it came time to let Mitch go . ugh.
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_DecompX-Reranking_Pipeline/paradetox/LLM_Mask_LLM_DecompX/llmtemp0.7_topp0.95_sampleT_nc10