# XDetox Pipeline with Global Reranking

This notebook mirrors `XDetox_Pipeline.ipynb` but replaces the **DecompX-based reranking** with a **global score**:

$$
T'(c) = 1 - T(c) \quad (\text{safety})
$$
$$
\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)
$$

where:
- $T(c)$ = toxicity (from `xlmr-large-toxicity-classifier-v2`, in [0,1])
- $S(c)$ = semantic similarity (LaBSE cosine similarity, mapped to [0,1])
- $F(c)$ = fluency (GPT-2 perplexity mapped to [0,1], higher = more fluent)

You control:
- **`weights=(w_T, w_S, w_F)`**
- **`num_candidates`** per input

Masking (DecompX) and evaluation (BLEU/BERTScore/MeaningBERT/PPL/Toxicity) remain as before.


In [1]:
#@title Mount Drive, Imports & locate XDetox
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

# Try My Drive
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox


In [2]:
#@title Runtime setup (paths, cache, GPU)
# HuggingFace cache inside the repo (persists on Drive)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

# Add repo to PYTHONPATH
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("XDETOX_DIR:", XDETOX_DIR)
print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: Tesla T4


In [6]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")


Repo folders OK.


In [3]:
#@title Install dependencies (restart runtime if major errors)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi

# BERTScore dependency required by evaluation/bertscore.py
!pip -q install bert-score


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m82.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m36.0 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.3.1 which is incompatible.[0m[

In [4]:
#@title Import from 'transformers'
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    GPT2LMHeadModel, GPT2TokenizerFast,
)



In [9]:
#@title Import from 'rewrite'
from rewrite.mask_orig import Masker as Masker_single
from rewrite.generation import Infiller
from rewrite import rewrite_example as rx
import argparse as _argparse

In [5]:
#@title NLTK data
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")


NLTK ready


In [10]:
#@title Data configs
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("Datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR


Datasets: microagressions_val, microagressions_test, sbf_val, sbf_test, dynabench_val, dynabench_test, jigsaw_toxic, paradetox, appdia_original, appdia_discourse


In [11]:
#@title Helpers: subset data
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    """
    Create a small subset file matching the expected format used by rewrite_example.get_data().
    Returns the path to the *new* subset file (or original path if n is None).
    """
    if n is None or n <= 0:
        return data_path  # no subset

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        # txt file
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        # tsv
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    # Fallback: copy original
    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out


In [12]:
#@title Global scoring helpers: toxicity, similarity, fluency

# Devices: Use GPU for scoring if available (much faster)
DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Scoring models will use: {DEVICE_SCORE}")

# ---------- Toxicity model (textdetox/xlmr-large-toxicity-classifier-v2) ----------
_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    """
    Returns toxicity probabilities in [0,1] for each input text.
    (0 = non-toxic, 1 = very toxic)
    """
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)  # [..., 2]
        scores.extend(probs[:, 1].detach().cpu().tolist())  # toxic prob
    return scores

# ---------- Semantic similarity (LaBSE) ----------
_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    """
    Returns a numpy array of shape (len(texts), hidden_dim).
    Mean-pooled LaBSE sentence embeddings.
    """
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state  # [B, L, H]
        mask = enc["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        masked = hidden * mask
        summed = masked.sum(dim=1)  # [B, H]
        counts = mask.sum(dim=1).clamp(min=1e-6)  # [B,1]
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

# ---------- Fluency via GPT-2 perplexity ----------
_GPT2_NAME = "gpt2"  # small model for speed; can switch to gpt2-medium if needed
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    """
    Simple sentence-level perplexity using GPT-2.
    Returns a list of floats (one per text).
    """
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4  # clip extreme
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    """
    Map perplexities to [0,1] fluency scores.
    Low perplexity -> high fluency.
    """
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

Scoring models will use: cuda


In [13]:
#@title Global reranking: combine toxicity, similarity, fluency
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    """
    sources: list[str], length N
    candidates: list[list[str]], shape N x C
    weights: (w_T, w_S, w_F)
    Returns:
        best_idx: np.ndarray of shape (N,), index of chosen candidate per source
        details: dict with matrices [N x C] for tox, safety, sim, flu, score
    """
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    # Flatten candidates and map to source indices
    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    # Toxicity
    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)  # [N*C]

    # Semantic similarity (LaBSE)
    src_embs = get_labse_embeddings(sources)  # [N, D]
    cand_embs = get_labse_embeddings(flat_cands)  # [N*C, D]
    # Normalize
    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)
    # Cosine between each candidate and its source
    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)  # [-1,1]
    sims = (sims + 1.0) / 2.0  # -> [0,1]

    # Fluency: GPT-2 PPL -> F in [0,1]
    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    # Safety
    safety = 1.0 - tox

    # Global score
    scores = w_T * safety + w_S * sims + w_F * flus

    # Reshape to [N, C]
    tox2     = tox.reshape(N, C)
    safety2  = safety.reshape(N, C)
    sims2    = sims.reshape(N, C)
    flus2    = flus.reshape(N, C)
    scores2  = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details


In [14]:
#@title Evaluation helpers (reuse evaluate_all.py with toxicity)
def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    """
    Call evaluation.evaluate_all (the new one with local toxicity) on each gen folder.
    """
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("Eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    """
    Read gen_stats.txt into a dict of floats; tolerate '(skipped): None'.
    """
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []
    for thresh in np.arange(0.15, 0.3, 0.05, dtype=np.float64):
        mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
        base_path = os.path.join(base_out_dir, data_type, mask_dir)
        if not os.path.isdir(base_path):
            continue
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),   # <-- new
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
            })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",   # <-- MeaningBERT between bertsore & bleu4
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")


In [15]:
#@title Masking + generation with global reranking


def _process_in_batches(masker, inputs, batch_size, thresh: float):
    batched_inputs = [
        inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
    ]
    results = []
    for batch in tqdm(batched_inputs, desc="Masking (DecompX)", leave=False):
        batch_result = masker.process_text(sentence=batch, threshold=thresh)
        results.append(batch_result)
    return results

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )

def _run_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    batch_size,
    alpha_a, alpha_e, alpha_b,
    temperature,
    rep_penalty,
    max_length,
    top_k_gen,
    top_p,
    filter_p,
    sample,
    num_candidates,
    weights,
    overwrite_gen=False,
):
    """
    For one threshold value:
      - load inputs using rewrite_example.get_data
      - mask with DecompX (Masker_single)
      - generate num_candidates samples per input with Infiller
      - rerank with global score
      - save orig.txt / gen.txt under the usual folder
    """
    # Load inputs using original get_data logic
    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    print(f"#inputs at thresh={thresh}: {len(inputs)}")

    # Paths
    mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    # Masking (reuse if exists unless you want to overwrite)
    if not os.path.exists(masked_file):
        masker = Masker_single()
        decoded_masked_inputs_batches = _process_in_batches(
            masker, inputs, batch_size=batch_size, thresh=thresh
        )
        decoded_masked_inputs = [
            item for sublist in decoded_masked_inputs_batches for item in sublist
        ]
        decoded_mask_inputs = [
            d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d) + "\n")
        masker.release_model()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("Reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    # Initialize Infiller (same as in rewrite_example)
    rewriter = Infiller(
        seed=0,
        base_path="facebook/bart-base",
        antiexpert_path="hallisky/bart-base-toxic-antiexpert",
        expert_path="hallisky/bart-base-nontoxic-expert",
        base_type="base",
        antiexpert_type="antiexpert",
        expert_type="expert",
        tokenizer="facebook/bart-base",
    )

    # Build generation folder name
    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"
    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("Generation already exists at:", gen_txt, "— skipping generation.")
        return

    _ensure_dir(final_abs)

    # Generate multiple candidates per input
    all_candidates: List[List[str]] = [[] for _ in range(len(inputs))]

    print(f"Generating {num_candidates} candidates per input (sampling={sample})")
    for c in range(num_candidates):
        outs, decoded = rewriter.generate(
            inputs,
            decoded_mask_inputs,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            verbose=False,
            max_length=max_length,
            repetition_penalty=rep_penalty,
            p=top_p,
            filter_p=filter_p,
            k=top_k_gen,
            batch_size=batch_size,
            sample=sample,
            ranking=False,          # <-- no DecompX reranking
            ranking_eval_output=0,
        )
        for i, text in enumerate(decoded):
            all_candidates[i].append(re.sub(r"\s+", " ", text).strip())

    # --- Memory Cleanup before Scoring ---
    # We are done with the BART generation model.
    # Delete it and clear CUDA cache to make room for XLM-R (toxicity) and LaBSE (similarity).
    del rewriter
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    # -------------------------------------

    # Global reranking
    print("Global reranking (toxicity + similarity + fluency)...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    # Save orig + chosen gen
    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("Saved:", orig_txt)
    print("Saved:", gen_txt)

In [16]:
#@title `detoxify()` — masking + global reranking + optional eval

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_global",
    thresholds = (0.15, 0.20, 0.25),
    batch_size: int = 10,
    ranking: bool = True,   # kept for API symmetry; if False, you could adapt to skip reranking
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,   # if None, take from data_configs
    alpha_e: float = None,   # if None, take from data_configs
    alpha_b: float = 1.0,
    temperature: float = None,  # if None, from data_configs
    rep_penalty: float = None,  # if None, from data_configs
    num_examples: int = 100,    # small-batch control; set None to use full dataset
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    # NEW:
    weights = (0.5, 0.3, 0.2),   # (w_T, w_S, w_F)
    num_candidates: int = 3,     # candidates per input
):
    """
    Run XDetox with DecompX masking + global reranking based on:
      - toxicity (XLM-R large)
      - semantic similarity (LaBSE)
      - fluency (GPT-2 perplexity -> [0,1])

    Parameters match XDetox_Pipeline.ipynb plus:
      - weights: (w_T, w_S, w_F)
      - num_candidates: int, candidates per input
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    # fallbacks from data_configs
    if alpha_a is None: alpha_a = cfg["alpha_a"]
    if alpha_e is None: alpha_e = cfg["alpha_e"]
    if temperature is None: temperature = cfg["temperature"]
    if rep_penalty is None: rep_penalty = cfg["rep_penalty"]

    base_out_rel = os.path.join("data", "dexp_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    # subset path (file)
    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    print(f"Data type: {data_type}")
    print(f"Subset path: {subset_path}")
    print(f"Output base: {base_out_abs}")
    print(f"Weights (w_T, w_S, w_F): {weights}")
    print(f"num_candidates per input: {num_candidates}")

    # run thresholds
    for t in thresholds:
        print("=" * 60)
        print(f"Threshold (DecompX) = {t:.2f}")
        _run_global_reranking_for_threshold(
            data_type=data_type,
            subset_path=subset_path,
            thresh=t,
            base_out_rel=base_out_rel,
            batch_size=batch_size,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            rep_penalty=rep_penalty,
            max_length=max_length,
            top_k_gen=top_k_gen,
            top_p=top_p,
            filter_p=filter_p,
            sample=sample,
            num_candidates=num_candidates,
            weights=weights,
            overwrite_gen=overwrite_gen,
        )

        if run_eval:
            mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
            base_path = os.path.join(base_out_abs, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

    # Summarize metrics across thresholds
    if run_eval:
        _aggregate_eval_csv(output_folder, data_type,
                            os.path.join(REPO, "data", "dexp_outputs", output_folder))


In [17]:
#@title Example run — paradetox, small subset with global reranking

# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_global_demo",
#     thresholds=(0.20,),          # single threshold for quick test
#     batch_size=8,                # T4-friendly
#     ranking=True,                # flag kept for symmetry, global reranking is always used here
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,             # small subset
#     run_eval=True,               # BLEU/BERTScore/PPL/Toxicity via evaluate_all.py
#     overwrite_gen=False,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     weights=(0.5, 0.3, 0.2),     # (w_T, w_S, w_F): safety, similarity, fluency
#     num_candidates=20             # candidates per input for reranking
# )


In [20]:
detoxify(
    data_type="paradetox",
    output_folder="colab_run_global_demo",
    thresholds=(0.20,),
    batch_size=8,
    ranking=True,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,        # <-- keep this True
    overwrite_gen=False,  # <-- do NOT touch generation
    overwrite_eval=True,  # <-- force recompute gen_stats.txt
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=20,
)


Data type: paradetox
Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/dexp_outputs/colab_run_global_demo
Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
num_candidates per input: 20
Threshold (DecompX) = 0.20
#inputs at thresh=0.2: 50
Reusing existing masked_inputs.txt
Found 1 GPUS!
Generation already exists at: /content/drive/MyDrive/w266 - Project/XDetox/data/dexp_outputs/colab_run_global_demo/paradetox/DecompX0.2/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/gen.txt — skipping generation.
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/dexp_outputs/colab_run_global_demo/paradetox/DecompX0.2/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/orig.txt --gen_path /conte