# xdetox with decompx masking, llm infilling, and global reranking

This notebook runs an XDetox variant with:

1. **DecompX masking** (token-level toxicity attribution on RoBERTa).
2. **LLM infilling** using Mistral-7B-Instruct  
   (`mistralai/Mistral-7B-Instruct-v0.2`).
3. **Global reranking** of LLM candidates using:
   - **Toxicity** (XLM-R large classifier).
   - **Semantic similarity** (LaBSE).
   - **fluency** (GPT-2 perplexity).

For each toxic input sentence, the goal is to choose **one best detoxified candidate** that is:

- As **non-toxic** as possible.
- As **semantically close** as possible to the original.
- As **fluent** as possible.

Compared to the original XDetox pipeline (DecompX masking + MaRCo BART infilling + global reranking), this notebook:

- Keeps **DecompX masking**,
- Replaces **MaRCo BART generation** with **LLM infilling**,
- Keeps the **same global reranking** scheme.

---

## scoring: global reranking

For each candidate $c$, we compute:

- $T(c)$: toxicity in $[0, 1]$ from  
  `textdetox/xlmr-large-toxicity-classifier-v2`.
- $S(c)$: semantic similarity in $[0, 1]$, from LaBSE cosine similarity.
- $F(c)$: fluency in $[0, 1]$, from GPT-2 perplexity mapped into a score  
  (low perplexity → high fluency).

We convert toxicity into a **safety** score:

$$
T'(c) = 1 - T(c)
$$

Then we form a **global score**:

$$
\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)
$$

You control the weights:

- `weights = (w_T, w_S, w_F)`  
  - `w_T`: importance of **safety** (low toxicity).  
  - `w_S`: importance of **semantic similarity**.  
  - `w_F`: importance of **fluency**.

For each input sentence, we:

1. Generate `num_candidates` candidates with the LLM.
2. Score each candidate with the formula above.
3. Select the **highest-scoring** candidate as the final output.

---

## decompx masking

Masking uses the original **DecompX** component from XDetox:

- Model: `rewrite.mask_orig.Masker` (RoBERTa + DecompX).
- For each input sentence, DecompX assigns a toxicity **importance score** to each token.
- Tokens whose importance exceeds a **threshold** are treated as toxic.
- All such tokens are replaced with `<mask>`.

For each dataset:

1. Input toxic sentences are loaded via `rewrite_example.get_data`.
2. For each threshold value `t` in `thresholds`, we run DecompX masking and write:

```text
   data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/masked_inputs.txt
````

where `DecompX_LLM{t}` is the masking directory for threshold `t`.

The same DecompX thresholds are later used to organize output folders and evaluation.

---

## llm infilling (mistral-7b-instruct)

After DecompX masking, we use **Mistral-7B-Instruct** as an **infiller**, not as a masker.

### prompted infilling behavior

For each example we pass two strings to the LLM:

* **Toxic Sentence**: original input (possibly toxic).
* **Masked Sentence**: the same sentence, but with toxic spans replaced by `<mask>` by DecompX.

The system prompt and few-shot example instruct the model to:

* **Only replace `<mask>` tokens** in the Masked Sentence with polite, non-toxic alternatives.
* Keep **all other words and punctuation** in the Masked Sentence unchanged, except small grammar fixes needed after infilling.
* **Preserve the meaning and intent** of the Toxic Sentence as much as possible.
* Use **the same language** as the Toxic Sentence.
* If there is **no `<mask>` token**, return the Masked Sentence unchanged.
* Return the final sentence inside exactly one pair of brackets:

```text
[Detoxified sentence here.]
```

### candidate generation

For each $(\text{toxic}, \text{masked})$ pair:

* We build a chat-style prompt:

```text
Toxic Sentence: {raw_toxic}
Masked Sentence: {masked_by_DecompX}
Final Output:
```

* We call the Mistral model with:

  * `num_return_sequences = num_candidates`,
  * `do_sample = llm_sample`,
  * `temperature = llm_temperature` (if sampling),
  * `top_p = llm_top_p`,
  * `max_new_tokens = max_new_tokens`.

* For each returned sequence, we:

  1. **Extract content inside the first `[ ... ]` block.**
  2. **Strip any remaining outer brackets.**
  3. **Normalize whitespace.**

If a cleaned candidate is empty, we fall back to the Masked Sentence.

The result is a list of `num_candidates` LLM-generated detoxified sentences per input.

---

## global reranking of llm candidates

After infilling, we apply **global reranking**:

1. **Flatten candidates** into a list and compute:

   * XLM-R toxicity $T(c)$,
   * LaBSE embeddings for sources and candidates → similarity $S(c)$,
   * GPT-2 perplexity → fluency score $F(c)$.

2. Convert toxicity to safety $T'(c) = 1 - T(c)$ and compute:

   $$\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)$$

3. Reshape scores back to an $N \times C$ matrix (inputs × candidates).

4. For each input, choose the candidate with **maximum** score.

For each threshold `t` and LLM setting, we create a run folder:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/{run_folder}/
```

and write:

* `orig.txt` — original toxic inputs (one per line),
* `gen.txt` — selected LLM generations (one per line).

`{run_folder}` encodes LLM hyperparameters (temperature, sampling, top-p, number of candidates, etc.).

---

## evaluation

If `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* BERTScore (F1)
* MeaningBERT
* BLEU-4
* Toxicity (orig / gen) with XLM-R
* Perplexity (orig / gen) with GPT-2

For each run folder, it writes:

```text
data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/{run_folder}/gen_stats.txt
```

The notebook also builds a **summary CSV per dataset** by scanning all `DecompX_LLM*` folders:

```text
data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
```

This CSV includes:

* `threshold` (the DecompX masking threshold `t`),
* `folder` (run folder name),
* `bertscore`, `meaningbert`, `bleu4`,
* `perplexity_gen`, `perplexity_orig`,
* `toxicity_gen`, `toxicity_orig`.

---

## how to use `detoxify()`

Function signature (conceptual):

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_decompx_mask_llm_global",
    thresholds = (0.20,),
    echo: bool = False,
    num_examples: int = 100,        # None = full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # DecompX masking
    mask_batch_size: int = 10,
    # Global reranking
    weights = (0.5, 0.3, 0.2),      # (w_T, w_S, w_F)
    # LLM infilling
    llm_sample: bool = True,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    max_new_tokens: int = 64,
    num_candidates: int = 3,
)
```

### key arguments

#### core i/o

* `data_type`: dataset key from `data_configs`, for example:

  * `"paradetox"`, `"dynabench_val"`, `"dynabench_test"`,
  * `"jigsaw_toxic"`, `"microagressions_val"`, `"sbf_val"`,
  * `"appdia_original"`, `"appdia_discourse"`, etc.

* `output_folder`: top-level outputs directory:

  ```text
  data/model_outputs/{output_folder}/{data_type}/...
  ```

* `num_examples`:

  * `None`: use the full dataset.
  * Integer: run only on the first `num_examples` entries.

* `overwrite_gen`:

  * `False`: if `gen.txt` already exists for a given `(t, run_folder)`, reuse existing generations.
  * `True`: regenerate candidates and overwrite `gen.txt`.

* `echo`:

  * If `True`, print:

    * Dataset and paths,
    * Example inputs,
    * Example masked sentences,
    * Example detoxified outputs,
    * Per-run metrics (if `run_eval=True`).

#### decompx masking

* `thresholds`: a tuple of DecompX thresholds, for example:

  ```python
  thresholds = (0.15, 0.20, 0.25)
  ```

  For each `t`, the notebook creates a separate `DecompX_LLM{t}` directory.

* `mask_batch_size`: batch size for DecompX masking.

#### llm infilling

* `llm_sample`:

  * `True`: stochastic sampling.
  * `False`: deterministic (greedy-like) decoding.

* `llm_temperature`: sampling temperature for Mistral (used when `llm_sample=True`).

* `llm_top_p`: nucleus sampling cutoff.

* `max_new_tokens`: maximum number of new tokens generated by Mistral for each candidate.

* `num_candidates`: number of LLM candidates per input. Larger values improve reranking, but increase cost.

#### global reranking

* `weights = (w_T, w_S, w_F)`:

  * `w_T`: weight of safety (1 − toxicity).
  * `w_S`: weight of semantic similarity.
  * `w_F`: weight of fluency.

The helper `rerank_candidates_global` uses fixed GPT-2 perplexity bounds `p_min` and `p_max` to map perplexity to a $[0, 1]$ fluency score.

#### evaluation

* `run_eval`: if `True`, run evaluation and write `gen_stats.txt` per run folder.
* `overwrite_eval`: if `True`, recompute metrics even if `gen_stats.txt` already exists.
* `skip_ref_eval`: if `True`, skip reference-based metrics that require gold outputs.

---

## example calls

### quick sanity check (small subset, one threshold)

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_decompx_mask_llm_global_demo_50_ex",
    thresholds=(0.20,),
    echo=True,
    num_examples=50,         # small subset
    overwrite_gen=True,
    run_eval=True,           # BLEU / BERTScore / MeaningBERT / PPL / Toxicity
    overwrite_eval=True,
    skip_ref_eval=False,
    mask_batch_size=8,
    weights=(0.5, 0.3, 0.2),
    llm_sample=True,
    llm_temperature=0.7,
    llm_top_p=0.95,
    max_new_tokens=64,
    num_candidates=10,
)
```

### larger run (multiple thresholds)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_decompx_mask_llm_global_full",
    thresholds=(0.15, 0.20, 0.25),
    echo=True,
    num_examples=None,       # full dataset
    overwrite_gen=False,
    run_eval=True,
    overwrite_eval=False,
    skip_ref_eval=False,
    mask_batch_size=8,
    weights=(0.5, 0.3, 0.2),
    llm_sample=True,
    llm_temperature=0.7,
    llm_top_p=0.95,
    max_new_tokens=64,
    num_candidates=10,
)
```

After running `detoxify`, check:

* Per-threshold, per-run files:

  ```text
  data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/{run_folder}/orig.txt
  data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/{run_folder}/gen.txt
  data/model_outputs/{output_folder}/{data_type}/DecompX_LLM{t}/{run_folder}/gen_stats.txt
  ```

* Aggregated metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
  ```

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

# try my drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("try mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using xdetox_dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("transformers_cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("repo folders ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2TokenizerFast,
)

from rewrite.mask_orig import Masker as Masker_single
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("loaded")

In [None]:
## data config

data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

In [None]:
## helpers

def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """subset data"""
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
## scoring

DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"scoring models will use: {DEVICE_SCORE}")

_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].detach().cpu().tolist())
    return scores

_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="labse", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1)
        masked = hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

_GPT2_NAME = "gpt2"
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="gpt2 ppl", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

In [None]:
## evaluation helpers

def _parse_run_folder_name(folder_name):
    """accept all subdirs"""
    return True

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    """eval w/ toxicity"""
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    """aggregate metrics"""
    rows = []

    data_root = os.path.join(base_out_dir, data_type)
    if not os.path.isdir(data_root):
        print("no data_type directory found:", data_root)
        return

    for mask_dir in os.listdir(data_root):
        if not mask_dir.startswith("DecompX"):
            continue
        mask_path = os.path.join(data_root, mask_dir)
        if not os.path.isdir(mask_path):
            continue
        thr_str = mask_dir.replace("DecompX", "")
        try:
            thresh = float(thr_str)
        except ValueError:
            print("could not parse threshold from dir:", mask_dir)
            continue

        for folder in os.listdir(mask_path):
            gen_dir = os.path.join(mask_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
            })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no eval files found")

In [None]:
## llm infilling - mistral

USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("llm device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

INFILL_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences
more polite and respectful by INFILLING the special token <mask>.

You are NOT a free rewriter. You must keep all non-masked text as close as possible
to the given masked sentence.

You are given two inputs:
1) Toxic Sentence: the original toxic sentence.
2) Masked Sentence: the same sentence, where toxic spans are replaced with <mask>.

Your task:
1. For each <mask> token in the Masked Sentence, replace it with a short, non-toxic
   word or phrase that fits the context and preserves the meaning of the Toxic Sentence.
2. Do NOT modify any other words or punctuation outside the <mask> spans, unless a very
   small change is needed to fix grammar or agreement.
3. Preserve the original meaning and intent as much as possible, but make the sentence
   safe and respectful.
4. Keep the language the same as the original (do NOT translate).

Output rules (VERY STRICT):
- ONLY return the final detoxified sentence with all <mask> tokens filled.
- Wrap the final sentence in exactly ONE pair of square brackets, e.g.:
  [Detoxified sentence here.]
- Do NOT include the Toxic Sentence or Masked Sentence in your output.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any other '[' or ']' characters.
"""

INFILL_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Masked Sentence: You're such a <mask>, nobody wants to hear your <mask>.
Step 1 - Decide safe replacements for each <mask>: "rude person", "opinion"
Step 2 - Infill the masked sentence, keeping all other words the same:
You're such a rude person, nobody wants to hear your opinion.
Final Output: [You're such a rude person, nobody wants to hear your opinion.]"""

def _lazy_load_llm():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"loading llm: {LLM_MODEL_NAME} on {DEVICE_LLM}")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("llm loaded")

def _extract_bracket_content(text: str) -> str:
    """extract [ ... ] content"""
    text = text.strip()

    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    if "[" in text:
        return text.split("[", 1)[1].strip()

    return text

def _postprocess_llm_infill(text: str) -> str:
    """cleanup"""
    s = text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()

    s = s.replace("<mask>", " ")
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return text.strip()
    return s

@torch.no_grad()
def llm_infill_candidates(
    sources: List[str],
    masked: List[str],
    num_candidates: int = 3,
    temperature: float = 0.7,
    top_p: float = 0.95,
    max_new_tokens: int = 64,
    sample: bool = True,
) -> List[List[str]]:
    """infill w/ mistral"""
    assert len(sources) == len(masked), "sources and masked length mismatch"
    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    _lazy_load_llm()
    all_cands: List[List[str]] = []

    for src, msk in tqdm(
        list(zip(sources, masked)),
        desc="llm infilling",
        leave=False,
    ):
        messages = [
            {
                "role": "system",
                "content": INFILL_SYSTEM_PROMPT + "\n\nHere is an example:\n" + INFILL_FEW_SHOT,
            },
            {
                "role": "user",
                "content": (
                    f"Toxic Sentence: {src}\n"
                    f"Masked Sentence: {msk}\n"
                    f"Final Output:"
                ),
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                INFILL_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + INFILL_FEW_SHOT
                + "\n\nToxic Sentence: "
                + src
                + "\nMasked Sentence: "
                + msk
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        input_len = inputs["input_ids"].shape[1]

        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=sample,
            temperature=float(temperature) if sample else 0.0,
            top_p=top_p,
            num_return_sequences=num_candidates,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )

        cand_list = []
        for idx in range(num_candidates):
            gen_text = _LLM_TOKENIZER.decode(
                gen[idx][input_len:],
                skip_special_tokens=True,
            )
            cleaned = _extract_bracket_content(gen_text)
            cleaned = _postprocess_llm_infill(cleaned)
            if not cleaned:
                cleaned = src
            cand_list.append(cleaned)

        all_cands.append(cand_list)

    return all_cands

In [None]:
## global reranking

def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    """rerank using tox/sim/flu"""
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)
    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat_cands)

    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)

    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)
    sims = (sims + 1.0) / 2.0

    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    safety = 1.0 - tox
    scores = w_T * safety + w_S * sims + w_F * flus

    tox2    = tox.reshape(N, C)
    safety2 = safety.reshape(N, C)
    sims2   = sims.reshape(N, C)
    flus2   = flus.reshape(N, C)
    scores2 = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details

In [None]:
## decompx masking

def _process_in_batches(masker, inputs, batch_size, thresh: float):
    batched_inputs = [
        inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
    ]
    results = []
    for batch in tqdm(batched_inputs, desc=f"masking t={thresh:.2f}", leave=False):
        batch_result = masker.process_text(sentence=batch, threshold=thresh)
        results.append(batch_result)
    return results

In [None]:
## masking + llm + reranking per threshold

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_llm_gen_folder_name(
    temperature, sample, top_p, max_new_tokens, num_candidates
):
    return (
        "llm"
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topp" + str(top_p) +
        "_maxnew" + str(max_new_tokens) +
        "_ncand" + str(num_candidates)
    )

def _run_decompx_masking_and_llm_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    batch_size_mask,
    num_candidates,
    weights,
    llm_temperature,
    llm_top_p,
    llm_max_new_tokens,
    llm_sample,
    overwrite_gen=False,
    echo: bool = False,
    inputs=None,
):
    """decompx mask + llm infill + rerank"""
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"#inputs at thresh={thresh}: {len(inputs)}")

    mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        masker = Masker_single()
        decoded_masked_inputs_batches = _process_in_batches(
            masker, inputs, batch_size=batch_size_mask, thresh=thresh
        )
        decoded_masked_inputs = [
            item for sublist in decoded_masked_inputs_batches for item in sublist
        ]
        decoded_mask_inputs = [
            d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d).strip() + "\n")
        masker.release_model()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("reusing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "masked vs inputs mismatch"

    if echo:
        print(f"\nexample masked (t={thresh:.2f}, first 3):")
        for i, m in enumerate(decoded_mask_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    gen_folder = _build_llm_gen_folder_name(
        temperature=llm_temperature,
        sample=llm_sample,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        num_candidates=num_candidates,
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("gen already exists:", gen_txt, "— skipping")
        return

    _ensure_dir(final_abs)

    print(f"llm infilling: {num_candidates} cands/input (sampling={llm_sample})")
    all_candidates = llm_infill_candidates(
        sources=inputs,
        masked=decoded_mask_inputs,
        num_candidates=num_candidates,
        temperature=llm_temperature,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        sample=llm_sample,
    )

    global _LLM_MODEL, _LLM_TOKENIZER
    try:
        del _LLM_MODEL
        del _LLM_TOKENIZER
    except Exception:
        pass
    _LLM_MODEL = None
    _LLM_TOKENIZER = None
    if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
        torch.cuda.empty_cache()

    print("global reranking...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    if echo:
        print(f"\nexample detoxified (t={thresh:.2f}, first 3):")
        for i, g in enumerate(best_generations[:3]):
            print(f"  detox[{i}]: {g}")

    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("saved:", orig_txt)
    print("saved:", gen_txt)

In [None]:
## main detoxify function

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_decompx_llm_global",
    thresholds = (0.15, 0.20, 0.25),
    echo: bool = False,
    batch_size: int = 10,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    weights = (0.5, 0.3, 0.2),
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
):
    """run xdetox pipeline"""
    assert data_type in data_configs, f"unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {base_out_abs}")
        print(f"num examples: {num_inputs}")
        print(f"thresholds: {', '.join(f'{t:.2f}' for t in thresholds)}")
        print(f"weights (w_T, w_S, w_F): {weights}")
        print(f"num_candidates (llm): {num_candidates}")
        print(f"llm temp / top_p / max_new: "
              f"{llm_temperature} / {llm_top_p} / {llm_max_new_tokens}")
        print("\nexample inputs (first 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    for t in thresholds:
        print("=" * 60)
        print(f"threshold = {t:.2f}")
        _run_decompx_masking_and_llm_global_reranking_for_threshold(
            data_type=data_type,
            subset_path=subset_path,
            thresh=t,
            base_out_rel=base_out_rel,
            batch_size_mask=batch_size,
            num_candidates=num_candidates,
            weights=weights,
            llm_temperature=llm_temperature,
            llm_top_p=llm_top_p,
            llm_max_new_tokens=llm_max_new_tokens,
            llm_sample=llm_sample,
            overwrite_gen=overwrite_gen,
            echo=echo,
            inputs=inputs,
        )

        if run_eval:
            mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
            base_path = os.path.join(base_out_abs, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

            if echo:
                gen_folder = _build_llm_gen_folder_name(
                    temperature=llm_temperature,
                    sample=llm_sample,
                    top_p=llm_top_p,
                    max_new_tokens=llm_max_new_tokens,
                    num_candidates=num_candidates,
                )
                stats_path = os.path.join(base_path, gen_folder, "gen_stats.txt")
                if os.path.exists(stats_path):
                    stats = _read_stats_file(stats_path)
                    print(f"\nmetrics for t={t:.2f}:")
                    metric_keys = [
                        ("bertscore",        "bertscore"),
                        ("meaningbert",      "MeaningBERT"),
                        ("bleu4",            "BLEU-4"),
                        ("perplexity gen",   "ppl_gen"),
                        ("perplexity orig",  "ppl_orig"),
                        ("toxicity gen",     "tox_gen"),
                        ("toxicity orig",    "tox_orig"),
                    ]
                    for key, label in metric_keys:
                        val = stats.get(key, None)
                        if isinstance(val, float) and math.isnan(val):
                            continue
                        if val is None:
                            continue
                        print(f"  {label}: {val:.4f}")
                else:
                    print(f"gen_stats.txt not found at {stats_path}")

    if run_eval:
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

In [None]:
# example run (commented out)
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_decompx_llm_global_demo_50_examples",
#     thresholds=(0.20,),
#     echo=True,
#     batch_size=8,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     weights=(0.5, 0.3, 0.2),
#     num_candidates=10,
#     llm_temperature=0.7,
#     llm_top_p=0.95,
#     llm_max_new_tokens=64,
#     llm_sample=True,
# )

In [17]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_DecompX-Masking_LLM-Infilling_Global-Reranking_Pipeline.ipynb",
    thresholds=(0.20,),
    echo=True,
    batch_size=8,              # DecompX masking batch size
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
    num_candidates=10,         # LLM candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
)


[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_Global-Reranking_Pipeline.ipynb
[echo] Number of examples to detoxify: 671
[echo] Thresholds: 0.20
[echo] Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
[echo] num_candidates per input (LLM): 10
[echo] LLM temperature / top_p / max_new_tokens: 0.7 / 0.95 / 64

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Threshold (DecompX masking) = 0.20
#inputs at thresh=0.2: 671


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Masking (DecompX, thr=0.20):   0%|          | 0/84 [00:00<?, ?it/s]




[echo] Example masked inputs at threshold 0.20 (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated  and " right <mask> now " would be good .
  masked[2]: * neither * of my coworkers gave a <mask> when it came time to ditch mitch . ugh .
LLM infilling: 10 candidates per input (sampling=True)
Loading LLM infiller: mistralai/Mistral-7B-Instruct-v0.2 on cuda ...


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM infiller loaded.


LLM infilling (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]

Global reranking (toxicity + similarity + fluency)...


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]


[echo] Example detoxified outputs at threshold 0.20 (first up to 3):
  detox[0]: or the loud obnoxious one - thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and " right away now " would be good.
  detox[2]: neither of my coworkers showed much concern when it came time to let Mitch go. ugh.
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_Global-Reranking_Pipeline.ipynb/paradetox/DecompX0.2/llm_temp0.7_sampleT_topp0.95_maxnew64_ncand10/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_Global-Reranking_Pipeline.ipynb/paradetox/DecompX0.2/llm_temp0.7_sampleT_topp0.95_maxnew64_ncand10/gen.txt
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking_LLM-Infilling_Global-Reranking_Pipeline.ipynb/paradetox/DecompX0.2/llm_temp0.7_sampleT_topp0