# xdetox with llm masking and decompx reranking

This notebook runs an XDetox pipeline with:

1. **llm masking** using Mistral-7B-Instruct (`mistralai/Mistral-7B-Instruct-v0.2`), which detects toxic spans and replaces them with `<mask>`.
2. **marco-style generation** (base / expert / anti-expert BART mixture).
3. **decompx-based reranking** of multiple candidates per input using token-level toxicity importances.

The goal is to pick, for each toxic input sentence, **one best detoxified candidate** that is:

- As **non-toxic** as possible under DecompX.
- Still a reasonable rewrite of the original sentence.
- Grammatically acceptable (enforced by MaRCo, not by a separate fluency model).

Unlike the global-reranking pipeline, this notebook uses **only DecompX** for reranking (no XLM-R, LaBSE, or GPT-2 in the scoring).

---

## decompx-based reranking

Reranking uses DecompX on **each candidate sentence** to measure its contribution to toxicity.

For a candidate sentence $s_j$ with $N_j$ tokens $t_{1,j}, \dots, t_{N_j,j}$, DecompX assigns a toxicity importance score $\text{Importance}(t_{i,j})$ to each token $t_{i,j}$.

We define the DecompX toxicity score of $s_j$ as the sum of these importances:

$$
\text{DecompXScore}(s_j) = \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
$$

The reranking step then selects the candidate with the **lowest** cumulative toxicity:

$$
s^* = \arg\min_{s_j} \sum_{i=1}^{N_j} \text{Importance}(t_{i,j})
$$

This follows Equation (3) from the XDetox paper: the chosen sentence $s^*$ is the one whose tokens have the smallest total contribution to toxicity.

### implementation note

In the code, DecompX is accessed through `Masker_single`, which:

- Computes token-level importance scores internally.
- Masks tokens whose importance exceeds a threshold `decompx_threshold` by replacing them with `<mask>`.

To obtain a scalar toxicity score for reranking, we use the DecompX-masked sentence as a simple, monotonic proxy for the summed importances. Concretely, for each candidate sentence $s_j$:

- Let $\tilde{s}_j$ be its DecompX-masked version (with some tokens replaced by `<mask>`).
- We compute

  $$
  \widehat{\text{DecompXScore}}(s_j)
  = \frac{\#\{\text{`<mask>' tokens in } \tilde{s}_j\}}{\text{number of tokens in } s_j}
  $$

- We then choose the candidate with the **smallest** $\widehat{\text{DecompXScore}}(s_j)$ for each input.

This approximates the ideal scoring from the paper: more and/or larger toxic spans (higher cumulative importance) lead to more `<mask>` tokens and thus a higher $\widehat{\text{DecompXScore}}$.

---

## llm masking (mistral-7b-instruct)

### prompted masking behavior

Masking is done by a chat-style LLM:

- Model: `mistralai/Mistral-7B-Instruct-v0.2`.
- The LLM is instructed to:
  - Identify **toxic, offensive, or profane words or short phrases**.
  - Replace **each toxic span** with a **single `<mask>` token**.
  - Allow **multiple `<mask>` tokens** in one sentence (one per toxic span).
  - If several neighboring words are toxic, **collapse them into one `<mask>`**  
    (do not output `<mask> <mask>` or longer runs).
  - Keep **all non-toxic words and punctuation in place**.
  - **Not** paraphrase, summarize, or reorder the sentence.
  - Return the masked sentence **inside exactly one pair of brackets**:

    ```text
    [This is a <mask> example.]
    ```

### post-processing of llm masks

Because LLM output is not always perfectly formatted, the notebook cleans the raw masked output before passing it to MaRCo:

For each LLM output:

1. **Extract bracket content**:

   - Try to read the first `[ ... ]` block.
   - If there is `[` but no `]`, take everything after the first `[` as the sentence.
   - If no brackets exist, fall back to the full string.

2. **Strip stray outer brackets** if they still remain.

3. **Normalize whitespace**:

   - Collapse multiple spaces to a single space.
   - Trim leading and trailing spaces.

4. **Normalize `<mask>` casing and spacing**:

   - Any variant like `<Mask>`, `<MASK>`, `< mask >` is normalized to `<mask>`.

5. **Collapse runs of `<mask>`**:

   - Any sequence like `<mask> <mask> <mask>` is collapsed to a single `<mask>`.

6. If the cleaned sentence becomes empty, fall back to the original masked text.

All cleaned, LLM-masked sentences are saved once per dataset and output folder:

```text
data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/masked_inputs.txt
````

If this file already exists, it is reused rather than re-calling the LLM.

---

## marco generation (bart base / expert / anti-expert)

After LLM masking, generation uses the **MaRCo** setup:

1. **Models**:

   * Base: `facebook/bart-base`
   * Anti-expert (toxic): `hallisky/bart-base-toxic-antiexpert`
   * Expert (non-toxic): `hallisky/bart-base-nontoxic-expert`

   These are combined via `rewrite.generation.Infiller`.

2. **inputs**:

   * Original sentences (for reference).
   * LLM-masked sentences (with `<mask>` tokens), used as infilling prompts.

3. **hyperparameters**:

   * `alpha_a`, `alpha_e`, `alpha_b`: anti-expert, expert, base weights.
   * `temperature`: sampling temperature.
   * `top_k_gen`: top-k for sampling.
   * `top_p`: nucleus sampling on ensembled logits.
   * `filter_p`: nucleus filter on base logits (advanced).
   * `rep_penalty`: repetition penalty.
   * `max_length`: maximum generation length.
   * `sample`: sampling (`True`) vs greedy (`False`).
   * `batch_size`: generation batch size.

4. **multiple candidates per input**:

   * The notebook calls `Infiller.generate(...)` **`num_candidates` times**.
   * Each call produces one candidate per input (with sampling, if enabled).
   * Candidates are collected as:

     ```python
     all_candidates[i] = [cand_0, cand_1, ..., cand_{num_candidates-1}]
     ```

5. **folder structure**:

   Generated outputs go to a run folder whose name encodes the main decoding settings:

   ```text
   data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/{run_folder}/
   ```

   where `{run_folder}` is built from:

   * `alpha_a`, `alpha_e`, `alpha_b`
   * `base_type`, `antiexpert_type`, `expert_type`
   * `temperature`, `sample`, `top_k_gen`, `rep_penalty`, `filter_p`, `max_length`, `top_p`

In each run folder, the notebook writes:

* `orig.txt` — original inputs (one per line).
* `gen.txt` — DecompX-selected candidates (one per line).

---

## evaluation

If `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* BERTScore (F1)
* MeaningBERT
* BLEU-4
* Toxicity (orig / gen) using XLM-R
* Perplexity (orig / gen) using GPT-2

For each `{run_folder}`, it writes:

```text
data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/{run_folder}/gen_stats.txt
```

Then `_aggregate_eval_csv` collects all runs under `LLM_Masking_DecompX/` for that dataset and writes:

```text
data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
```

The CSV contains, per run folder:

* `folder`: run folder name (encodes hyperparameters),
* `threshold`: fixed label `0.20` (kept for compatibility; it is not used for masking here),
* `bertscore`, `meaningbert`, `bleu4`,
* `perplexity_gen`, `perplexity_orig`,
* `toxicity_gen`, `toxicity_orig`.

---

## how to use `detoxify()`

Function signature:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_decompx",
    echo: bool = False,
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 3,
    decompx_threshold: float = 0.20,
)
```

### key arguments

#### core i/o

* `data_type`: dataset key from `data_configs`, for example:

  * `"paradetox"`, `"dynabench_val"`, `"dynabench_test"`,
  * `"jigsaw_toxic"`, `"microagressions_val"`, `"sbf_val"`,
  * `"appdia_original"`, `"appdia_discourse"`, etc.

* `output_folder`: folder under `data/model_outputs/` where results are stored:

  ```text
  data/model_outputs/{output_folder}/{data_type}/...
  ```

* `num_examples`:

  * If set to an integer, only the first `num_examples` examples from the dataset are used.
  * If `None`, the full dataset is used.

#### llm masking (mistral)

* Masking behavior is controlled by a fixed system prompt and a few-shot example.
* The notebook caches masked sentences to:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/masked_inputs.txt
  ```

  and reuses this file when it exists.

#### generation (marco / bart)

* `sample`: `True` for stochastic sampling, `False` for greedy decoding.

* `top_k_gen`: top-k sampling on the ensembled logits.

* `top_p`: nucleus sampling on the ensembled logits.

* `filter_p`: nucleus filter on the base logits (often left at `1.0`).

* `max_length`: maximum sequence length.

* `alpha_a`, `alpha_e`, `alpha_b`:

  * Anti-expert, expert, base mixture weights.
  * If `None`, defaults are taken from `data_configs[data_type]`.

* `temperature`: sampling temperature (falls back to `data_configs[data_type]["temperature"]`).

* `rep_penalty`: repetition penalty (falls back to `data_configs[data_type]["rep_penalty"]`).

* `batch_size`: generation batch size.

* `num_candidates`: number of candidates to generate for each input.

#### decompx reranking

* `decompx_threshold`:

  * Threshold used by DecompX when deciding which tokens to mask.
  * Controls how aggressively tokens are treated as toxic when computing the DecompX-based score.
  * The candidate with the **lowest** approximate DecompX score is selected.

#### evaluation

* `run_eval`: if `True`, run evaluation and write `gen_stats.txt` plus summary CSV.
* `overwrite_gen`: if `True`, regenerate outputs even if `gen.txt` exists.
* `overwrite_eval`: if `True`, recompute evaluation even if `gen_stats.txt` exists.
* `skip_ref_eval`: if `True`, skip some reference-based evaluation (for example, perplexity on references).

#### echo / debugging

* `echo`: if `True`, print:

  * Dataset and output paths.
  * Number of examples.
  * Example inputs.
  * Example LLM-masked inputs.
  * Example final detoxified outputs.
  * Evaluation metrics for this run (if `run_eval=True`).

---

## example calls

### quick sanity check on a small subset

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_llm_mask_decompx_demo_50_examples",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,
    decompx_threshold=0.20,
)
```

### larger run (full dataset)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_llm_mask_decompx_full",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=None,
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=False,
    skip_ref_eval=False,
    num_candidates=10,
    decompx_threshold=0.20,
)
```

After running `detoxify`, you can inspect:

* Final chosen generations:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/{run_folder}/gen.txt
  ```

* Original inputs:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/{run_folder}/orig.txt
  ```

* Per-run evaluation metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking_DecompX/{run_folder}/gen_stats.txt
  ```

* Aggregated metrics over runs:

  ```text
  data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
  ```

This notebook is designed to be directly comparable to:

* **decompx-masking + decompx-reranking** (original XDetox pipeline), and
* **llm-masking + global reranking** (toxicity + similarity + fluency),

so you can isolate the effect of **llm-based masking** vs **decompx-based masking**, and compare **decompx-only reranking** to **multi-objective global reranking**.

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("try mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using xdetox_dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("transformers_cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("repo folders ok")


In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
)
from rewrite.generation import Infiller
from rewrite.mask_orig import Masker as Masker_single
from rewrite import rewrite_example as rx
import argparse as _argparse


In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("nltk ready")


In [None]:
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR


In [None]:
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
def _parse_run_folder_name(folder_name):
    pattern = (
        r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_"
        r"temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_"
        r"maxlength(\d+)_topp(\d+\.\d+)"
    )
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []

    mask_dir = "LLM_Masking_DecompX"
    base_path = os.path.join(base_out_dir, data_type, mask_dir)
    if not os.path.isdir(base_path):
        print("no evaluation directory found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "threshold":        0.20,
            "folder":           folder,
            "bertscore":        s.get("bertscore", np.nan),
            "meaningbert":      s.get("meaningbert", np.nan),
            "bleu4":            s.get("bleu4", np.nan),
            "perplexity_gen":   s.get("perplexity gen", np.nan),
            "perplexity_orig":  s.get("perplexity orig", np.nan),
            "toxicity_gen":     s.get("toxicity gen", np.nan),
            "toxicity_orig":    s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no evaluation files found to summarize")


In [None]:
USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("llm masker device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

MASK_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences safer
by MASKING toxic words with the special token <mask>, while keeping the original sentence
structure and wording as much as possible.

You must behave like a MASKER, not a full rewriter.

Your task:
1. Identify toxic, offensive, or profane words or short phrases.
2. For each toxic span, replace the entire span with a single <mask> token.
3. There may be multiple toxic spans in one sentence, so multiple <mask> tokens are allowed.
4. If several neighboring words are toxic, you must still use only a single <mask> token
   in that place. In other words, if you would place "<mask> <mask>" or a longer sequence
   of <mask> tokens, collapse them into a single <mask> so that there are never multiple
   <mask> tokens in a row.
5. Do NOT rewrite, paraphrase, or summarize the sentence.
6. Do NOT add, remove, or reorder non-toxic words or punctuation.
7. Keep punctuation and spacing as close to the original as possible.
8. If there is no toxic content, return the sentence unchanged.

Output rules (format is very strict):
- ONLY return the final masked sentence inside ONE pair of square brackets, like:
  [This is a <mask> example.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any language tags or metadata.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

MASK_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Step 1 - Identify toxic words: "stupid idiot", "crap"
Step 2 - Mask toxic words (do NOT rewrite the rest):
You're such a <mask>, nobody wants to hear your <mask>.
Final Output: [You're such a <mask>, nobody wants to hear your <mask>.]"""

def _lazy_load_llm_masker():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"loading llm masker: {LLM_MODEL_NAME} on {DEVICE_LLM} ...")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("llm masker loaded")

def _extract_bracket_content(text: str) -> str:
    text = text.strip()

    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    if "[" in text:
        return text.split("[", 1)[1].strip()

    return text

def _postprocess_llm_mask(masked_text: str) -> str:
    s = masked_text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()
    s = re.sub(r"<\s*mask\s*>", "<mask>", s, flags=re.IGNORECASE)
    s = re.sub(r"(?:\s*<mask>\s*){2,}", " <mask> ", s)
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return masked_text.strip()
    return s

@torch.no_grad()
def llm_mask_sentences(sentences: List[str]) -> List[str]:
    _lazy_load_llm_masker()
    masked = []
    for s in tqdm(sentences, desc="llm masking", leave=False):
        messages = [
            {
                "role": "system",
                "content": MASK_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + MASK_FEW_SHOT,
            },
            {
                "role": "user",
                "content": f"Toxic Sentence: {s}\nFinal Output:",
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                MASK_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + MASK_FEW_SHOT
                + "\n\nToxic Sentence: "
                + s
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            temperature=0.0,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        gen_text = _LLM_TOKENIZER.decode(
            gen[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        masked_text = _extract_bracket_content(gen_text)
        masked_text = _postprocess_llm_mask(masked_text)

        if not masked_text:
            masked_text = s
        masked.append(masked_text)

    return masked

In [None]:
def _decompx_mask_texts(texts: List[str],
                        threshold: float = 0.20,
                        batch_size: int = 16) -> List[str]:
    masker = Masker_single()
    masked_all = []
    for i in tqdm(range(0, len(texts), batch_size),
                  desc="decompx masking", leave=False):
        batch = texts[i:i + batch_size]
        batch_out = masker.process_text(sentence=batch, threshold=threshold)
        masked_all.extend(batch_out)
    cleaned = [
        m.replace("<s>", "").replace("</s>", "").strip()
        for m in masked_all
    ]
    masker.release_model()
    return cleaned

def _decompx_toxicity_scores(texts: List[str],
                             threshold: float = 0.20,
                             batch_size: int = 16) -> np.ndarray:
    if not texts:
        return np.zeros((0,), dtype=float)

    masked = _decompx_mask_texts(texts, threshold=threshold, batch_size=batch_size)
    scores = []
    for m in masked:
        num_masks = len(re.findall(r"<mask>", m))
        tokens = m.split()
        length = max(len(tokens), 1)
        scores.append(num_masks / length)
    return np.asarray(scores, dtype=float)

def rerank_candidates_decompx(
    sources: List[str],
    candidates: List[List[str]],
    threshold: float = 0.20,
    batch_size_mask: int = 16,
):
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    scores = _decompx_toxicity_scores(
        flat_cands,
        threshold=threshold,
        batch_size=batch_size_mask,
    )

    scores2 = scores.reshape(N, C)
    best_idx = np.argmin(scores2, axis=1)

    details = {
        "score": scores2,
    }
    return best_idx, details


In [None]:
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )

def _run_llm_masking_and_decompx_reranking(
    data_type,
    subset_path,
    base_out_rel,
    batch_size,
    alpha_a, alpha_e, alpha_b,
    temperature,
    rep_penalty,
    max_length,
    top_k_gen,
    top_p,
    filter_p,
    sample,
    num_candidates,
    decompx_threshold,
    overwrite_gen=False,
    inputs=None,
):
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"inputs to detoxify: {len(inputs)}")

    mask_dir = "LLM_Masking_DecompX"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        print("running llm masking to create masked_inputs.txt ...")
        decoded_mask_inputs = llm_mask_sentences(inputs)
        decoded_mask_inputs = [
            re.sub(r"\s+", " ", d).strip() for d in decoded_mask_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(d + "\n")

        global _LLM_MODEL, _LLM_TOKENIZER
        del _LLM_MODEL
        del _LLM_TOKENIZER
        _LLM_MODEL = None
        _LLM_TOKENIZER = None
        if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
            torch.cuda.empty_cache()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    rewriter = Infiller(
        seed=0,
        base_path="facebook/bart-base",
        antiexpert_path="hallisky/bart-base-toxic-antiexpert",
        expert_path="hallisky/bart-base-nontoxic-expert",
        base_type="base",
        antiexpert_type="antiexpert",
        expert_type="expert",
        tokenizer="facebook/bart-base",
    )

    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"
    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("generation already exists at:", gen_txt, "— skipping")
        _ensure_dir(final_abs)
        if not os.path.exists(orig_txt):
            with open(orig_txt, "w") as f:
                for l in inputs:
                    f.write(re.sub(r"\s+", " ", l).strip() + "\n")

        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]

        return inputs, decoded_mask_inputs, best_generations, final_abs

    _ensure_dir(final_abs)

    all_candidates: List[List[str]] = [[] for _ in range(len(inputs))]

    print(f"generating {num_candidates} candidates per input (sampling={sample})")
    for c in range(num_candidates):
        outs, decoded = rewriter.generate(
            inputs,
            decoded_mask_inputs,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            verbose=False,
            max_length=max_length,
            repetition_penalty=rep_penalty,
            p=top_p,
            filter_p=filter_p,
            k=top_k_gen,
            batch_size=batch_size,
            sample=sample,
            ranking=False,
            ranking_eval_output=0,
        )
        for i, text in enumerate(decoded):
            all_candidates[i].append(re.sub(r"\s+", " ", text).strip())

    del rewriter
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print(f"decompx reranking (threshold={decompx_threshold:.2f}) ...")
    best_idx, details = rerank_candidates_decompx(
        sources=inputs,
        candidates=all_candidates,
        threshold=decompx_threshold,
        batch_size_mask=16,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("saved:", orig_txt)
    print("saved:", gen_txt)

    return inputs, decoded_mask_inputs, best_generations, final_abs


In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_decompx",
    echo: bool = False,
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 3,
    decompx_threshold: float = 0.20,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    if alpha_a is None:
        alpha_a = cfg["alpha_a"]
    if alpha_e is None:
        alpha_e = cfg["alpha_e"]
    if temperature is None:
        temperature = cfg["temperature"]
    if rep_penalty is None:
        rep_penalty = cfg["rep_penalty"]

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {base_out_abs}")
        print(f"number of examples to detoxify: {num_inputs}")
        print(f"num_candidates per input: {num_candidates}")
        print(f"DecompX threshold (reranking): {decompx_threshold}")
        print("\nexample inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    inputs, masked_inputs, best_generations, run_dir = _run_llm_masking_and_decompx_reranking(
        data_type=data_type,
        subset_path=subset_path,
        base_out_rel=base_out_rel,
        batch_size=batch_size,
        alpha_a=alpha_a,
        alpha_e=alpha_e,
        alpha_b=alpha_b,
        temperature=temperature,
        rep_penalty=rep_penalty,
        max_length=max_length,
        top_k_gen=top_k_gen,
        top_p=top_p,
        filter_p=filter_p,
        sample=sample,
        num_candidates=num_candidates,
        decompx_threshold=decompx_threshold,
        overwrite_gen=overwrite_gen,
        inputs=inputs,
    )

    if echo:
        print("\nexample masked inputs (first up to 3):")
        for i, m in enumerate(masked_inputs[:3]):
            print(f"  masked[{i}]: {m}")

        print("\nExample detoxified outputs (first up to 3):")
        for i in range(min(3, len(best_generations))):
            print(f"  detox[{i}]: {best_generations[i]}")

    if run_eval:
        mask_dir = "LLM_Masking_DecompX"
        base_path = os.path.join(base_out_abs, data_type, mask_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if echo:
            stats_path = os.path.join(run_dir, "gen_stats.txt")
            if os.path.exists(stats_path):
                stats = _read_stats_file(stats_path)
                print("\nevaluation metrics for this run:")
                metric_keys = [
                    ("bertscore", "bertscore"),
                    ("meaningbert", "meaningbert"),
                    ("bleu4", "bleu-4"),
                    ("perplexity gen", "perplexity (gen)"),
                    ("perplexity orig", "perplexity (orig)"),
                    ("toxicity gen", "toxicity (gen)"),
                    ("toxicity orig", "toxicity (orig)"),
                ]
                for key, label in metric_keys:
                    val = stats.get(key, None)
                    if isinstance(val, float) and math.isnan(val):
                        continue
                    if val is None:
                        continue
                    print(f"  {label}: {val:.4f}")
            else:
                print("\ngen_stats.txt not found for this run")



In [None]:
# example: small demo (adjust num_examples as you like)
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_decompx_demo_50_examples",
#     echo=True,
#     batch_size=8,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     num_candidates=10,
#     decompx_threshold=0.20,
# )


In [30]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_LLM-Masking_DecompX-Reranking_Pipeline",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,
    decompx_threshold=0.20
)


[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_DecompX-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] num_candidates per input: 10
[echo] DecompX threshold (reranking): 0.2

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
#inputs to detoxify: 671
Running LLM masking (Mistral) to create masked_inputs.txt ...
Loading LLM masker: mistralai/Mistral-7B-Instruct-v0.2 on cuda ...


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM masker loaded.


LLM masking (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]



Found 1 GPUS!
Generating 10 candidates per input (sampling=True)


Filling in masks: 100%|██████████| 84/84 [01:33<00:00,  1.11s/it]
Filling in masks: 100%|██████████| 84/84 [01:40<00:00,  1.19s/it]
Filling in masks: 100%|██████████| 84/84 [01:46<00:00,  1.27s/it]
Filling in masks: 100%|██████████| 84/84 [01:30<00:00,  1.07s/it]
Filling in masks: 100%|██████████| 84/84 [01:51<00:00,  1.33s/it]
Filling in masks: 100%|██████████| 84/84 [01:42<00:00,  1.22s/it]
Filling in masks: 100%|██████████| 84/84 [01:24<00:00,  1.00s/it]
Filling in masks: 100%|██████████| 84/84 [01:43<00:00,  1.23s/it]
Filling in masks: 100%|██████████| 84/84 [01:35<00:00,  1.14s/it]
Filling in masks: 100%|██████████| 84/84 [01:33<00:00,  1.11s/it]


DecompX reranking (threshold=0.20) ...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/420 [00:00<?, ?it/s]



Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_DecompX-Reranking_Pipeline/paradetox/LLM_Masking_DecompX/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_DecompX-Reranking_Pipeline/paradetox/LLM_Masking_DecompX/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/gen.txt

[echo] Example masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated and "<mask> right <mask> now <mask> " would be good .
  masked[2]: neither of my coworkers gave a <mask> when it came time to ditch mitch . ugh .

[echo] Example detoxified outputs (first up to 3):
  detox[0]: . or the loud sound of a one- thousand ton beast roaring towards you h