# xdetox pipeline w/ global reranking

runs the full pipeline with:

1. **decompx masking** (token-level toxicity attribution on roberta).
2. **marco-style generation** (base / expert / anti-expert bart mixture).
3. **global reranking** of multiple candidates using:
   - **toxicity** (xlm-r large classifier).
   - **semantic similarity** (labse).
   - **fluency** (gpt-2 perplexity).

goal: pick best detoxified candidate that is:
- as non-toxic as possible
- semantically close to original
- fluent

---

## scoring

For each candidate $c$ we compute:

- $T(c)$: toxicity in $[0,1]$ from `textdetox/xlmr-large-toxicity-classifier-v2`.
- $S(c)$: semantic similarity in $[0,1]$, from labse cosine similarity.
- $F(c)$: fluency in $[0,1]$, from gpt-2 perplexity mapped to a score
  (low perplexity → high fluency).

convert toxicity to safety score:

$$T'(c) = 1 - T(c)$$

global score:

$$\text{Score}(c) = w_T \cdot T'(c) + w_S \cdot S(c) + w_F \cdot F(c)$$

weights:

- `weights = (w_T, w_S, w_F)`  
  - `w_T`: importance of **safety** (low toxicity).  
  - `w_S`: importance of **semantic similarity**.  
  - `w_F`: importance of **fluency**.

for each input:

1. generate `num_candidates` candidates.
2. score each candidate.
3. select highest-scoring candidate.

---

## masking and generation

for each dataset:

1. **subset selection**  
   - can run on full dataset or only first `num_examples` rows.
   - subset file written under `datasets/_subsets/{data_type}/`.

2. **decompx masking**  
   - use `rewrite.mask_orig.Masker` (roberta + decompx) to detect toxic tokens.
   - toxic tokens replaced by `<mask>`.
   - masked inputs saved as `masked_inputs.txt`.

3. **marco generation (bart ensemble)**  
   - use `rewrite.generation.Infiller` with:
     - base model (bart),
     - expert model (non-toxic),
     - anti-expert model (toxic).
   - generation controlled by:
     - `alpha_a`, `alpha_e`, `alpha_b` (expert/anti-expert weights),
     - `temperature`, `top_k_gen`, `top_p`, `filter_p`, `rep_penalty`, `max_length`,
     - `sample` (use sampling or greedy decoding).
   - for each input, sample `num_candidates` different candidates.

4. **global reranking**  
   - all candidates scored by toxicity + similarity + fluency.
   - best candidate written to `gen.txt` (one line per input).
   - original texts written to `orig.txt`.

---

## evaluation

if `run_eval=True`, pipeline calls `evaluation.evaluate_all` to compute:

- bertscore (f1)
- meaningbert
- bleu-4
- toxicity (orig / gen)
- perplexity (orig / gen)

writes `gen_stats.txt` for each `(threshold, run)` folder.  
also creates summary csv per dataset:

- `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv`

with metrics aggregated across thresholds.

---

## how to use

signature:

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_global",
    thresholds = (0.15, 0.20, 0.25),
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    weights = (0.5, 0.3, 0.2),
    num_candidates: int = 3,
)
```

### key arguments

* **core i/o**

  * `data_type`: dataset key from `data_configs`
    (e.g. `"paradetox"`, `"dynabench_val"`, `"jigsaw_toxic"`, etc.).
  * `output_folder`: folder under `data/model_outputs/` where results stored.

* **masking / thresholds**

  * `thresholds`: tuple of decompx thresholds to try
    (e.g. `(0.15, 0.20, 0.25)`). each value creates a `DecompX{thresh}` subfolder.
  * `num_examples`: if set, only first `num_examples` used (for quick runs).
    use `None` to process full dataset.

* **generation (marco / bart)**

  * `sample`: `True` → stochastic sampling; `False` → greedy decoding.
  * `top_k_gen`: top-k for sampling on ensembled logits.
  * `top_p`: nucleus sampling on ensembled logits.
  * `filter_p`: nucleus filter on **base** logits (advanced; often `1.0`).
  * `max_length`: maximum generation length (tokens).
  * `alpha_a`, `alpha_e`, `alpha_b`: anti-expert / expert / base weights.
    if `None`, defaults taken from `data_configs[data_type]`.
  * `temperature`: sampling temperature; if `None`, use dataset default.
  * `rep_penalty`: repetition penalty; if `None`, use dataset default.
  * `batch_size`: generation batch size (trade-off of speed vs memory).

* **global reranking**

  * `weights=(w_T, w_S, w_F)`: weights for safety, similarity, fluency.
  * `num_candidates`: how many candidates to generate per input
    (larger → better reranking, but slower).

* **evaluation**

  * `run_eval`: if `True`, run evaluation and write `gen_stats.txt`.
  * `overwrite_gen`: if `True`, regenerate even if `gen.txt` exists.
  * `overwrite_eval`: if `True`, recompute evaluation even if `gen_stats.txt` exists.
  * `skip_ref_eval`: if `True`, skip perplexity on references.

* **echo**
  * `echo`: if `True`, print example inputs, masked inputs, generated outputs, and per-threshold evaluation metrics.
---

## example calls

quick sanity check on small subset:

```python
detoxify(
    data_type="paradetox",
    output_folder="colab_run_global_demo",
    thresholds=(0.20,),
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=50,
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=True,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=20,
)
```

full-dataset run:

```python
detoxify(
    data_type="paradetox",
    output_folder="paradetox_full_global",
    thresholds=(0.15, 0.20, 0.25),
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=None,
    run_eval=True,
    overwrite_gen=False,
    overwrite_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=10,
)
```

after running `detoxify`, check:

* `data/model_outputs/{output_folder}/{data_type}/DecompX*/.../orig.txt`
* `data/model_outputs/{output_folder}/{data_type}/DecompX*/.../gen.txt`
* `data/model_outputs/{output_folder}/{data_type}/DecompX*/.../gen_stats.txt`
* `data/model_outputs/{output_folder}/{data_type}/{data_type}.csv` (summary).

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, glob, re, sys, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
import torch
import nltk
from typing import List

candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("trying mydrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("using xdetox_dir:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("xdetox_dir:", XDETOX_DIR)
print("cache:", HF_CACHE)
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("repo folders ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi

!pip -q install bert-score

In [None]:
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    GPT2LMHeadModel, GPT2TokenizerFast,
)

In [None]:
from rewrite.mask_orig import Masker as Masker_single
from rewrite.generation import Infiller
from rewrite import rewrite_example as rx
import argparse as _argparse

In [None]:
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("ready")

In [None]:
data_configs = {
    "microagressions_val": {
        "data_path": "./datasets/microagressions/val.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.25,
        "temperature": 2.5,
    },
    "sbf_val": {
        "data_path": "./datasets/sbf/sbfdev.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "rep_penalty": 1.5,
        "alpha_a": 1.5,
        "alpha_e": 5.0,
        "temperature": 2.9,
    },
    "dynabench_val": {
        "data_path": "./datasets/dynabench/db_dev.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "rep_penalty": 1.0,
        "alpha_a": 1.5,
        "alpha_e": 4.75,
        "temperature": 2.5,
    }
}
print("datasets:", ", ".join(data_configs.keys()))

REPO = XDETOX_DIR

In [None]:
def _abs_repo_path(rel: str) -> str:
    return os.path.join(REPO, rel.lstrip("./"))

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def _subset_for_data_type(data_type, data_path, n, out_dir):
    if n is None or n <= 0:
        return data_path

    src = _abs_repo_path(data_path)
    _ensure_dir(out_dir)

    if "microagressions" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "sbf" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if "dynabench" in data_path:
        df = pd.read_csv(src)
        sub = df.head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        sub.to_csv(out, index=False)
        return out

    if any(k in data_path for k in ["paradetox", "jigsaw"]):
        if data_path.endswith(".txt"):
            with open(src, "r") as f:
                lines = [s.rstrip("\n") for s in f.readlines()]
            out = os.path.join(out_dir, os.path.basename(src))
            with open(out, "w") as g:
                for s in lines[:n]:
                    g.write(s + "\n")
            return out
        elif data_path.endswith(".csv"):
            df = pd.read_csv(src).head(n)
            out = os.path.join(out_dir, os.path.basename(src))
            df.to_csv(out, index=False)
            return out

    if "appdia" in data_path:
        df = pd.read_csv(src, sep="\t").head(n)
        out = os.path.join(out_dir, os.path.basename(src))
        df.to_csv(out, sep="\t", index=False)
        return out

    out = os.path.join(out_dir, os.path.basename(src))
    shutil.copy(src, out)
    return out

In [None]:
DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"scoring models will use: {DEVICE_SCORE}")

_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_TOKENIZER is None or _TOX_MODEL is None:
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(
            _TOX_MODEL_NAME
        ).to(DEVICE_SCORE)
        _TOX_MODEL.eval()

@torch.no_grad()
def get_toxicity_scores(texts, batch_size=32):
    _lazy_load_tox()
    scores = []
    for i in tqdm(range(0, len(texts), batch_size), desc="toxicity", leave=False):
        batch = texts[i:i+batch_size]
        enc = _TOX_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=512, padding=True
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].detach().cpu().tolist())
    return scores

_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_TOKENIZER is None or _LABSE_MODEL is None:
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE)
        _LABSE_MODEL.eval()

@torch.no_grad()
def get_labse_embeddings(texts, batch_size=32):
    _lazy_load_labse()
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="labse embeddings", leave=False):
        batch = texts[i:i+batch_size]
        enc = _LABSE_TOKENIZER(
            batch, return_tensors="pt",
            truncation=True, max_length=256, padding=True
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1)
        masked = hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

_GPT2_NAME = "gpt2"
_GPT2_TOKENIZER = None
_GPT2_MODEL = None

def _lazy_load_gpt2():
    global _GPT2_TOKENIZER, _GPT2_MODEL
    if _GPT2_TOKENIZER is None or _GPT2_MODEL is None:
        _GPT2_TOKENIZER = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MODEL = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE)
        _GPT2_MODEL.eval()

@torch.no_grad()
def get_gpt2_perplexities(texts):
    import math as _math
    _lazy_load_gpt2()
    ppls = []
    for s in tqdm(texts, desc="gpt-2 ppl", leave=False):
        enc = _GPT2_TOKENIZER(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MODEL(enc["input_ids"], labels=enc["input_ids"])
        ppl = _math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls, p_min=5.0, p_max=300.0):
    import math as _math
    ppls = np.asarray(ppls, dtype=float)
    p = np.clip(ppls, p_min, p_max)
    log_p = np.log(p)
    log_min = _math.log(p_min)
    log_max = _math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

In [None]:
def rerank_candidates_global(
    sources,
    candidates,
    weights=(0.5, 0.3, 0.2),
    ppl_min=5.0,
    ppl_max=300.0,
):
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return np.array([], dtype=int), {}

    C_list = [len(c) for c in candidates]
    assert len(set(C_list)) == 1, "All inputs must have same num_candidates"
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands = []
    flat_src_idx = []
    for i, cand_list in enumerate(candidates):
        for cand in cand_list:
            flat_cands.append(cand)
            flat_src_idx.append(i)
    flat_src_idx = np.array(flat_src_idx, dtype=int)

    tox = np.array(get_toxicity_scores(flat_cands), dtype=float)

    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat_cands)
    src_embs = src_embs / np.clip(np.linalg.norm(src_embs, axis=1, keepdims=True), 1e-8, None)
    cand_embs = cand_embs / np.clip(np.linalg.norm(cand_embs, axis=1, keepdims=True), 1e-8, None)
    sims = np.sum(cand_embs * src_embs[flat_src_idx], axis=1)
    sims = (sims + 1.0) / 2.0

    ppls = np.array(get_gpt2_perplexities(flat_cands), dtype=float)
    flus = perplexity_to_fluency(ppls, p_min=ppl_min, p_max=ppl_max)

    safety = 1.0 - tox

    scores = w_T * safety + w_S * sims + w_F * flus

    tox2     = tox.reshape(N, C)
    safety2  = safety.reshape(N, C)
    sims2    = sims.reshape(N, C)
    flus2    = flus.reshape(N, C)
    scores2  = scores.reshape(N, C)

    best_idx = scores2.argmax(axis=1)
    details = {
        "tox": tox2,
        "safety": safety2,
        "sim": sims2,
        "flu": flus2,
        "score": scores2,
    }
    return best_idx, details

In [None]:
def _parse_run_folder_name(folder_name):
    pattern = r"aa(\d+\.\d+)_ae(\d+\.\d+)_ab(\d+\.\d+)_base(.*?)_anti(.*?)_expert(.*?)_temp(\d+\.\d+)_sample(.*?)_topk(\d+)_reppenalty(\d+\.\d+)_filterp(\d+\.\d+)_maxlength(\d+)_topp(\d+\.\d+)"
    m = re.match(pattern, folder_name)
    return bool(m)

def _eval_with_toxicity(base_path, overwrite_eval=False, skip_ref=False, tox_threshold=0.5, tox_batch_size=32):
    import sys as _sys, os as _os
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir) or not _parse_run_folder_name(folder):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (":" + env.get("PYTHONPATH","") if env.get("PYTHONPATH") else "")
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

def _aggregate_eval_csv(output_folder, data_type, base_out_dir):
    rows = []
    for thresh in np.arange(0.15, 0.3, 0.05, dtype=np.float64):
        mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
        base_path = os.path.join(base_out_dir, data_type, mask_dir)
        if not os.path.isdir(base_path):
            continue
        for folder in os.listdir(base_path):
            gen_dir = os.path.join(base_path, folder)
            stats_path = os.path.join(gen_dir, "gen_stats.txt")
            if not os.path.exists(stats_path):
                continue
            s = _read_stats_file(stats_path)
            rows.append({
                "threshold":        float(f"{thresh:.2f}"),
                "folder":           folder,
                "bertscore":        s.get("bertscore", np.nan),
                "meaningbert":      s.get("meaningbert", np.nan),
                "bleu4":            s.get("bleu4", np.nan),
                "perplexity_gen":   s.get("perplexity gen", np.nan),
                "perplexity_orig":  s.get("perplexity orig", np.nan),
                "toxicity_gen":     s.get("toxicity gen", np.nan),
                "toxicity_orig":    s.get("toxicity orig", np.nan),
            })

    if rows:
        cols = [
            "threshold", "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote summary csv:", out_csv)
    else:
        print("no evaluation files found.")

In [None]:
def _process_in_batches(masker, inputs, batch_size, thresh: float):
    batched_inputs = [
        inputs[i : i + batch_size] for i in range(0, len(inputs), batch_size)
    ]
    results = []
    for batch in tqdm(batched_inputs, desc="masking", leave=False):
        batch_result = masker.process_text(sentence=batch, threshold=thresh)
        results.append(batch_result)
    return results

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_gen_folder_name(
    alpha_a, alpha_e, alpha_b,
    base_type, antiexpert_type, expert_type,
    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
):
    return (
        "aa" + str(alpha_a) +
        "_ae" + str(alpha_e) +
        "_ab" + str(alpha_b) +
        "_base" + base_type[:5] +
        "_anti" + antiexpert_type[:5] +
        "_expert" + expert_type[:5] +
        "_temp" + str(temperature) +
        "_sample" + _bool2str(sample) +
        "_topk" + str(top_k_gen) +
        "_reppenalty" + str(rep_penalty) +
        "_filterp" + str(filter_p) +
        "_maxlength" + str(max_length) +
        "_topp" + str(top_p)
    )

def _run_global_reranking_for_threshold(
    data_type,
    subset_path,
    thresh,
    base_out_rel,
    batch_size,
    alpha_a, alpha_e, alpha_b,
    temperature,
    rep_penalty,
    max_length,
    top_k_gen,
    top_p,
    filter_p,
    sample,
    num_candidates,
    weights,
    overwrite_gen=False,
    echo: bool = False,
    inputs=None,
):
    if inputs is None:
        args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
        inputs = rx.get_data(args_data)
    print(f"#inputs at thresh={thresh}: {len(inputs)}")

    mask_dir = f"DecompX{abs(thresh):g}" if thresh != 0 else "DecompX0.0"
    cur_rel = os.path.join(base_out_rel, data_type, mask_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    masked_file = os.path.join(cur_abs, "masked_inputs.txt")

    if not os.path.exists(masked_file):
        masker = Masker_single()
        decoded_masked_inputs_batches = _process_in_batches(
            masker, inputs, batch_size=batch_size, thresh=thresh
        )
        decoded_masked_inputs = [
            item for sublist in decoded_masked_inputs_batches for item in sublist
        ]
        decoded_mask_inputs = [
            d.replace("<s>", "").replace("</s>", "") for d in decoded_masked_inputs
        ]
        with open(masked_file, "w") as f:
            for d in decoded_mask_inputs:
                f.write(re.sub(r"\s+", " ", d) + "\n")
        masker.release_model()
    else:
        with open(masked_file, "r") as f:
            decoded_mask_inputs = [s.strip() for s in f.readlines()]
        print("reusing existing masked_inputs.txt")

    assert len(decoded_mask_inputs) == len(inputs), "Masked vs inputs mismatch"

    if echo:
        print(f"\nexample masked inputs at threshold {thresh:.2f} (first up to 3):")
        for i, m in enumerate(decoded_mask_inputs[:3]):
            print(f"  masked[{i}]: {m}")

    rewriter = Infiller(
        seed=0,
        base_path="facebook/bart-base",
        antiexpert_path="hallisky/bart-base-toxic-antiexpert",
        expert_path="hallisky/bart-base-nontoxic-expert",
        base_type="base",
        antiexpert_type="antiexpert",
        expert_type="expert",
        tokenizer="facebook/bart-base",
    )

    base_type = "base"
    antiexpert_type = "antiexpert"
    expert_type = "expert"
    gen_folder = _build_gen_folder_name(
        alpha_a, alpha_e, alpha_b,
        base_type, antiexpert_type, expert_type,
        temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
    )
    final_abs = os.path.join(cur_abs, gen_folder)
    gen_txt = os.path.join(final_abs, "gen.txt")
    orig_txt = os.path.join(final_abs, "orig.txt")

    if os.path.exists(gen_txt) and not overwrite_gen:
        print("generation already exists at:", gen_txt)
        return

    _ensure_dir(final_abs)

    all_candidates: List[List[str]] = [[] for _ in range(len(inputs))]

    print(f"generating {num_candidates} candidates per input (sampling={sample})")
    for c in range(num_candidates):
        outs, decoded = rewriter.generate(
            inputs,
            decoded_mask_inputs,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            verbose=False,
            max_length=max_length,
            repetition_penalty=rep_penalty,
            p=top_p,
            filter_p=filter_p,
            k=top_k_gen,
            batch_size=batch_size,
            sample=sample,
            ranking=False,
            ranking_eval_output=0,
        )
        for i, text in enumerate(decoded):
            all_candidates[i].append(re.sub(r"\s+", " ", text).strip())

    del rewriter
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    print("global reranking (toxicity + similarity + fluency)...")
    best_idx, details = rerank_candidates_global(
        sources=inputs,
        candidates=all_candidates,
        weights=weights,
    )
    best_generations = [
        all_candidates[i][best_idx[i]] for i in range(len(inputs))
    ]

    if echo:
        print(f"\nexample detoxified outputs at threshold {thresh:.2f} (first up to 3):")
        for i, g in enumerate(best_generations[:3]):
            print(f"  detox[{i}]: {g}")

    with open(orig_txt, "w") as f:
        for l in inputs:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")
    with open(gen_txt, "w") as f:
        for l in best_generations:
            f.write(re.sub(r"\s+", " ", l).strip() + "\n")

    print("saved:", orig_txt)
    print("saved:", gen_txt)

In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "colab_run_global",
    thresholds = (0.15, 0.20, 0.25),
    echo: bool = False,
    batch_size: int = 10,
    sample: bool = True,
    top_k_gen: int = 50,
    top_p: float = 0.95,
    filter_p: float = 1.0,
    max_length: int = 128,
    alpha_a: float = None,
    alpha_e: float = None,
    alpha_b: float = 1.0,
    temperature: float = None,
    rep_penalty: float = None,
    num_examples: int = 100,
    overwrite_gen: bool = False,
    run_eval: bool = False,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
    weights = (0.5, 0.3, 0.2),
    num_candidates: int = 3,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"
    cfg = data_configs[data_type].copy()

    if num_candidates < 1:
        raise ValueError("num_candidates must be >= 1")

    if alpha_a is None:
        alpha_a = cfg["alpha_a"]
    if alpha_e is None:
        alpha_e = cfg["alpha_e"]
    if temperature is None:
        temperature = cfg["temperature"]
    if rep_penalty is None:
        rep_penalty = cfg["rep_penalty"]

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    original_data_path = cfg["data_path"]
    subset_dir = os.path.join(REPO, "datasets", "_subsets", data_type)
    _ensure_dir(subset_dir)
    subset_path = _subset_for_data_type(
        data_type, original_data_path, num_examples, subset_dir
    )

    args_data = _argparse.Namespace(data_type=data_type, data_path=subset_path)
    inputs = rx.get_data(args_data)
    num_inputs = len(inputs)

    if echo:
        print("=" * 80)
        print(f"dataset: {data_type}")
        print(f"subset path: {subset_path}")
        print(f"output base: {base_out_abs}")
        print(f"number of examples to detoxify: {num_inputs}")
        print(f"thresholds: {', '.join(f'{t:.2f}' for t in thresholds)}")
        print(f"Weights (w_T, w_S, w_F): {weights}")
        print(f"num_candidates per input: {num_candidates}")
        print("\nexample inputs (first up to 3):")
        for i, s in enumerate(inputs[:3]):
            print(f"  input[{i}]: {s}")
        print("=" * 80)

    for t in thresholds:
        print("=" * 60)
        print(f"threshold (decompx) = {t:.2f}")
        _run_global_reranking_for_threshold(
            data_type=data_type,
            subset_path=subset_path,
            thresh=t,
            base_out_rel=base_out_rel,
            batch_size=batch_size,
            alpha_a=alpha_a,
            alpha_e=alpha_e,
            alpha_b=alpha_b,
            temperature=temperature,
            rep_penalty=rep_penalty,
            max_length=max_length,
            top_k_gen=top_k_gen,
            top_p=top_p,
            filter_p=filter_p,
            sample=sample,
            num_candidates=num_candidates,
            weights=weights,
            overwrite_gen=overwrite_gen,
            echo=echo,
            inputs=inputs,
        )

        if run_eval:
            mask_dir = f"DecompX{abs(t):g}" if t != 0 else "DecompX0.0"
            base_path = os.path.join(base_out_abs, data_type, mask_dir)
            _eval_with_toxicity(
                base_path,
                overwrite_eval=overwrite_eval,
                skip_ref=skip_ref_eval,
                tox_threshold=0.5,
                tox_batch_size=32,
            )

            if echo:
                base_type = "base"
                antiexpert_type = "antiexpert"
                expert_type = "expert"
                gen_folder = _build_gen_folder_name(
                    alpha_a, alpha_e, alpha_b,
                    base_type, antiexpert_type, expert_type,
                    temperature, sample, top_k_gen, rep_penalty, filter_p, max_length, top_p
                )
                stats_path = os.path.join(base_path, gen_folder, "gen_stats.txt")
                if os.path.exists(stats_path):
                    stats = _read_stats_file(stats_path)
                    print(f"\nevaluation metrics for this run (t={t:.2f}):")
                    metric_keys = [
                        ("bertscore",        "bertscore"),
                        ("meaningbert",      "meaningbert"),
                        ("bleu4",            "bleu-4"),
                        ("perplexity gen",   "perplexity (gen)"),
                        ("perplexity orig",  "perplexity (orig)"),
                        ("toxicity gen",     "toxicity (gen)"),
                        ("toxicity orig",    "toxicity (orig)"),
                    ]
                    for key, label in metric_keys:
                        val = stats.get(key, None)
                        if isinstance(val, float) and math.isnan(val):
                            continue
                        if val is None:
                            continue
                        print(f"  {label}: {val:.4f}")
                else:
                    print(f"gen_stats.txt not found at {stats_path}")

    if run_eval:
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

In [None]:
# detoxify(
#     data_type="paradetox",
#     output_folder="colab_run_global_demo",
#     thresholds=(0.20,),
#     batch_size=8,
#     sample=True,
#     top_k_gen=50,
#     top_p=0.95,
#     max_length=96,
#     num_examples=50,
#     run_eval=True,
#     overwrite_gen=True,
#     overwrite_eval=True,
#     skip_ref_eval=False,
#     weights=(0.5, 0.3, 0.2),
#     num_candidates=20,
# )


In [32]:
detoxify(
    data_type="paradetox",
    output_folder="XDetox_w_DecompX-Masking-Global-Reranking_Pipeline",
    thresholds=(0.20,),
    echo=True,
    batch_size=8,
    sample=True,
    top_k_gen=50,
    top_p=0.95,
    max_length=96,
    num_examples=1000,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    skip_ref_eval=False,
    weights=(0.5, 0.3, 0.2),
    num_candidates=10,
)

[echo] Dataset: paradetox
[echo] Subset path: /content/drive/MyDrive/w266 - Project/XDetox/datasets/_subsets/paradetox/test_toxic_parallel.txt
[echo] Output base: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking-Global-Reranking_Pipeline
[echo] Number of examples to detoxify: 671
[echo] Thresholds: 0.20
[echo] Weights (w_T, w_S, w_F): (0.5, 0.3, 0.2)
[echo] num_candidates per input: 10

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .
Threshold (DecompX) = 0.20
#inputs at thresh=0.2: 671


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Masking (DecompX):   0%|          | 0/84 [00:00<?, ?it/s]




[echo] Example masked inputs at threshold 0.20 (first up to 3):
  masked[0]: . or the loud <mask> one - thousand ton beast roaring towards you howling its horn .
  masked[1]: mandated  and " right <mask> now " would be good .
  masked[2]: * neither * of my coworkers gave a <mask> when it came time to ditch mitch . ugh .
Found 1 GPUS!
Generating 10 candidates per input (sampling=True)


Filling in masks: 100%|██████████| 84/84 [01:30<00:00,  1.08s/it]
Filling in masks: 100%|██████████| 84/84 [01:30<00:00,  1.08s/it]
Filling in masks: 100%|██████████| 84/84 [01:36<00:00,  1.15s/it]
Filling in masks: 100%|██████████| 84/84 [01:34<00:00,  1.13s/it]
Filling in masks: 100%|██████████| 84/84 [01:37<00:00,  1.17s/it]
Filling in masks: 100%|██████████| 84/84 [01:26<00:00,  1.03s/it]
Filling in masks: 100%|██████████| 84/84 [01:37<00:00,  1.16s/it]
Filling in masks: 100%|██████████| 84/84 [01:42<00:00,  1.22s/it]
Filling in masks: 100%|██████████| 84/84 [01:34<00:00,  1.12s/it]
Filling in masks: 100%|██████████| 84/84 [01:28<00:00,  1.06s/it]


Global reranking (toxicity + similarity + fluency)...


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]


[echo] Example detoxified outputs at threshold 0.20 (first up to 3):
  detox[0]: . or the loud sound of a one- thousand ton beast roaring towards you howling its horn.
  detox[1]: mandated and "right now" would be good.
  detox[2]: *nor* of my coworkers gave a hoot about the fact that they were all over the place when it came time to ditch mitch. ugh.
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking-Global-Reranking_Pipeline/paradetox/DecompX0.2/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/orig.txt
Saved: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/XDetox_w_DecompX-Masking-Global-Reranking_Pipeline/paradetox/DecompX0.2/aa1.5_ae4.75_ab1.0_basebase_antiantie_expertexper_temp2.5_sampleT_topk50_reppenalty1.0_filterp1.0_maxlength96_topp0.95/gen.txt
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDet