# XDetox with LLM Masking, LLM Infilling, and Global Reranking

This notebook runs an XDetox pipeline with:

1. **LLM masking** using Mistral-7B-Instruct (`mistralai/Mistral-7B-Instruct-v0.2`), which detects toxic spans and replaces them with `<mask>`.
2. **LLM infilling** (same Mistral model) that fills the `<mask>` tokens with safe alternatives while keeping the rest of the sentence almost unchanged.
3. **Global reranking** of multiple LLM candidates per input using:

   * **Toxicity** (XLM-R large classifier).
   * **Semantic similarity** (LaBSE).
   * **Fluency** (GPT-2 perplexity).

The goal is to pick, for each toxic input sentence, **one best detoxified candidate** that is:

* As **non-toxic** as possible.
* As **semantically close** as possible to the original.
* As **fluent and natural** as possible.

The main differences from the previous pipelines are:

* The **masking** step and the **infilling** step are **both done by an LLM (Mistral-7B-Instruct)**.
* There is **no MaRCo / BART generation** in this notebook.
* **DecompX is not used at all** (neither for masking nor for reranking).

---

## Scoring: Global Reranking

For each candidate $c$, we compute:

* $T(c)$: toxicity in $[0, 1]$ from `textdetox/xlmr-large-toxicity-classifier-v2`
  (higher = more toxic).
* $S(c)$: semantic similarity in $[0, 1]$ from LaBSE cosine similarity to the **original toxic sentence**.
* $F(c)$: fluency in $[0, 1]$ from GPT-2 perplexity mapped to a score
  (low perplexity $\Rightarrow$ high fluency).

We convert toxicity into a **safety** score:

$$
T'(c) = 1 - T(c)
$$

Then we form a **global score**:

$$
\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)
$$

You control the weights:

* `weights = (w_T, w_S, w_F)`

  * `w_T`: importance of **safety** (low toxicity).
  * `w_S`: importance of **semantic similarity**.
  * `w_F`: importance of **fluency**.

For each input sentence, we:

1. Generate `num_candidates` LLM infilling candidates.
2. Score each candidate using the formula above.
3. Select the **highest-scoring** candidate as the final output.

---

## LLM Masking (Mistral-7B-Instruct)

### Prompted masking behavior

Masking is done by a chat-style LLM:

* Model: `mistralai/Mistral-7B-Instruct-v0.2`.
* The LLM is instructed to:

  * **Identify toxic, offensive, or profane words or short phrases**.
  * Replace **each toxic span** with a **single `<mask>` token**.
  * Allow **multiple `<mask>` tokens** in one sentence (one per toxic span).
  * If multiple neighboring words are toxic, **collapse them into one `<mask>`**
    (never output `<mask> <mask> ...`).
  * Keep **all non-toxic words and punctuation unchanged**.
  * **Not** paraphrase, summarize, or reorder the sentence.
  * Return the masked sentence **inside exactly one pair of brackets**:

    ```text
    [This is a <mask> example.]
    ```

This step turns a raw toxic input into a **masked template** that marks exactly where detoxification must happen.

### Post-processing of LLM masks

Because LLM output can be noisy, the notebook cleans the raw masked output before using it:

1. **Extract the bracket content**:

   * Try to read the first `[ ... ]` block.
   * If there is `[` but no `]`, take everything after `[` as the sentence.
   * If there are no brackets, fall back to the full string.

2. **Strip stray outer brackets** that may still remain.

3. **Normalize whitespace** (collapse multiple spaces into one).

4. **Normalize `<mask>` tokens**:

   * Any variant like `<Mask>`, `<MASK>`, `< mask >` is normalized to `<mask>`.

5. **Collapse runs of `<mask>`**:

   * Any sequence like `<mask> <mask> <mask>` becomes a single `<mask>`.

6. If, after cleaning, the sentence becomes empty, fall back to the original masked text.

All cleaned, LLM-masked sentences are saved to:

* `data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/masked_inputs.txt`

and reused across later runs with the same `output_folder` and `data_type`.

---

## LLM Infilling (Mistral-7B-Instruct)

After masking, detoxification is also performed by Mistral-7B-Instruct via **infilling**:

* For each example we pass two inputs to the LLM:

  1. **Toxic Sentence**: the raw toxic input.
  2. **Masked Sentence**: the same sentence, but with toxic spans replaced by `<mask>`.

The system prompt instructs the LLM to:

* Treat the **Masked Sentence** as a template.
* For each `<mask>` token, insert a **short, non-toxic phrase** that:

  * Fits the local context, and
  * Preserves the meaning and intent of the original **Toxic Sentence**.
* Keep **all other tokens outside `<mask>` spans unchanged**, except for small grammar fixes.
* Keep the **language the same** (no translation).
* Output the final filled sentence **inside one pair of brackets**:

  ```text
  [Detoxified sentence here.]
  ```

For each `(source, masked)` pair we:

1. Build a prompt that includes both **Toxic Sentence** and **Masked Sentence**.
2. Call `generate(...)` with:

   * `num_return_sequences = num_candidates`,
   * `do_sample = llm_sample`,
   * `temperature = llm_temperature` (if sampling),
   * `top_p = llm_top_p`,
   * `max_new_tokens = llm_max_new_tokens`.
3. Decode the new tokens, extract bracket content, and clean it.

### Post-processing of LLM infilling

For each generated candidate we:

1. **Extract** the text inside the first `[ ... ]`.
2. **Remove stray outer brackets**, if any remain.
3. **Normalize whitespace**.
4. **Remove leftover `<mask>` tokens**, if the model failed to fill them.

If cleaning produces an empty string, we fall back to the original toxic input for safety.

The result is a list of candidate sentences:

* `candidates[i]` is a list of length `num_candidates` for the $i$-th input.

---

## Global Reranking (toxicity + similarity + fluency)

After we have LLM candidates, we apply **global reranking**:

1. **Flatten** all candidates into a single list and keep track of which input each candidate belongs to.

2. Compute:

   * **Toxicity** with XLM-R large:

     * `get_toxicity_scores(...)` returns $T(c) \in [0, 1]$.
     * We define **safety** as $T'(c) = 1 - T(c)$.

   * **Semantic similarity** with LaBSE:

     * Encode all sources and candidates with LaBSE.
     * Compute cosine similarity between each candidate and its source.
     * Map cosine similarity from $[-1, 1]$ to $[0, 1]$.

   * **Fluency** with GPT-2:

     * Compute sentence-level perplexity for each candidate.
     * Map perplexity to a **fluency score** $F(c) \in [0, 1]$ using a log-scale mapping
       with clipping between `ppl_min` and `ppl_max`.

3. Combine these into a **global score**:

   $$
   \text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)
   $$

4. For each input we:

   * Reshape the scores into an $N \times C$ matrix (inputs $\times$ candidates).
   * Select the candidate with the **highest score**.

For each run folder, we write:

* `orig.txt` — original toxic inputs (one per line).
* `gen.txt` — chosen detoxified outputs (one per line).

Outputs are stored under:

```text
data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/{run_folder}/
```

where `{run_folder}` encodes LLM hyperparameters (temperature, top-p, sampling flag, number of candidates, max new tokens, etc.).

---

## Evaluation

If `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* BERTScore (F1)
* MeaningBERT
* BLEU-4
* Toxicity (orig / gen) using XLM-R
* Perplexity (orig / gen) using GPT-2

For each run folder, it writes:

* `gen_stats.txt` under:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/{run_folder}/
  ```

It also creates a **summary CSV per dataset**:

* `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

The CSV aggregates metrics over all run folders in `LLM_Mask_LLM_Global/`.
The `threshold` column is kept as a **fixed label** (for example `0.20`) for compatibility with other XDetox pipelines. It does **not** control masking or reranking here.

---

## How to Use `detoxify()`

Function signature (conceptual):

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_infill_global",
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # LLM infilling:
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
    # Global reranking:
    weights = (0.5, 0.3, 0.2),  # (w_T, w_S, w_F)
    ppl_min: float = 5.0,
    ppl_max: float = 300.0,
)
```

### Key arguments

#### Core I/O

* `data_type`: dataset key from `data_configs`, such as:

  * `"paradetox"`, `"dynabench_val"`, `"dynabench_test"`,
  * `"jigsaw_toxic"`, `"microagressions_val"`, `"sbf_val"`,
  * `"appdia_original"`, `"appdia_discourse"`, etc.

* `output_folder`: top-level directory under `data/model_outputs/`:

  ```text
  data/model_outputs/{output_folder}/{data_type}/...
  ```

* `num_examples`: if set, only the first `num_examples` examples are used (useful for quick tests).
  Use `None` to run on the full dataset.

* `overwrite_gen`:

  * If `False` and `gen.txt` already exists for a given run folder, skip generation and reuse previous outputs.
  * If `True`, regenerate outputs even if `gen.txt` exists.

* `echo`:

  * If `True`, print:

    * Dataset and output paths,
    * A few example inputs,
    * A few masked sentences,
    * A few final detoxified outputs,
    * And, if `run_eval=True`, evaluation metrics.

#### LLM masking (Mistral)

* Masking uses **Mistral-7B-Instruct** with a fixed **masking prompt** and a single **few-shot example**.
* The behavior does **not** depend on numeric hyperparameters.
* Masked sentences are cached in:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/masked_inputs.txt
  ```

and reused if the file already exists.

#### LLM infilling (Mistral)

* `num_candidates`: how many detoxified candidates to generate per input.
* `llm_temperature`:

  * Controls randomness when `llm_sample=True`.
* `llm_top_p`:

  * Nucleus sampling parameter.
* `llm_max_new_tokens`:

  * Maximum number of new tokens generated beyond the prompt.
* `llm_sample`:

  * `True`: sampling with `temperature` and `top_p`.
  * `False`: deterministic decoding (greedy-like).

The infilling step uses both **Toxic Sentence** and **Masked Sentence** in the prompt, and is instructed to **only modify `<mask>` tokens** plus small grammar fixes.

#### Global reranking

* `weights = (w_T, w_S, w_F)`:

  * `w_T`: weight on **safety** (1 − toxicity).
  * `w_S`: weight on **semantic similarity**.
  * `w_F`: weight on **fluency**.

* `ppl_min`, `ppl_max`:

  * Bounds for mapping GPT-2 perplexity into a $[0, 1]$ fluency score.
  * Lower perplexity (between `ppl_min` and `ppl_max`) gives higher fluency.

Larger `num_candidates` usually improve reranking quality but increase computation.

#### Evaluation

* `run_eval`: if `True`, run `evaluation.evaluate_all` after generation and write `gen_stats.txt`.
* `overwrite_eval`:

  * If `False`, keep existing `gen_stats.txt` when present.
  * If `True`, recompute evaluation.
* `skip_ref_eval`: if `True`, skip some reference-based metrics (for example, reference perplexity).

---

## Example Calls

### Quick sanity check on a small subset

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_llm_mask_infill_global_demo_50_examples",
    echo=True,
    num_examples=50,           # small subset for testing
    run_eval=True,             # BLEU / BERTScore / MeaningBERT / PPL / Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,         # LLM candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
)
```

### Larger run (more candidates, full dataset)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_llm_mask_infill_global_full",
    echo=True,
    num_examples=None,         # full dataset
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=False,
    skip_ref_eval=False,
    num_candidates=20,
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    weights=(0.5, 0.3, 0.2),
)
```

After running `detoxify`, you can inspect:

* `orig.txt` and `gen.txt` under:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/{run_folder}/
  ```

* Per-run metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Mask_LLM_Global/{run_folder}/gen_stats.txt
  ```

* Aggregated metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
  ```

This setup lets you directly compare:

* **LLM-masking + LLM-infilling + global reranking**
  against other XDetox pipelines such as:

* **DecompX-masking + global reranking**,

* **LLM-masking + MaRCo + global reranking**, and

* **LLM-masking + LLM-infilling + DecompX reranking**.


In [1]:
#@title Mount Drive, Imports & locate XDetox
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

# Try My Drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox


In [2]:
#@title Runtime setup (paths, cache, GPU)
# HuggingFace cache inside the repo (persists on Drive)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

# Add repo to PYTHONPATH
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("XDETOX_DIR:", XDETOX_DIR)
print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


In [3]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")

Repo folders OK.


In [4]:
#@title Install dependencies (restart runtime if major errors)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.8/1.8 MB[0m [31m79.8 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m74.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the follo

In [5]:
#@title Imports from transformers / rewrite
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2TokenizerFast,
)
from rewrite import rewrite_example as rx
import argparse as _argparse



In [6]:
#@title NLTK data
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")

NLTK ready


In [7]:
#@title Data configs
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("Datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

Datasets: microagressions_val, microagressions_test, sbf_val, sbf_test, dynabench_val, dynabench_test, jigsaw_toxic, paradetox, appdia_original, appdia_discourse


In [8]:
#@title Helpers: subset data
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """
    Create a small subset file matching rewrite_example.get_data().
    Returns the new subset path (or original path if n is None/<=0).
    """
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [9]:
#@title Global scoring helpers: toxicity, similarity, fluency

DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Scoring models will use: {DEVICE_SCORE}")

# ---------- Toxicity model (textdetox/xlmr-large-toxicity-classifier-v2) ----------
_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    """
    Returns toxicity probabilities in [0,1] for each input text.
    (0 = non-toxic, 1 = very toxic)
    """
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)  # [..., 2]
        scores.extend(probs[:, 1].detach().cpu().tolist())  # toxic prob
    return scores

# ---------- Semantic similarity (LaBSE) ----------
_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    """
    Returns a numpy array of shape (len(texts), hidden_dim).
    Mean-pooled LaBSE sentence embeddings.
    """
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state  # [B, L, H]
        mask = enc["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        masked = hidden * mask
        summed = masked.sum(dim=1)  # [B, H]
        counts = mask.sum(dim=1).clamp(min=1e-6)  # [B,1]
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

# ---------- Fluency via GPT-2 perplexity ----------
_GPT2_NAME = "gpt2"
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    """
    Simple sentence-level perplexity using GPT-2.
    Returns a list of floats (one per text).
    """
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4  # clip extreme
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    """
    Map perplexities to [0,1] fluency scores.
    Low perplexity -> high fluency.
    """
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F


Scoring models will use: cuda


In [10]:
#@title Evaluation helpers (evaluate_all.py with MeaningBERT + toxicity)
def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    """
    Call evaluation.evaluate_all on each gen folder.
    """
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("Eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    """
    Aggregate eval metrics for LLM-masking + LLM-infilling + global reranking.

    Layout (absolute base_out_dir):
      base_out_dir/
        └── {data_type}/
            └── LLM_Mask_LLM_Global/
                └── {run_folder}/
                    └── gen_stats.txt

    threshold column kept as fixed label (=0.20) for compatibility.
    """
    rows = []

    mask_dir = "LLM_Mask_LLM_Global"
    base_path = os.path.join(base_out_dir, data_type, mask_dir)
    if not os.path.isdir(base_path):
        print("No evaluation directory found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "threshold":        0.20,  # label only, not used by this pipeline
            "folder":           folder,
            "bertscore":        s.get("bertscore", np.nan),
            "meaningbert":      s.get("meaningbert", np.nan),
            "bleu4":            s.get("bleu4", np.nan),
            "perplexity_gen":   s.get("perplexity gen", np.nan),
            "perplexity_orig":  s.get("perplexity orig", np.nan),
            "toxicity_gen":     s.get("toxicity gen", np.nan),
            "toxicity_orig":    s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")

In [11]:
#@title Shared LLM loader (Mistral-7B-Instruct) for masking + infilling
USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("LLM device (mask + infill):", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

def _lazy_load_llm():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"Loading LLM (Mistral-7B-Instruct): {LLM_MODEL_NAME} on {DEVICE_LLM} ...")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("LLM loaded.")


def _extract_bracket_content(text: str) -> str:
    """
    Extract content inside the first [ ... ] block.
    If there is an opening '[' but no closing ']', take everything after '['.
    Otherwise, fall back to the whole string.
    """
    text = text.strip()

    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    if "[" in text:
        return text.split("[", 1)[1].strip()

    return text

LLM device (mask + infill): cuda


In [12]:
#@title LLM Masking (MetaDetox-style) using Mistral-7B-Instruct

MASK_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences safer
by MASKING toxic words with the special token <mask>, while keeping the original sentence
structure and wording as much as possible.

You must behave like a MASKER, not a full rewriter.

Your task:
1. Identify toxic, offensive, or profane words or short phrases.
2. For each toxic span, replace the entire span with a single <mask> token.
3. There may be multiple toxic spans in one sentence, so multiple <mask> tokens are allowed.
4. If several neighboring words are toxic, you must still use only a single <mask> token
   in that place. In other words, if you would place "<mask> <mask>" or a longer sequence
   of <mask> tokens, collapse them into a single <mask> so that there are never multiple
   <mask> tokens in a row.
5. Do NOT rewrite, paraphrase, or summarize the sentence.
6. Do NOT add, remove, or reorder non-toxic words or punctuation.
7. Keep punctuation and spacing as close to the original as possible.
8. If there is no toxic content, return the sentence unchanged.

Output rules (format is very strict):
- ONLY return the final masked sentence inside ONE pair of square brackets, like:
  [This is a <mask> example.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any language tags or metadata.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

MASK_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Step 1 - Identify toxic words: "stupid idiot", "crap"
Step 2 - Mask toxic words (do NOT rewrite the rest):
You're such a <mask>, nobody wants to hear your <mask>.
Final Output: [You're such a <mask>, nobody wants to hear your <mask>.]"""

def _postprocess_llm_mask(masked_text: str) -> str:
    """
    Clean up LLM-masked sentences:
      - remove stray leading/trailing brackets,
      - normalize whitespace,
      - normalize <mask> token casing,
      - collapse runs of <mask> into a single <mask>.
    """
    s = masked_text.strip()

    # Remove any leftover outer brackets if still present
    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    # Normalize whitespace
    s = re.sub(r"\s+", " ", s).strip()

    # Normalize mask token casing and spacing: <Mask>, <MASK>, < mask > -> <mask>
    s = re.sub(r"<\s*mask\s*>", "<mask>", s, flags=re.IGNORECASE)

    # Collapse runs of <mask> (e.g., "<mask> <mask> <mask>" -> "<mask>")
    s = re.sub(r"(?:\s*<mask>\s*){2,}", " <mask> ", s)
    s = re.sub(r"\s+", " ", s).strip()

    # Simple safety: if we somehow deleted everything, fall back to the original masked_text
    if not s:
        return masked_text.strip()

    return s

@torch.no_grad()
def llm_mask_sentences(sentences: List[str]) -> List[str]:
    """
    Use Mistral-7B-Instruct as a masker:
    input: toxic sentence
    output: same sentence but toxic words replaced by <mask>.
    """
    _lazy_load_llm()
    masked = []
    for s in tqdm(sentences, desc="LLM masking (Mistral)", leave=False):
        messages = [
            {
                "role": "system",
                "content": MASK_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + MASK_FEW_SHOT,
            },
            {
                "role": "user",
                "content": f"Toxic Sentence: {s}\nFinal Output:",
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            # Fallback: plain prompt if chat template not available
            prompt = (
                MASK_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + MASK_FEW_SHOT
                + "\n\nToxic Sentence: "
                + s
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            temperature=0.0,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        # Only decode newly generated tokens
        gen_text = _LLM_TOKENIZER.decode(
            gen[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        # 1) extract bracket content (robust to missing ']')
        masked_text = _extract_bracket_content(gen_text)
        # 2) normalize for downstream processing
        masked_text = _postprocess_llm_mask(masked_text)

        # Ensure we at least return something
        if not masked_text:
            masked_text = s
        masked.append(masked_text)

    return masked


In [13]:
#@title LLM infilling (Mistral-7B-Instruct) — fills <mask> tokens

INFILL_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences
more polite and respectful by INFILLING the special token <mask>.

You are NOT a free rewriter. You must keep all non-masked text as close as possible
to the given masked sentence.

You are given two inputs:
1) Toxic Sentence: the original toxic sentence.
2) Masked Sentence: the same sentence, where toxic spans are replaced with <mask>.

Your task:
1. For each <mask> token in the Masked Sentence, replace it with a short, non-toxic
   word or phrase that fits the context and preserves the meaning of the Toxic Sentence.
2. Do NOT modify any other words or punctuation outside the <mask> spans, unless a very
   small change is needed to fix grammar or agreement.
3. Preserve the original meaning and intent as much as possible, but make the sentence
   safe and respectful.
4. Keep the language the same as the original (do NOT translate).

Output rules (VERY STRICT):
- ONLY return the final detoxified sentence with all <mask> tokens filled.
- Wrap the final sentence in exactly ONE pair of square brackets, e.g.:
  [Detoxified sentence here.]
- Do NOT include the Toxic Sentence or Masked Sentence in your output.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any other '[' or ']' characters.
"""

INFILL_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Masked Sentence: You're such a <mask>, nobody wants to hear your <mask>.
Step 1 - Decide safe replacements for each <mask>: "rude person", "opinion"
Step 2 - Infill the masked sentence, keeping all other words the same:
You're such a rude person, nobody wants to hear your opinion.
Final Output: [You're such a rude person, nobody wants to hear your opinion.]"""

def _postprocess_llm_infill(text: str) -> str:
    """
    Post-process LLM-infilling output:
      - remove stray outer brackets if still present,
      - normalize whitespace,
      - remove any leftover <mask> tokens.
    """
    s = text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()

    # Remove any leftover <mask> tokens if the model did not fill them
    s = s.replace("<mask>", " ")
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return text.strip()
    return s

@torch.no_grad()
def llm_infill_candidates(
    sources: List[str],
    masked: List[str],
    num_candidates: int = 3,
    temperature: float = 0.7,
    top_p: float = 0.95,
    max_new_tokens: int = 64,
    sample: bool = True,
) -> List[List[str]]:
    """
    For each (source, masked) pair, generate `num_candidates` detoxified sentences
    by infilling the <mask> tokens using Mistral-7B-Instruct.

    Returns:
        candidates: list[list[str]] with shape [N][num_candidates]
    """
    assert len(sources) == len(masked), "sources and masked length mismatch"
    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    _lazy_load_llm()
    all_cands: List[List[str]] = []

    for src, msk in tqdm(
        list(zip(sources, masked)),
        desc="LLM infilling (Mistral)",
        leave=False,
    ):
        messages = [
            {
                "role": "system",
                "content": INFILL_SYSTEM_PROMPT + "\n\nHere is an example:\n" + INFILL_FEW_SHOT,
            },
            {
                "role": "user",
                "content": (
                    f"Toxic Sentence: {src}\n"
                    f"Masked Sentence: {msk}\n"
                    f"Final Output:"
                ),
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                INFILL_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + INFILL_FEW_SHOT
                + "\n\nToxic Sentence: "
                + src
                + "\nMasked Sentence: "
                + msk
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        input_len = inputs["input_ids"].shape[1]

        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=sample,
            temperature=float(temperature) if sample else 0.0,
            top_p=top_p,
            num_return_sequences=num_candidates,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )

        cand_list = []
        for idx in range(num_candidates):
            gen_text = _LLM_TOKENIZER.decode(
                gen[idx][input_len:],
                skip_special_tokens=True,
            )
            cleaned = _extract_bracket_content(gen_text)
            cleaned = _postprocess_llm_infill(cleaned)
            if not cleaned:
                cleaned = src  # fall back to original if something goes wrong
            cand_list.append(cleaned)

        all_cands.append(cand_list)

    return all_cands

In [14]:
#@title Global reranking: combine toxicity, similarity, fluency
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    """
    sources: list[str], length N
    candidates: list[list[str]], shape N x C
    weights: (w_T, w_S, w_F)

    Returns:
        best_idx: np.ndarray (N,), index of chosen candidate per source
        details: dict with matrices [N x C] for tox, safety, sim, flu, score
    """
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    # Flatten candidates and map to source indices
    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    # Toxicity
    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)  # [N*C]

    # Semantic similarity (LaBSE)
    src_embs = get_labse_embeddings(sources)      # [N, D]
    cand_embs = get_labse_embeddings(flat_cands)  # [N*C, D]
    # Normalize
    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)
    # Cosine between each candidate and its source
    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)  # [-1,1]
    sims = (sims + 1.0) / 2.0  # -> [0,1]

    # Fluency: GPT-2 PPL -> F in [0,1]
    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    # Safety
    safety = 1.0 - tox

    # Global score
    scores = w_T * safety + w_S * sims + w_F * flus

    # Reshape to [N, C]
    tox2     = tox.reshape(N, C)
    safety2  = safety.reshape(N, C)
    sims2    = sims.reshape(N, C)
    flus2    = flus.reshape(N, C)
    scores2  = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details

In [15]:
#@title Helpers for folder naming
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_run_folder_name(
    llm_temperature: float,
    llm_top_p: float,
    llm_sample: bool,
    num_candidates: int,
    max_new_tokens: int,
):
    """
    Build a folder name encoding LLM hyperparameters.
    """
    return (
        f"llmtemp{llm_temperature}_topp{llm_top_p}_"
        f"sample{_bool2str(llm_sample)}_"
        f"nc{num_candidates}_"
        f"maxntok{max_new_tokens}"
    )



In [16]:
#@title Masking + LLM infilling + global reranking (per threshold)

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_llm_gen_folder_name(
    temperature, sample, top_p, max_new_tokens, num_candidates
):
    return (
        "llm"
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topp" + str(top_p) +
        "_maxnew" + str(max_new_tokens) +
        "_ncand" + str(num_candidates)
    )

def _run_decompx_masking_and_llm_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    batch_size_mask,
    num_candidates,
    weights,
    llm_temperature,
    llm_top_p,
    llm_max_new_tokens,
    llm_sample,
    overwrite_gen=False,
    echo: bool = False,
    inputs=None,
):
    """
    For one DecompX threshold:
      - load inputs using rewrite_example.get_data (or reuse pre-loaded `inputs`)
      - mask with DecompX (Masker_single) -> masked_inputs.txt
      - LLM infilling (Mistral) -> num_candidates candidates per input
      - global reranking (toxicity + similarity + fluency)
      - save orig.txt / gen.txt under DecompX{thr}/{run_folder}
    """
    # Load inputs if not provided
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"#inputs at thresh={thresh}: {len(inputs)}")

    # Paths
    mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    # DecompX masking (reuse if exists)
    if not os.path.exists(masked_file):
        masker = Masker_single()
        decoded_masked_inputs_batches = _process_in_batches(
            masker, inputs, batch_size=batch_size_mask, thresh=thresh
        )
        decoded_masked_inputs = [
            item for sublist in decoded_masked_inputs_batches for item in sublist
        ]
        decoded_mask_inputs = [
            d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d).strip() + "\n")
        masker.release_model()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("Reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print("\n[echo] Example masked inputs at threshold "
              f"{thresh:.2f} (first up to 3):")
        for i, m in enumerate(decoded_mask_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    # Build generation folder name for LLM
    gen_folder = _build_llm_gen_folder_name(
        temperature=llm_temperature,
        sample=llm_sample,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        num_candidates=num_candidates,
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("Generation already exists at:", gen_txt, "— skipping generation.")
        return

    _ensure_dir(final_abs)

    # LLM infilling: generate num_candidates per input
    print(f"LLM infilling: {num_candidates} candidates per input (sampling={llm_sample})")
    all_candidates = llm_infill_candidates(
        sources=inputs,
        masked=decoded_mask_inputs,
        num_candidates=num_candidates,
        temperature=llm_temperature,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        sample=llm_sample,
    )

    # Free LLM model before heavy scoring
    global _LLM_MODEL, _LLM_TOKENIZER
    try:
        del _LLM_MODEL
        del _LLM_TOKENIZER
    except Exception:
        pass
    _LLM_MODEL = None
    _LLM_TOKENIZER = None
    if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
        torch.cuda.empty_cache()

    # Global reranking
    print("Global reranking (toxicity + similarity + fluency)...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    if echo:
        print("\n[echo] Example detoxified outputs at threshold "
              f"{thresh:.2f} (first up to 3):")
        for i, g in enumerate(best_generations[:3]):
            print(f"  detox[{i}]: {g}")

    # Save orig + chosen gen
    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("Saved:", orig_txt)
    print("Saved:", gen_txt)


In [17]:
#@title `detoxify()` — LLM masking + LLM infilling + global reranking + optional eval

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_infill_global",
    echo: bool = False,
    num_examples: int = 100,       # None = full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # LLM infilling hyperparameters
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
    # global reranking
    weights = (0.5, 0.3, 0.2),     # (w_T, w_S, w_F)
):
    """
    Run XDetox with:
      - LLM masking (Mistral-7B-Instruct),
      - LLM infilling (Mistral-7B-Instruct),
      - global reranking based on:
          - toxicity (XLM-R),
          - semantic similarity (LaBSE),
          - fluency (GPT-2 perplexity -> [0,1]),
      - evaluation via evaluation/evaluate_all.py.

    No DecompX is used in this notebook.
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    # subset path (file)
    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    # Load inputs
    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"[echo] Dataset: {data_type}")
        print(f"[echo] Subset path: {subset_path}")
        print(f"[echo] Output base: {base_out_abs}")
        print(f"[echo] Number of examples to detoxify: {num_inputs}")
        print(f"[echo] Weights (w_T, w_S, w_F): {weights}")
        print(f"[echo] num_candidates per input: {num_candidates}")
        print("\n[echo] Example inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    # Directory for this pipeline
    mask_dir = "LLM_Mask_LLM_Global"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    # Step 1: LLM masking (reuse cached file if available)
    if not os.path.exists(masked_file):
        print("Running LLM masking (Mistral) to create masked_inputs.txt ...")
        masked_inputs = llm_mask_sentences(inputs)
        masked_inputs = [re.sub(r"\s+", " ", d).strip() for d in masked_inputs]
        with open(masked_file, "w") as f:
            for d in masked_inputs:
                f.write(d + "\n")
    else:
        with open(masked_file, "r") as f:
            masked_inputs = [s.strip() for s in f.readlines()]
        print("Reusing existing masked_inputs.txt")

    assert len(masked_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print("\n[echo] Example LLM-masked inputs (first up to 3):")
        for i, m in enumerate(masked_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    # Build run folder name for this LLM config
    run_folder = _build_run_folder_name(
        llm_temperature=llm_temperature,
        llm_top_p=llm_top_p,
        llm_sample=llm_sample,
        num_candidates=num_candidates,
        max_new_tokens=llm_max_new_tokens,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)
    orig_txt = os.path.join(final_abs, "orig.txt")
    gen_txt = os.path.join(final_abs, "gen.txt")

    # If outputs already exist and we are not overwriting, just load them
    if os.path.exists(gen_txt) and not overwrite_gen:
        print("Generation already exists at:", gen_txt, "— skipping generation.")
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]
        if echo:
            print("\n[echo] Example detoxified outputs (first up to 3):")
            for i in range(min(3, len(best_generations))):
                print(f"  detox[{i}]: {best_generations[i]}")
    else:
        # Step 2: LLM infilling
        print(f"LLM infilling: generating {num_candidates} candidates per input (sampling={llm_sample})")
        all_candidates = llm_infill_candidates(
            sources=inputs,
            masked=masked_inputs,
            num_candidates=num_candidates,
            temperature=llm_temperature,
            top_p=llm_top_p,
            max_new_tokens=llm_max_new_tokens,
            sample=llm_sample,
        )

        # Free LLM before loading heavy scoring models
        global _LLM_MODEL, _LLM_TOKENIZER
        try:
            del _LLM_MODEL
            del _LLM_TOKENIZER
        except Exception:
            pass
        _LLM_MODEL = None
        _LLM_TOKENIZER = None
        if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
            torch.cuda.empty_cache()

        # Step 3: global reranking
        print("Global reranking (toxicity + similarity + fluency)...")
        best_idx, details = rerank_candidates_global(
            sources=inputs,
            candidates=all_candidates,
            weights=weights,
        )
        best_generations = [
            all_candidates[i][best_idx[i]] for i in range(len(inputs))
        ]

        # Save orig + chosen gen
        with open(orig_txt, "w") as f:
            for l in inputs:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")
        with open(gen_txt, "w") as f:
            for l in best_generations:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")

        print("Saved:", orig_txt)
        print("Saved:", gen_txt)

        if echo:
            print("\n[echo] Example detoxified outputs (first up to 3):")
            for i in range(min(3, len(best_generations))):
                print(f"  detox[{i}]: {best_generations[i]}")

    # Optional evaluation
    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, mask_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        # If echo, print metrics for THIS run from gen_stats.txt
        if echo:
            stats_path = os.path.join(final_abs, "gen_stats.txt")
            if os.path.exists(stats_path):
                stats = _read_stats_file(stats_path)
                print("\n[echo] Evaluation metrics for this run:")
                metric_keys = [
                    ("bertscore", "BERTScore"),
                    ("meaningbert", "MeaningBERT"),
                    ("bleu4", "BLEU-4"),
                    ("perplexity gen", "Perplexity (gen)"),
                    ("perplexity orig", "Perplexity (orig)"),
                    ("toxicity gen", "Toxicity (gen)"),
                    ("toxicity orig", "Toxicity (orig)"),
                ]
                for key, label in metric_keys:
                    val = stats.get(key, None)
                    if isinstance(val, float) and math.isnan(val):
                        continue
                    if val is None:
                        continue
                    print(f"  {label}: {val:.4f}")
            else:
                print("\n[echo] gen_stats.txt not found for this run; no metrics to print.")


In [18]:
#@title Example run — paradetox, LLM masking + LLM infilling + global reranking

# Example: small demo (adjust num_examples as you like)
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_infill_global_demo_50_examples",
#     echo=True,
#     num_examples=50,           # small subset for testing
#     run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     num_candidates=10,         # LLM candidates per input
#     llm_temperature=0.7,
#     llm_top_p=0.95,
#     llm_max_new_tokens=64,
#     llm_sample=True,
#     weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
# )

In [19]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline",
    echo=True,
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,         # LLM candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
)


[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
[echo] num_candidates per input: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Reusing existing masked_inputs.txt

[echo] Example LLM-masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated and "<mask> right <mask> now <mask> " would be good .
  masked[2]: neither of my coworke

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM loaded.


LLM infilling (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]

Global reranking (toxicity + similarity + fluency)...


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64/gen.txt

[echo] Example detoxified outputs (first up to 3):
  detox[0]: or the loud obnoxious one - thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and "immediately" "right now" would be good.
  detox[2]: neither of my coworkers gave a lackadaisical response when it came time to let Mitch go. ugh.
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_ma