# llm masking, infilling & global reranking

runs xdetox with llm-based masking and infilling (mistral-7B) + global reranking.

## pipeline

1. llm masking - detects toxic spans, replaces with `<mask>`
2. llm infilling - fills masks with safe alternatives  
3. global reranking - picks best candidate using toxicity, similarity, fluency

## scoring

For each candidate:
- toxicity score (xlm-r large)
- semantic similarity (labse)
- fluency (gpt-2 perplexity)

combined score: `w_T * (1-tox) + w_S * sim + w_F * flu`

## masking

Uses mistral to identify toxic words/phrases and replace with `<mask>`. Post-processing normalizes output and collapses multiple masks.

## infilling

Mistral fills each `<mask>` with safe text while keeping rest unchanged. Generates multiple candidates per input.

## reranking

Scores all candidates and picks highest scoring one based on safety, similarity, and fluency.

## eval

Computes bertscore, meaningbert, bleu, toxicity, perplexity if `run_eval=True`.

outputs saved to `data/model_outputs/{folder}/{dataset}/LLM_Mask_LLM_Global/`

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("try mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("dir:", XDETOX_DIR)
print("cache:", HF_CACHE)
print("cuda:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("folders ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    AutoModelForCausalLM,
    GPT2LMHeadModel, GPT2TokenizerFast,
)
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("done")

In [None]:
# data configs
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

In [None]:
# subset helper
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
# scoring helpers

DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"scoring device: {DEVICE_SCORE}")

_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].detach().cpu().tolist())
    return scores

_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1)
        masked = hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

_GPT2_NAME = "gpt2"
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

In [None]:
# eval helpers
def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False,
                        tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []

    mask_dir = "LLM_Mask_LLM_Global"
    base_path = os.path.join(base_out_dir, data_type, mask_dir)
    if not os.path.isdir(base_path):
        print("no eval dir found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "threshold":        0.20,
            "folder":           folder,
            "bertscore":        s.get("bertscore", np.nan),
            "meaningbert":      s.get("meaningbert", np.nan),
            "bleu4":            s.get("bleu4", np.nan),
            "perplexity_gen":   s.get("perplexity gen", np.nan),
            "perplexity_orig":  s.get("perplexity orig", np.nan),
            "toxicity_gen":     s.get("toxicity gen", np.nan),
            "toxicity_orig":    s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("saved summary csv:", out_csv)
    else:
        print("no eval files found")

In [None]:
# llm loader
USE_LLM_GPU = True
DEVICE_LLM = torch.device("cuda" if USE_LLM_GPU and torch.cuda.is_available() else "cpu")
print("llm device:", DEVICE_LLM)

LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
_LLM_TOKENIZER = None
_LLM_MODEL = None

def _lazy_load_llm():
    global _LLM_MODEL, _LLM_TOKENIZER
    if _LLM_MODEL is not None and _LLM_TOKENIZER is not None:
        return
    print(f"loading mistral on {DEVICE_LLM}")
    _LLM_TOKENIZER = AutoTokenizer.from_pretrained(LLM_MODEL_NAME)
    _LLM_MODEL = AutoModelForCausalLM.from_pretrained(
        LLM_MODEL_NAME,
        torch_dtype=torch.float16 if DEVICE_LLM.type == "cuda" else torch.float32,
        device_map=None,
    ).to(DEVICE_LLM)
    _LLM_MODEL.eval()
    print("loaded")


def _extract_bracket_content(text: str) -> str:
    text = text.strip()

    m = re.search(r"\[([^\]]*)\]", text, flags=re.DOTALL)
    if m:
        return m.group(1).strip()

    if "[" in text:
        return text.split("[", 1)[1].strip()

    return text

In [None]:
# llm masking

MASK_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences safer
by MASKING toxic words with the special token <mask>, while keeping the original sentence
structure and wording as much as possible.

You must behave like a MASKER, not a full rewriter.

Your task:
1. Identify toxic, offensive, or profane words or short phrases.
2. For each toxic span, replace the entire span with a single <mask> token.
3. There may be multiple toxic spans in one sentence, so multiple <mask> tokens are allowed.
4. If several neighboring words are toxic, you must still use only a single <mask> token
   in that place. In other words, if you would place "<mask> <mask>" or a longer sequence
   of <mask> tokens, collapse them into a single <mask> so that there are never multiple
   <mask> tokens in a row.
5. Do NOT rewrite, paraphrase, or summarize the sentence.
6. Do NOT add, remove, or reorder non-toxic words or punctuation.
7. Keep punctuation and spacing as close to the original as possible.
8. If there is no toxic content, return the sentence unchanged.

Output rules (format is very strict):
- ONLY return the final masked sentence inside ONE pair of square brackets, like:
  [This is a <mask> example.]
- Do NOT print anything before or after the brackets.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any language tags or metadata.
- Do NOT include additional '[' or ']' characters inside the sentence.
"""

MASK_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Step 1 - Identify toxic words: "stupid idiot", "crap"
Step 2 - Mask toxic words (do NOT rewrite the rest):
You're such a <mask>, nobody wants to hear your <mask>.
Final Output: [You're such a <mask>, nobody wants to hear your <mask>.]"""

def _postprocess_llm_mask(masked_text: str) -> str:
    s = masked_text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()

    s = re.sub(r"<\s*mask\s*>", "<mask>", s, flags=re.IGNORECASE)

    s = re.sub(r"(?:\s*<mask>\s*){2,}", " <mask> ", s)
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return masked_text.strip()

    return s

@torch.no_grad()
def llm_mask_sentences(sentences: List[str]) -> List[str]:
    _lazy_load_llm()
    masked = []
    for s in tqdm(sentences, desc="LLM masking", leave=False):
        messages = [
            {
                "role": "system",
                "content": MASK_SYSTEM_PROMPT + "\n\nBelow is an example:\n" + MASK_FEW_SHOT,
            },
            {
                "role": "user",
                "content": f"Toxic Sentence: {s}\nFinal Output:",
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                MASK_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + MASK_FEW_SHOT
                + "\n\nToxic Sentence: "
                + s
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=64,
            do_sample=False,
            temperature=0.0,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )
        gen_text = _LLM_TOKENIZER.decode(
            gen[0][inputs["input_ids"].shape[1]:],
            skip_special_tokens=True,
        )

        masked_text = _extract_bracket_content(gen_text)
        masked_text = _postprocess_llm_mask(masked_text)

        if not masked_text:
            masked_text = s
        masked.append(masked_text)

    return masked

In [None]:
# llm infilling

INFILL_SYSTEM_PROMPT = """You are a helpful assistant trained to make toxic or offensive sentences
more polite and respectful by INFILLING the special token <mask>.

You are NOT a free rewriter. You must keep all non-masked text as close as possible
to the given masked sentence.

You are given two inputs:
1) Toxic Sentence: the original toxic sentence.
2) Masked Sentence: the same sentence, where toxic spans are replaced with <mask>.

Your task:
1. For each <mask> token in the Masked Sentence, replace it with a short, non-toxic
   word or phrase that fits the context and preserves the meaning of the Toxic Sentence.
2. Do NOT modify any other words or punctuation outside the <mask> spans, unless a very
   small change is needed to fix grammar or agreement.
3. Preserve the original meaning and intent as much as possible, but make the sentence
   safe and respectful.
4. Keep the language the same as the original (do NOT translate).

Output rules (VERY STRICT):
- ONLY return the final detoxified sentence with all <mask> tokens filled.
- Wrap the final sentence in exactly ONE pair of square brackets, e.g.:
  [Detoxified sentence here.]
- Do NOT include the Toxic Sentence or Masked Sentence in your output.
- Do NOT add explanations, comments, or extra lines.
- Do NOT include any other '[' or ']' characters.
"""

INFILL_FEW_SHOT = """Toxic Sentence: You're such a stupid idiot, nobody wants to hear your crap.
Masked Sentence: You're such a <mask>, nobody wants to hear your <mask>.
Step 1 - Decide safe replacements for each <mask>: "rude person", "opinion"
Step 2 - Infill the masked sentence, keeping all other words the same:
You're such a rude person, nobody wants to hear your opinion.
Final Output: [You're such a rude person, nobody wants to hear your opinion.]"""

def _postprocess_llm_infill(text: str) -> str:
    s = text.strip()

    if s.startswith("[") and s.endswith("]") and len(s) > 2:
        s = s[1:-1].strip()
    else:
        if s.startswith("["):
            s = s[1:].strip()
        if s.endswith("]"):
            s = s[:-1].strip()

    s = re.sub(r"\s+", " ", s).strip()

    s = s.replace("<mask>", " ")
    s = re.sub(r"\s+", " ", s).strip()

    if not s:
        return text.strip()
    return s

@torch.no_grad()
def llm_infill_candidates(
    sources: List[str],
    masked: List[str],
    num_candidates: int = 3,
    temperature: float = 0.7,
    top_p: float = 0.95,
    max_new_tokens: int = 64,
    sample: bool = True,
) -> List[List[str]]:
    assert len(sources) == len(masked), "sources and masked length mismatch"
    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    _lazy_load_llm()
    all_cands: List[List[str]] = []

    for src, msk in tqdm(
        list(zip(sources, masked)),
        desc="llm infilling",
        leave=False,
    ):
        messages = [
            {
                "role": "system",
                "content": INFILL_SYSTEM_PROMPT + "\n\nHere is an example:\n" + INFILL_FEW_SHOT,
            },
            {
                "role": "user",
                "content": (
                    f"Toxic Sentence: {src}\n"
                    f"Masked Sentence: {msk}\n"
                    f"Final Output:"
                ),
            },
        ]
        try:
            prompt = _LLM_TOKENIZER.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True,
            )
        except Exception:
            prompt = (
                INFILL_SYSTEM_PROMPT
                + "\n\nExample:\n"
                + INFILL_FEW_SHOT
                + "\n\nToxic Sentence: "
                + src
                + "\nMasked Sentence: "
                + msk
                + "\nFinal Output:"
            )

        inputs = _LLM_TOKENIZER(prompt, return_tensors="pt").to(DEVICE_LLM)
        input_len = inputs["input_ids"].shape[1]

        gen = _LLM_MODEL.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=sample,
            temperature=float(temperature) if sample else 0.0,
            top_p=top_p,
            num_return_sequences=num_candidates,
            pad_token_id=_LLM_TOKENIZER.eos_token_id,
        )

        cand_list = []
        for idx in range(num_candidates):
            gen_text = _LLM_TOKENIZER.decode(
                gen[idx][input_len:],
                skip_special_tokens=True,
            )
            cleaned = _extract_bracket_content(gen_text)
            cleaned = _postprocess_llm_infill(cleaned)
            if not cleaned:
                cleaned = src
            cand_list.append(cleaned)

        all_cands.append(cand_list)

    return all_cands

In [None]:
# global reranking
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)

    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat_cands)
    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)
    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)
    sims = (sims + 1.0) / 2.0

    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    safety = 1.0 - tox

    scores = w_T * safety + w_S * sims + w_F * flus

    tox2     = tox.reshape(N, C)
    safety2  = safety.reshape(N, C)
    sims2    = sims.reshape(N, C)
    flus2    = flus.reshape(N, C)
    scores2  = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details

In [None]:
# folder naming
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_run_folder_name(
    llm_temperature: float,
    llm_top_p: float,
    llm_sample: bool,
    num_candidates: int,
    max_new_tokens: int,
):
    return (
        f"llmtemp{llm_temperature}_topp{llm_top_p}_"
        f"sample{_bool2str(llm_sample)}_"
        f"nc{num_candidates}_"
        f"maxntok{max_new_tokens}"
    )

In [None]:
# decompx masking + llm global reranking (unused in this notebook)

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_llm_gen_folder_name(
    temperature, sample, top_p, max_new_tokens, num_candidates
):
    return (
        "llm"
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topp" + str(top_p) +
        "_maxnew" + str(max_new_tokens) +
        "_ncand" + str(num_candidates)
    )

def _run_decompx_masking_and_llm_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    batch_size_mask,
    num_candidates,
    weights,
    llm_temperature,
    llm_top_p,
    llm_max_new_tokens,
    llm_sample,
    overwrite_gen=False,
    echo: bool = False,
    inputs=None,
):
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"inputs at thresh={thresh}: {len(inputs)}")

    mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        masker = Masker_single()
        decoded_masked_inputs_batches = _process_in_batches(
            masker, inputs, batch_size=batch_size_mask, thresh=thresh
        )
        decoded_masked_inputs = [
            item for sublist in decoded_masked_inputs_batches for item in sublist
        ]
        decoded_mask_inputs = [
            d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d).strip() + "\n")
        masker.release_model()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("reusing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print(f"\nexample masked at thresh {thresh:.2f}:")
        for i, m in enumerate(decoded_mask_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    gen_folder = _build_llm_gen_folder_name(
        temperature=llm_temperature,
        sample=llm_sample,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        num_candidates=num_candidates,
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("gen exists, skipping:", gen_txt)
        return

    _ensure_dir(final_abs)

    print(f"llm infilling: {num_candidates} cands (sample={llm_sample})")
    all_candidates = llm_infill_candidates(
        sources=inputs,
        masked=decoded_mask_inputs,
        num_candidates=num_candidates,
        temperature=llm_temperature,
        top_p=llm_top_p,
        max_new_tokens=llm_max_new_tokens,
        sample=llm_sample,
    )

    global _LLM_MODEL, _LLM_TOKENIZER
    try:
        del _LLM_MODEL
        del _LLM_TOKENIZER
    except Exception:
        pass
    _LLM_MODEL = None
    _LLM_TOKENIZER = None
    if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
        torch.cuda.empty_cache()

    print("global reranking...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    if echo:
        print(f"\ndetoxified at thresh {thresh:.2f}:")
        for i, g in enumerate(best_generations[:3]):
            print(f"  detox[{i}]: {g}")

    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("saved:", orig_txt)
    print("saved:", gen_txt)

In [None]:
# main detoxify function

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_llm_mask_infill_global",
    echo: bool = False,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    num_candidates: int = 3,
    llm_temperature: float = 0.7,
    llm_top_p: float = 0.95,
    llm_max_new_tokens: int = 64,
    llm_sample: bool = True,
    weights = (0.5, 0.3, 0.2),
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset: {subset_path}")
        print(f"output: {base_out_abs}")
        print(f"num examples: {num_inputs}")
        print(f"weights: {weights}")
        print(f"candidates: {num_candidates}")
        print("\nexample inputs:")
        for i, s in enumerate(inputs[:3]):
            print(f"  [{i}]: {s}")
        print("=" * 80)

    mask_dir = "LLM_Mask_LLM_Global"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        print("running llm masking...")
        masked_inputs = llm_mask_sentences(inputs)
        masked_inputs = [re.sub(r"\s+", " ", d).strip() for d in masked_inputs]
        with open(masked_file, "w") as f:
            for d in masked_inputs:
                f.write(d + "\n")
    else:
        with open(masked_file, "r") as f:
            masked_inputs = [s.strip() for s in f.readlines()]
        print("reusing masked file")

    assert len(masked_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print("\nmasked examples:")
        for i, m in enumerate(masked_inputs[:3]):
            print(f"  [{i}]: {m}")

    run_folder = _build_run_folder_name(
        llm_temperature=llm_temperature,
        llm_top_p=llm_top_p,
        llm_sample=llm_sample,
        num_candidates=num_candidates,
        max_new_tokens=llm_max_new_tokens,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)
    orig_txt = os.path.join(final_abs, "orig.txt")
    gen_txt = os.path.join(final_abs, "gen.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("gen exists, skipping:", gen_txt)
        with open(gen_txt, "r") as f:
            best_generations = [s.strip() for s in f.readlines()]
        if echo:
            print("\ndetoxified outputs:")
            for i in range(min(3, len(best_generations))):
                print(f"  [{i}]: {best_generations[i]}")
    else:
        print(f"llm infilling: {num_candidates} cands (sample={llm_sample})")
        all_candidates = llm_infill_candidates(
            sources=inputs,
            masked=masked_inputs,
            num_candidates=num_candidates,
            temperature=llm_temperature,
            top_p=llm_top_p,
            max_new_tokens=llm_max_new_tokens,
            sample=llm_sample,
        )

        global _LLM_MODEL, _LLM_TOKENIZER
        try:
            del _LLM_MODEL
            del _LLM_TOKENIZER
        except Exception:
            pass
        _LLM_MODEL = None
        _LLM_TOKENIZER = None
        if torch.cuda.is_available() and DEVICE_LLM.type == "cuda":
            torch.cuda.empty_cache()

        print("global reranking...")
        best_idx, details = rerank_candidates_global(
            sources=inputs,
            candidates=all_candidates,
            weights=weights,
        )
        best_generations = [
            all_candidates[i][best_idx[i]] for i in range(len(inputs))
        ]

        with open(orig_txt, "w") as f:
            for l in inputs:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")
        with open(gen_txt, "w") as f:
            for l in best_generations:
                f.write(re.sub(r"\s+", " ", l).strip() + "\n")

        print("saved:", orig_txt)
        print("saved:", gen_txt)

        if echo:
            print("\ndetoxified outputs:")
            for i in range(min(3, len(best_generations))):
                print(f"  [{i}]: {best_generations[i]}")

    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, mask_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if echo:
            stats_path = os.path.join(final_abs, "gen_stats.txt")
            if os.path.exists(stats_path):
                stats = _read_stats_file(stats_path)
                print("\neval metrics:")
                metric_keys = [
                    ("bertscore", "bertscore"),
                    ("meaningbert", "meaningbert"),
                    ("bleu4", "bleu4"),
                    ("perplexity gen", "ppl gen"),
                    ("perplexity orig", "ppl orig"),
                    ("toxicity gen", "tox gen"),
                    ("toxicity orig", "tox orig"),
                ]
                for key, label in metric_keys:
                    val = stats.get(key, None)
                    if isinstance(val, float) and math.isnan(val):
                        continue
                    if val is None:
                        continue
                    print(f"  {label}: {val:.4f}")
            else:
                print("\nno stats file found")

In [None]:
# example run (commented out)

# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_llm_mask_infill_global_demo_50_examples",
#     echo=True,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     num_candidates=10,
#     llm_temperature=0.7,
#     llm_top_p=0.95,
#     llm_max_new_tokens=64,
#     llm_sample=True,
#     weights=(0.5, 0.3, 0.2),
# )

In [19]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline",
    echo=True,
    num_examples=1000,
    run_eval=True,             # BLEU/BERTScore/MeaningBERT/PPL/Toxicity
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    num_candidates=10,         # LLM candidates per input
    llm_temperature=0.7,
    llm_top_p=0.95,
    llm_max_new_tokens=64,
    llm_sample=True,
    weights=(0.5, 0.3, 0.2),   # safety, similarity, fluency
)


[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
[echo] num_candidates per input: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Reusing existing masked_inputs.txt

[echo] Example LLM-masked inputs (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated and "<mask> right <mask> now <mask> " would be good .
  masked[2]: neither of my coworke

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LLM loaded.


LLM infilling (Mistral):   0%|          | 0/671 [00:00<?, ?it/s]

Global reranking (toxicity + similarity + fluency)...


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]

Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_maxntok64/gen.txt

[echo] Example detoxified outputs (first up to 3):
  detox[0]: or the loud obnoxious one - thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and "immediately" "right now" would be good.
  detox[2]: neither of my coworkers gave a lackadaisical response when it came time to let Mitch go. ugh.
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_LLM-Masking_LLM-Infilling_Global-Reranking_Pipeline/paradetox/LLM_Mask_LLM_Global/llmtemp0.7_topp0.95_sampleT_nc10_ma