# XDetox with LLM Masking and Global Reranking

This notebook runs the full XDetox pipeline with:

1. **LLM masking** using Mistral-7B-Instruct (`mistralai/Mistral-7B-Instruct-v0.2`), which detects toxic spans and replaces them with `<mask>`.
2. **MaRCo-style generation** (base / expert / anti-expert BART mixture).
3. **Global reranking** of multiple candidates per input using:
   - **Toxicity** (XLM-R large classifier).
   - **Semantic similarity** (LaBSE).
   - **Fluency** (GPT-2 perplexity).

The goal is to pick, for each toxic input sentence, **one best detoxified candidate** that is:

- As **non-toxic** as possible.
- As **semantically close** as possible to the original.
- As **fluent** as possible.

The main difference from the original DecompX-based pipeline is that **masking is done by an LLM instead of DecompX**, and there is **no DecompX threshold hyperparameter**. The DecompX model is not used in this notebook.

---

## Scoring: Global Reranking

For each candidate $c$, we compute:

- $T(c)$: toxicity in $[0, 1]$ from `textdetox/xlmr-large-toxicity-classifier-v2`.
- $S(c)$: semantic similarity in $[0, 1]$, from LaBSE cosine similarity.
- $F(c)$: fluency in $[0, 1]$, from GPT-2 perplexity mapped into a score  
  (low perplexity → high fluency).

We convert toxicity into a **safety** score:

$$
T'(c) = 1 - T(c)
$$

Then we form a **global score**:

$$
\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)
$$

You control the weights:

- `weights = (w_T, w_S, w_F)`  
  - `w_T`: importance of **safety** (low toxicity).  
  - `w_S`: importance of **semantic similarity**.  
  - `w_F`: importance of **fluency**.

For each input sentence, we:

1. Generate `num_candidates` candidates.
2. Score each candidate with the formula above.
3. Select the **highest-scoring** candidate as the final output.

---

## LLM Masking (Mistral-7B-Instruct)

### Prompted masking behavior

Masking is done by a chat-style LLM:

- Model: `mistralai/Mistral-7B-Instruct-v0.2`.
- The LLM is instructed to:
  - **Identify toxic, offensive, or profane words or short phrases**.
  - Replace **each toxic span** with a **single `<mask>` token**.
  - Allow **multiple `<mask>` tokens** in one sentence (one per toxic span).
  - If multiple neighboring words are toxic, **collapse them into one `<mask>`**  
    (so it does not produce `<mask> <mask> ...` sequences).
  - Keep **all non-toxic words and punctuation in place**.
  - **Not** paraphrase, summarize, or otherwise rewrite the sentence.
  - Return the masked sentence **inside exactly one pair of brackets**:

    ```text
    [This is a <mask> example.]
    ```

### Post-processing of LLM masks

Because LLM output is noisy, the notebook cleans the raw output before feeding it into MaRCo. For each LLM output, we:

1. **Extract the bracket content**:

   - Try to read the first `[ ... ]` block.
   - If there is `[` but no `]`, take everything after the first `[` as the sentence.
   - If no brackets exist, fall back to the full string.

2. **Strip stray outer brackets** that might still remain.

3. **Normalize whitespace** (collapse multiple spaces).

4. **Normalize `<mask>` tokens**:

   - Any variant like `<Mask>`, `<MASK>`, `< mask >` is normalized to `<mask>`.

5. **Collapse runs of `<mask>`**:

   - Any sequence like `<mask> <mask> <mask>` becomes a single `<mask>`.
   - This avoids the earlier failure mode where the LLM produced many `<mask>` tokens in a row, which confused the generator.

6. If, after cleaning, the sentence becomes empty, fall back to the original masked text.

All cleaned, LLM-masked sentences are saved to:

- `data/model_outputs/{output_folder}/{data_type}/LLM_Masking/masked_inputs.txt`

and reused across later runs with the same `output_folder` and `data_type`.

---

## Generation (MaRCo: BART base / expert / anti-expert)

For each dataset:

1. **Subset selection**

   - The script can run on the full dataset or only on the first `num_examples` rows.
   - A subset file is written under:

     ```text
     datasets/_subsets/{data_type}/
     ```

2. **LLM masking**

   - Input toxic sentences are masked by Mistral-7B-Instruct as described above.
   - Cleaned masked sentences are written to:

     ```text
     data/model_outputs/{output_folder}/{data_type}/LLM_Masking/masked_inputs.txt
     ```

3. **MaRCo generation**

   - We use `rewrite.generation.Infiller` with:
     - Base model: `facebook/bart-base`,
     - Anti-expert (toxic) model: `hallisky/bart-base-toxic-antiexpert`,
     - Expert (non-toxic) model: `hallisky/bart-base-nontoxic-expert`.
   - Generation is controlled by:
     - `alpha_a`, `alpha_e`, `alpha_b`: anti-expert, expert, base weights,
     - `temperature`,
     - `top_k_gen`, `top_p`, `filter_p`,
     - `rep_penalty`,
     - `max_length`,
     - `sample` (sampling vs greedy decoding).
   - For each input sentence, the notebook **samples `num_candidates` candidates**.

4. **Global reranking**

   - For each input, all candidates are scored with the **global score** using:
     - XLM-R toxicity classifier,
     - LaBSE similarity,
     - GPT-2 perplexity.
   - The candidate with the **highest global score** is chosen.
   - For each run folder, we write:
     - `orig.txt` — original toxic inputs (one per line),
     - `gen.txt` — chosen detoxified outputs (one per line).

Outputs are stored under:

```text
data/model_outputs/{output_folder}/{data_type}/LLM_Masking/{run_folder}/
````

where `{run_folder}` encodes model and decoding hyperparameters (alphas, temperature, top-k, top-p, etc.).

---

## Evaluation

If `run_eval=True`, the pipeline calls `evaluation.evaluate_all` to compute:

* BERTScore (F1)
* MeaningBERT
* BLEU-4
* Toxicity (orig / gen) using XLM-R
* Perplexity (orig / gen) using GPT-2

For each run folder, it writes:

* `gen_stats.txt` under:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking/{run_folder}/
  ```

It also creates a **summary CSV per dataset**:

* `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

The CSV aggregates metrics over all run folders in `LLM_Masking/`.
The `threshold` column is kept as a **fixed label (`0.20`)** for compatibility, but it does **not** control masking here (there is no DecompX threshold).

---

## How to Use `detoxify()`

Function signature:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask",
    echo: bool = False,
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # global reranking:
    weights = (0.5, 0.3, 0.2),
    num_candidates: int = 3,
)
```

### Key arguments

#### Core I/O

* `data_type`: dataset key from `data_configs`, for example:

  * `"paradetox"`, `"dynabench_val"`, `"jigsaw_toxic"`,
  * `"microagressions_val"`, `"sbf_val"`, `"appdia_original"`, etc.

* `output_folder`: top-level directory under `data/model_outputs/` where results are stored:

  ```text
  data/model_outputs/{output_folder}/{data_type}/...
  ```

* `num_examples`: if set, only the first `num_examples` examples are used (for quick tests).
  Use `None` to run on the full dataset.

#### LLM masking (Mistral)

* Masking uses **Mistral-7B-Instruct** with a fixed system prompt and a few-shot example.
* The LLM decides which spans to mask; there is **no numeric threshold**.
* The notebook **caches** LLM-masked sentences to `masked_inputs.txt` and reuses them if they exist.

You do not pass masking options directly into `detoxify()`.
The masking behavior is controlled by the fixed prompt and post-processing code.

#### Generation (MaRCo / BART)

* `sample`:

  * `True`: stochastic sampling.
  * `False`: greedy decoding.
* `top_k_gen`: top-k on ensembled logits for sampling.
* `top_p`: nucleus sampling on ensembled logits.
* `filter_p`: nucleus filter on **base** logits (advanced; often `1.0`).
* `max_length`: maximum generation length (in tokens).
* `alpha_a`, `alpha_e`, `alpha_b`:

  * Anti-expert, expert, and base weights for MaRCo.
  * If `None`, defaults come from `data_configs[data_type]`.
* `temperature`: sampling temperature; if `None`, uses dataset default.
* `rep_penalty`: repetition penalty; if `None`, uses dataset default.
* `batch_size`: generation batch size.

#### Global reranking

* `weights = (w_T, w_S, w_F)`:

  * `w_T`: safety (1 − toxicity).
  * `w_S`: semantic similarity.
  * `w_F`: fluency.
* `num_candidates`: how many candidates to generate per input.

  * Larger values give better reranking at higher computational cost.

#### Evaluation

* `run_eval`: if `True`, run evaluation and write `gen_stats.txt`.
* `overwrite_gen`:

  * If `False` and `gen.txt` exists, reuse existing generations.
  * If `True`, regenerate even if `gen.txt` exists.
* `overwrite_eval`:

  * If `False` and `gen_stats.txt` exists, keep existing evaluation.
  * If `True`, recompute evaluation.
* `skip_ref_eval`: if `True`, skip some reference-based evaluation (for example, perplexity on gold references).

#### Echo / debugging

* `echo`:

  * If `True`, print:

    * Basic dataset and output information,
    * A few example inputs,
    * A few LLM-masked sentences,
    * A few final detoxified outputs,
    * And (if `run_eval=True`) evaluation metrics for the specific run.

---

## Example Calls

### Quick sanity check on a small subset

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_llm_mask_V2_demo_50_examples",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,           # small subset for testing
    run_eval=True,             # BLEU / BERTScore / MeaningBERT / PPL / Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
    num_candidates=10,         # candidates per input
)
```

### Larger run (more candidates, same dataset)

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_llm_mask_global",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=None,         # full dataset
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=False,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=20,
)
```

After running `detoxify`, you can inspect:

* `orig.txt` and `gen.txt` under:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking/{run_folder}/
  ```

* Per-run metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/LLM_Masking/{run_folder}/gen_stats.txt
  ```

* Aggregated metrics:

  ```text
  data/model_outputs/{output_folder}/{data_type}/{data_type}.csv
  ```

This setup lets you directly compare:

* The original **DecompX-masking + global reranking pipeline**, and
* The new **LLM-masking + global reranking pipeline**

on the same datasets and with the same global scoring scheme.


In [None]:
#@title Mount Drive, Imports & locate XDetox
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

# Try My Drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox


In [None]:
#@title Runtime setup (paths, cache, GPU)
# HuggingFace cache inside the repo (persists on Drive)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

# Add repo to PYTHONPATH
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("XDETOX_DIR:", XDETOX_DIR)
print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB


In [None]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")


Repo folders OK.


In [None]:
#@title Install dependencies (restart runtime if major errors)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
# BERTScore dependency required by evaluation/bertscore.py
!pip -q install bert-score


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m75.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m73.3 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.3.1 which is incompatible.[0m[

In [None]:
#@title Import from 'transformers'
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2TokenizerFast,
)



In [None]:
#@title Import from 'rewrite'
from rewrite.generation import Infiller
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
#@title NLTK data
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")


NLTK ready


In [None]:
#@title Data configs
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("Datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR


Datasets: microagressions_val, microagressions_test, sbf_val, sbf_test, dynabench_val, dynabench_test, jigsaw_toxic, paradetox, appdia_original, appdia_discourse


In [None]:
#@title Helpers: subset data
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """
    Create a small subset file matching the expected format used by rewrite_example.get_data().
    Returns the path to the *new* subset file (or original path if n is None).
    """
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
#@title Global scoring helpers: toxicity, similarity, fluency

DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Scoring models (tox/sim/flu) will use: {DEVICE_SCORE}")

# ---------- Toxicity model (textdetox/xlmr-large-toxicity-classifier-v2) ----------
_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].detach().cpu().tolist())
    return scores

# ---------- Semantic similarity (LaBSE) ----------
_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1)
        masked = hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

# ---------- Fluency via GPT-2 perplexity ----------
_GPT2_NAME = "gpt2"
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

Scoring models (tox/sim/flu) will use: cuda


In [None]:
#@title Evaluation helpers (evaluate_all.py with MeaningBERT + toxicity)
def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("Eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    """
    Aggregate evaluation metrics for the LLM-masking pipeline.

    Directory layout (absolute base_out_dir):
      base_out_dir/
        └── {data_type}/
            └── LLM_Masking/
                └── {run_folder}/
                    └── gen_stats.txt

    We keep a `threshold` column for compatibility, but it is fixed to 0.20.
    """
    rows = []

    mask_dir = "LLM_Masking"
    base_path = os.path.join(base_out_dir, data_type, mask_dir)
    if not os.path.isdir(base_path):
        print("No evaluation directory found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "threshold":        0.20,  # fixed label for this LLM pipeline
            "folder":           folder,
            "bertscore":        s.get("bertscore", np.nan),
            "meaningbert":      s.get("meaningbert", np.nan),
            "bleu4":            s.get("bleu4", np.nan),
            "perplexity_gen":   s.get("perplexity gen", np.nan),
            "perplexity_orig":  s.get("perplexity orig", np.nan),
            "toxicity_gen":     s.get("toxicity gen", np.nan),
            "toxicity_orig":    s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")


In [None]:
#@title  LLM Masking (MetaDetox-style) using Mistral-7B-Instruct
# By default, run Mistral on CPU to avoid GPU OOM when combined with BART/LaBSE/XLM-R/GPT-2.
# If you have a large GPU (e.g., A100 40GB) and want speed, set USE_LLM_GPU = True.
USE_LLM_GPU = True

DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("LLM masker device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

MASK_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences safer
by MASKING toxic words with the special token <mask>, while keeping the original sentence
structure and wording as much as possible.

You must behave like a MASKER, not a full rewriter.

Your task:
1. Identify toxic, offensive, or profane words or short phrases.
2. For each toxic span, replace the entire span with a single <mask> token.
3. There may be multiple toxic spans in one sentence, so multiple <mask> tokens are allowed.
4. If several neighboring words are toxic, you must still use only a single <mask> token
   in that place. In other words, if you would place "<mask> <mask>" or a longer sequence
   of <mask> tokens, collapse them into a single <mask> so that there are never multiple
   <mask> tokens in a row.
5. Do NOT rewrite, paraphrase, or summarize the sentence.
6. Do NOT add, remove, or reorder non-toxic words or punctuation.
7. Keep punctuation and spacing as close to the original as possible.
8. If there is no toxic content, return the sentence unchanged.

Output rules (format is very strict):
- ONLY return the final masked sentence inside ONE pair of square brackets, like:
  [This is a <mask> example.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any language tags or metadata.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

MASK_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Step 1 - Identify toxic words: "stupid idiot", "crap"
Step 2 - Mask toxic words (do NOT rewrite the rest):
You're such a <mask>, nobody wants to hear your <mask>.
Final Output: [You're such a <mask>, nobody wants to hear your <mask>.]"""


def _lazy_load_llm_masker():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"Loading LLM masker: {LLM_MODEL_NAME} on {DEVICE_LLM} ...")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("LLM masker loaded.")

def _extract_bracket_content(text: str) -> str:
    """
    Extract content inside the first [ ... ] block.
    If there is an opening '[' but no closing ']', take everything after '['.
    Otherwise, fall back to the whole string.
    """
    text = text.strip()

    # Case 1: well-formed [ ... ]
    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    # Case 2: has '[' but no ']' (truncate everything before '[')
    if "[" in text:
        return text.split("[", 1)[1].strip()

    # Fallback: no brackets at all
    return text

def _postprocess_llm_mask(masked_text: str) -> str:
    """
    Clean up LLM-masked sentences to be MaRCo-friendly:
      - remove stray leading/trailing brackets,
      - normalize whitespace,
      - normalize <mask> token casing,
      - collapse runs of <mask> into a single <mask>.
    """
    s = masked_text.strip()

    # Remove any leftover outer brackets if still present
    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    # Normalize whitespace
    s = re.sub(r"\s+", " ", s).strip()

    # Normalize mask token casing and spacing: <Mask>, <MASK>, < mask > -> <mask>
    s = re.sub(r"<\s*mask\s*>", "<mask>", s, flags=re.IGNORECASE)

    # Collapse runs of <mask> (e.g., "<mask> <mask> <mask>" -> "<mask>")
    s = re.sub(r"(?:\s*<mask>\s*){2,}", " <mask> ", s)
    s = re.sub(r"\s+", " ", s).strip()

    # Simple safety: if we somehow deleted everything, fall back to the original masked_text
    if not s:
        return masked_text.strip()

    return s

@torch.no_grad()
def llm_mask_sentences(sentences: List[str]) -> List[str]:
    """
    Use Mistral-7B-Instruct as a masker:
    input: toxic sentence
    output: same sentence but toxic words replaced by <mask>.
    """
    _lazy_load_llm_masker()
    masked = []
    for s in tqdm(sentences, desc="LLM masking (Mistral)", leave=False):
        messages = [
            {
                "role": "system",
                "content": MASK_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + MASK_FEW_SHOT,
            },
            {
                "role": "user",
                "content": f"Toxic Sentence: {s}\nFinal Output:",
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            # Fallback: plain prompt if chat template not available
            prompt = (
                MASK_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + MASK_FEW_SHOT
                + "\n\nToxic Sentence: "
                + s
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            temperature=0.0,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        # Only decode newly generated tokens
        gen_text = _LLM_TOKENIZER.decode(
            gen[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        # 1) extract bracket content (robust to missing ']')
        masked_text = _extract_bracket_content(gen_text)
        # 2) normalize for MaRCo (collapse mask runs, strip stray brackets, etc.)
        masked_text = _postprocess_llm_mask(masked_text)

        # Ensure we at least return something
        if not masked_text:
            masked_text = s
        masked.append(masked_text)

    return masked

LLM masker device: cuda


In [None]:
#@title Global reranking: combine toxicity, similarity, fluency
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    """
    sources: list[str], length N
    candidates: list[list[str]], shape N x C
    weights: (w_T, w_S, w_F)
    Returns:
        best_idx: np.ndarray of shape (N,), index of chosen candidate per source
        details: dict with matrices [N x C] for tox, safety, sim, flu, score
    """
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    # Flatten candidates and map to source indices
    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    # Toxicity
    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)  # [N*C]

    # Semantic similarity (LaBSE)
    src_embs = get_labse_embeddings(sources)  # [N, D]
    cand_embs = get_labse_embeddings(flat_cands)  # [N*C, D]
    # Normalize
    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)
    # Cosine between each candidate and its source
    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)  # [-1,1]
    sims = (sims + 1.0) / 2.0  # -> [0,1]

    # Fluency: GPT-2 PPL -> F in [0,1]
    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    # Safety
    safety = 1.0 - tox

    # Global score
    scores = w_T * safety + w_S * sims + w_F * flus

    # Reshape to [N, C]
    tox2     = tox.reshape(N, C)
    safety2  = safety.reshape(N, C)
    sims2    = sims.reshape(N, C)
    flus2    = flus.reshape(N, C)
    scores2  = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details


In [None]:
#@title Global reranking: combine toxicity, similarity, fluency
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    """
    sources: list[str], length N
    candidates: list[list[str]], shape N x C
    weights: (w_T, w_S, w_F)
    """
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)
    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat_cands)

    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)

    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)
    sims = (sims + 1.0) / 2.0

    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    safety = 1.0 - tox
    scores = w_T * safety + w_S * sims + w_F * flus

    tox2    = tox.reshape(N, C)
    safety2 = safety.reshape(N, C)
    sims2   = sims.reshape(N, C)
    flus2   = flus.reshape(N, C)
    scores2 = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details


In [None]:
#@title Masking + generation with global reranking (using LLM masks)

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )

def _run_llm_masking_and_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,               # kept for API symmetry, not used in folder naming
    base_out_rel,
    batch_size,
    alpha_a, alpha_e, alpha_b,
    temperature,
    rep_penalty,
    max_length,
    top_k_gen,
    top_p,
    filter_p,
    sample,
    num_candidates,
    weights,
    overwrite_gen=False,
    inputs=None,
):
    """
    Run one LLM-masking + MaRCo + global-reranking pass.

    `thresh` is only a dummy numeric label kept for compatibility; it does NOT
    affect behavior. All outputs go under a single folder: LLM_Masking.
    """
    # Load inputs if not provided
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"#inputs to detoxify: {len(inputs)}")

    # Paths: use LLM_Masking folder (no DecompX thresholds)
    mask_dir = "LLM_Masking"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    # LLM masking (reuse cached file if available)
    if not os.path.exists(masked_file):
        print("Running LLM masking (Mistral) to create masked_inputs.txt ...")
        decoded_mask_inputs = llm_mask_sentences(inputs)
        decoded_mask_inputs = [
            re.sub(r"\s+", " ", d).strip() for d in decoded_mask_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(d + "\n")

        # Optional: free LLM to release memory before loading BART
        global _LLM_MODEL, _LLM_TOKENIZER
        del _LLM_MODEL
        del _LLM_TOKENIZER
        _LLM_MODEL = None
        _LLM_TOKENIZER = None
        if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
            torch.cuda.empty_cache()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("Reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    # Initialize Infiller (MaRCo: base + toxic antiexpert + non-toxic expert)
    rewriter = Infiller(
        seed=0,
        base_path="facebook/bart-base",
        antiexpert_path="hallisky/bart-base-toxic-antiexpert",
        expert_path="hallisky/bart-base-nontoxic-expert",
        base_type="base",
        antiexpert_type="antiexpert",
        expert_type="expert",
        tokenizer="facebook/bart-base",
    )

    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"
    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    # If generation already exists and we are not overwriting, just load outputs
    if os.path.exists(gen_txt) and not overwrite_gen:
        print("Generation already exists at:", gen_txt, "— skipping generation.")
        _ensure_dir(final_abs)
        if not os.path.exists(orig_txt):
            with open(orig_txt, "w") as f:
                for l in inputs:
                    f.write(re.sub(r"\s+", " ", l).strip() + "\n")

        # Load masked inputs (already ensured above)
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]

        # Load final generations
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]

        return inputs, decoded_mask_inputs, best_generations, final_abs

    _ensure_dir(final_abs)

    # Generate multiple candidates per input
    all_candidates: List[List[str]] = [[] for _ in range(len(inputs))]

    print(f"Generating {num_candidates} candidates per input (sampling={sample})")
    for c in range(num_candidates):
        outs, decoded = rewriter.generate(
            inputs,
            decoded_mask_inputs,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            verbose=False,
            max_length=max_length,
            repetition_penalty=rep_penalty,
            p=top_p,
            filter_p=filter_p,
            k=top_k_gen,
            batch_size=batch_size,
            sample=sample,
            ranking=False,          # no DecompX ranking
            ranking_eval_output=0,
        )
        for i, text in enumerate(decoded):
            all_candidates[i].append(re.sub(r"\s+", " ", text).strip())

    # Free BART models before heavy scoring
    del rewriter
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # Global reranking
    print("Global reranking (toxicity + similarity + fluency)...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    # Save orig + chosen gen
    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("Saved:", orig_txt)
    print("Saved:", gen_txt)

    return inputs, decoded_mask_inputs, best_generations, final_abs


In [None]:
#@title `detoxify()` — LLM masking + global reranking + optional eval

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask",
    echo: bool = False,
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,   # if None, take from data_configs
    alpha_e: float = None,   # if None, take from data_configs
    alpha_b: float = 1.0,
    temperature: float = None,  # if None, from data_configs
    rep_penalty: float = None,  # if None, from data_configs
    num_examples: int = 100,    # small-batch control; None = full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # global reranking:
    weights = (0.5, 0.3, 0.2),   # (w_T, w_S, w_F)
    num_candidates: int = 3,     # candidates per input
):
    """
    Run XDetox with:
      - LLM masking (Mistral-7B-Instruct acting as a masked-span detector),
      - MaRCo generation (BART base / antiexpert / expert),
      - global reranking based on:
          - toxicity (XLM-R),
          - semantic similarity (LaBSE),
          - fluency (GPT-2 perplexity -> [0,1]),
      - evaluation via evaluation/evaluate_all.py (BLEU/BERTScore/MeaningBERT/PPL/Toxicity).

    Notes:
    - There is no DecompX threshold here. The LLM masker decides which spans to
      replace with <mask>.
    - Outputs are stored under:
        data/model_outputs/{output_folder}/{data_type}/LLM_Masking/{run_folder}/
    - If echo=True, the function prints:
        * number of examples and dataset,
        * a few example inputs,
        * a few masked inputs,
        * a few detoxified outputs,
        * evaluation metrics for this specific run (if run_eval=True).
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    # fallbacks from data_configs
    if alpha_a is None:
        alpha_a = cfg["alpha_a"]
    if alpha_e is None:
        alpha_e = cfg["alpha_e"]
    if temperature is None:
        temperature = cfg["temperature"]
    if rep_penalty is None:
        rep_penalty = cfg["rep_penalty"]

    # Use model_outputs instead of dexp_outputs
    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    # subset path (file)
    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    # Load inputs once here (so we can print and also reuse inside the runner)
    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"[echo] Dataset: {data_type}")
        print(f"[echo] Subset path: {subset_path}")
        print(f"[echo] Output base: {base_out_abs}")
        print(f"[echo] Number of examples to detoxify: {num_inputs}")
        print(f"[echo] Weights (w_T, w_S, w_F): {weights}")
        print(f"[echo] num_candidates per input: {num_candidates}")
        print("\n[echo] Example inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    # Dummy label kept only for compatibility with some utilities
    folder_label = 0.20

    # Run one LLM-masking + MaRCo + global-reranking pass
    inputs, masked_inputs, best_generations, run_dir = _run_llm_masking_and_global_reranking_for_threshold(
        data_type=data_type,
        subset_path=subset_path,
        thresh=folder_label,        # label only, not used for naming
        base_out_rel=base_out_rel,
        batch_size=batch_size,
        alpha_a=alpha_a,
        alpha_e=alpha_e,
        alpha_b=alpha_b,
        temperature=temperature,
        rep_penalty=rep_penalty,
        max_length=max_length,
        top_k_gen=top_k_gen,
        top_p=top_p,
        filter_p=filter_p,
        sample=sample,
        num_candidates=num_candidates,
        weights=weights,
        overwrite_gen=overwrite_gen,
        inputs=inputs,
    )

    if echo:
        print("\n[echo] Example masked inputs (first up to 3):")
        for i, m in enumerate(masked_inputs[:3]):
            print(f"  masked[{i}]: {m}")

        print("\n[echo] Example detoxified outputs (first up to 3):")
        for i in range(min(3, len(best_generations))):
            print(f"  detox[{i}]: {best_generations[i]}")

    # Optional evaluation (BLEU / BERTScore / MeaningBERT / PPL / Toxicity)
    if run_eval:
        mask_dir = "LLM_Masking"
        base_path = os.path.join(base_out_abs, data_type, mask_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        # If echo, print metrics for THIS run from gen_stats.txt
        if echo:
            stats_path = os.path.join(run_dir, "gen_stats.txt")
            if os.path.exists(stats_path):
                stats = _read_stats_file(stats_path)
                print("\n[echo] Evaluation metrics for this run:")
                metric_keys = [
                    ("bertscore", "BERTScore"),
                    ("meaningbert", "MeaningBERT"),
                    ("bleu4", "BLEU-4"),
                    ("perplexity gen", "Perplexity (gen)"),
                    ("perplexity orig", "Perplexity (orig)"),
                    ("toxicity gen", "Toxicity (gen)"),
                    ("toxicity orig", "Toxicity (orig)"),
                ]
                for key, label in metric_keys:
                    val = stats.get(key, None)
                    if isinstance(val, float) and math.isnan(val):
                        continue
                    if val is None:
                        continue
                    print(f"  {label}: {val:.4f}")
            else:
                print("\n[echo] gen_stats.txt not found for this run; no metrics to print.")

In [None]:
#@title Example run — paradetox, LLM masking + global reranking

# Example: small demo (adjust num_examples as you like)
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_demo",
#     echo=True,
#     batch_size=8,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,             # small subset for testing
#     run_eval=True,               # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
#     overwrite_gen=False,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     weights=(0.5, 0.3, 0.2),
#     num_candidates=10,
# )

# Minimal example call (re-run eval only, assuming generations exist)
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_demo",
#     echo=True,
#     batch_size=8,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=False,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     weights=(0.5, 0.3, 0.2),
#     num_candidates=10,
# )

In [None]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_LLM-Masking_Global-Reranking_Pipeline",
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=10,
)

[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_Global-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
[echo] num_candidates per input: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
#inputs to detoxify: 671
Reusing existing masked_inputs.txt
Found 1 GPUS!
Generating 10 candidates per input (sampling=True)


Filling in masks: 100%|██████████| 84/84 [01:38<00:00,  1.17s/it]
Filling in masks: 100%|██████████| 84/84 [01:38<00:00,  1.18s/it]
Filling in masks: 100%|██████████| 84/84 [01:44<00:00,  1.25s/it]
Filling in masks: 100%|██████████| 84/84 [01:28<00:00,  1.05s/it]
Filling in masks: 100%|██████████| 84/84 [01:50<00:00,  1.31s/it]
Filling in masks: 100%|██████████| 84/84 [01:41<00:00,  1.21s/it]
Filling in masks: 100%|██████████| 84/84 [01:23<00:00,  1.01it/s]
Filling in masks: 100%|██████████| 84/84 [01:42<00:00,  1.23s/it]
Filling in masks: 100%|██████████| 84/84 [01:35<00:00,  1.14s/it]
Filling in masks: 100%|██████████| 84/84 [01:33<00:00,  1.11s/it]


Global reranking (toxicity + similarity + fluency)...


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_Global-Reranking_Pipeline/paradetox/LLM_Masking/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_Global-Reranking_Pipeline/paradetox/LLM_Masking/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/gen.txt

[echo] Example masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated and "<mask> right <mask> now <mask> " would be good .
  masked[2]: neither of my coworkers gave a <mask> when it came time to ditch mitch . ugh .

[echo] Example detoxified outputs (first up to 3):
  detox[0]: . or the loud sound of a one- thousand ton beast roaring towards you howling its horn.
 