# T5-ParaDetox Pipeline with Global Reranking

This notebook combines:
- **T5-base** fine-tuned on ParaDetox for detoxification
- **Global reranking** using toxicity, semantic similarity, and fluency

## Pipeline

1. Generate `num_candidates` detoxified texts per input using T5 sampling
2. Score each candidate using:
   - **Toxicity** (XLM-R large classifier)
   - **Semantic Similarity** (LaBSE embeddings)
   - **Fluency** (GPT-2 perplexity)
3. Select candidate with highest weighted score
4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity

---

## Global Reranking Formula

For each candidate $c$:

$$\text{Score}(c) = w_T \cdot (1 - \text{Toxicity}(c)) + w_S \cdot \text{Similarity}(c) + w_F \cdot \text{Fluency}(c)$$

Default weights: $(w_T, w_S, w_F) = (0.5, 0.3, 0.2)$

---

## `detoxify()` API

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_Global-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    weights: tuple = (0.5, 0.3, 0.2),  # (toxicity, similarity, fluency)
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
)
```

### Key Arguments

- `data_type`: Dataset key (paradetox, microagressions_test, sbf_test, dynabench_test, jigsaw_toxic, appdia_original, appdia_discourse)
- `output_folder`: Folder under `data/model_outputs/` for results
- `num_candidates`: Number of candidates to generate per input for reranking
- `weights`: Tuple of (toxicity_weight, similarity_weight, fluency_weight)
- `echo`: If True, print example inputs, candidates, and outputs

## setup

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, sys, torch

# Adjust this if your project lives somewhere else
PROJECT_BASE = "/content/drive/MyDrive/w266 - Project"
XDETOX_DIR   = os.path.join(PROJECT_BASE, "XDetox")
T5_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("PROJECT_BASE:", PROJECT_BASE)
print("XDETOX_DIR:", XDETOX_DIR)
print("T5_CHECKPOINT:", T5_CHECKPOINT)
print("XDETOX exists:", os.path.isdir(XDETOX_DIR))
print("T5 checkpoint exists:", os.path.isdir(T5_CHECKPOINT) or os.path.isfile(T5_CHECKPOINT))

assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

In [None]:
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

In [None]:
for d in ["rewrite", "evaluation", "datasets", "data"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("ok")

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("ok")

In [None]:
import glob, re, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
from typing import List, Tuple

from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    GPT2TokenizerFast, GPT2LMHeadModel,
)
print("loaded")

REPO = XDETOX_DIR

## data config

In [None]:
data_configs = {
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "format": "txt",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "format": "csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "format": "csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "format": "csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "format": "txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "format": "tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "format": "tsv",
    },
}
print(f"{len(data_configs)} datasets:", ", ".join(data_configs.keys()))

## helpers

In [None]:
def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_test_data(data_type: str, num_examples: int = None) -> List[str]:
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    cfg = data_configs[data_type]
    data_path = os.path.join(REPO, cfg["data_path"].lstrip("./"))

    texts = []

    if cfg["format"] == "txt":
        with open(data_path, "r", encoding="utf-8") as f:
            texts = [line.strip() for line in f if line.strip()]

    elif cfg["format"] == "csv":
        df = pd.read_csv(data_path)
        if "text" in df.columns:
            texts = df["text"].tolist()
        elif "toxic" in df.columns:
            texts = df["toxic"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    elif cfg["format"] == "tsv":
        df = pd.read_csv(data_path, sep="\t")
        if "text" in df.columns:
            texts = df["text"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    cleaned = []
    for t in texts:
        if pd.isna(t):
            continue
        s = str(t).strip()
        if s:
            cleaned.append(s)

    if num_examples and num_examples > 0:
        cleaned = cleaned[:num_examples]

    return cleaned

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float("nan")

def _read_stats_file(path: str):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

print("ok")


## t5 model

In [None]:
print(f"loading t5 from {T5_CHECKPOINT}...")
t5_tokenizer = T5Tokenizer.from_pretrained(T5_CHECKPOINT)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_CHECKPOINT)
t5_model.eval()

DEVICE_T5 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model.to(DEVICE_T5)

print(f"loaded on {DEVICE_T5}")

## inference

In [None]:
def t5_generate_candidates(
    text: str,
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    device: torch.device = DEVICE_T5,
) -> List[str]:
    input_text = f"detoxify: {text}"
    input_ids = tokenizer.encode(
        input_text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_candidates,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            no_repeat_ngram_size=2,
        )

    return [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]

def t5_generate_candidates_batch(
    texts: List[str],
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    device: torch.device = DEVICE_T5,
) -> List[List[str]]:
    all_candidates: List[List[str]] = []
    for text in tqdm(texts, desc="T5 Generation"):
        cands = t5_generate_candidates(
            text,
            model,
            tokenizer,
            num_candidates=num_candidates,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_length=max_length,
            device=device,
        )
        all_candidates.append(cands)
    return all_candidates

test_text = "This is a stupid idea"
candidates = t5_generate_candidates(
    test_text,
    t5_model,
    t5_tokenizer,
    num_candidates=3,
    device=DEVICE_T5,
)
print(f"Input: {test_text}")
for i, c in enumerate(candidates):
    print(f"  cand[{i}]: {c}")

## reranking

In [None]:
DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("scoring device:", DEVICE_SCORE)

_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_MODEL is None:
        print("loading toxicity model...")
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL.to(DEVICE_SCORE).eval()

_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_MODEL is None:
        print("loading labse...")
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE).eval()

_GPT2_NAME = "gpt2"
_GPT2_TOK = None
_GPT2_MOD = None

def _lazy_load_gpt2_scorer():
    global _GPT2_TOK, _GPT2_MOD
    if _GPT2_MOD is None:
        print("loading gpt2...")
        _GPT2_TOK = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MOD = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE).eval()

print("ok")

In [None]:
@torch.no_grad()
def get_toxicity_scores(texts: List[str], batch_size: int = 32) -> List[float]:
    _lazy_load_tox()
    scores: List[float] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i + batch_size]
        enc = _TOX_TOKENIZER(
            batch,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True,
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].cpu().tolist())
    return scores

@torch.no_grad()
def get_labse_embeddings(texts: List[str], batch_size: int = 32) -> np.ndarray:
    _lazy_load_labse()
    embs: List[np.ndarray] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i + batch_size]
        enc = _LABSE_TOKENIZER(
            batch,
            return_tensors="pt",
            truncation=True,
            max_length=256,
            padding=True,
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state
        mask = enc["attention_mask"].unsqueeze(-1)
        masked = hidden * mask
        summed = masked.sum(dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

@torch.no_grad()
def get_gpt2_perplexities(texts: List[str]) -> List[float]:
    _lazy_load_gpt2_scorer()
    ppls: List[float] = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOK(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MOD(enc["input_ids"], labels=enc["input_ids"])
        ppl = math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls: List[float],
                          p_min: float = 5.0,
                          p_max: float = 300.0) -> np.ndarray:
    p = np.asarray(ppls, dtype=float)
    p = np.clip(p, p_min, p_max)
    log_p = np.log(p)
    log_min = math.log(p_min)
    log_max = math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

print("ok")

In [None]:
def rerank_candidates_global(
    sources: List[str],
    candidates: List[List[str]],
    weights: Tuple[float, float, float] = (0.5, 0.3, 0.2),
) -> List[str]:
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return []

    C_list = [len(c) for c in candidates]
    C = C_list[0]
    assert all(c == C for c in C_list), "All inputs must have same num_candidates"

    flat = [cand for clist in candidates for cand in clist]
    flat_idx = np.repeat(np.arange(N), C)

    print("  computing toxicity...")
    tox = np.array(get_toxicity_scores(flat))
    safety = 1.0 - tox

    print("  computing similarity...")
    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat)

    src_embs = src_embs / np.linalg.norm(src_embs, axis=1, keepdims=True).clip(1e-8)
    cand_embs = cand_embs / np.linalg.norm(cand_embs, axis=1, keepdims=True).clip(1e-8)

    sims = np.sum(cand_embs * src_embs[flat_idx], axis=1)
    sims = (sims + 1.0) / 2.0

    print("  computing fluency...")
    ppls = get_gpt2_perplexities(flat)
    flus = perplexity_to_fluency(ppls)

    scores = w_T * safety + w_S * sims + w_F * flus
    scores = scores.reshape(N, C)

    best_idx = scores.argmax(axis=1)
    best = [candidates[i][best_idx[i]] for i in range(N)]
    return best

print("ok")

## eval

In [None]:
def _eval_with_toxicity(base_path: str,
                        overwrite_eval: bool = False,
                        skip_ref: bool = False,
                        tox_threshold: float = 0.5,
                        tox_batch_size: int = 32):
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir  = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (
            ":" + env.get("PYTHONPATH", "") if env.get("PYTHONPATH") else ""
        )

        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")

        print("eval:", " ".join(cmd))
        res = run(
            cmd,
            cwd=REPO,
            env=env,
            stdout=PIPE,
            stderr=PIPE,
            text=True,
        )
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _aggregate_eval_csv(output_folder: str,
                        data_type: str,
                        base_out_dir: str):
    rows = []

    rerank_dir = "T5_Global_Rerank"
    base_path  = os.path.join(base_out_dir, data_type, rerank_dir)
    if not os.path.isdir(base_path):
        print("no eval dir:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir   = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "config":          folder,
            "bertscore":       s.get("bertscore", np.nan),
            "meaningbert":     s.get("meaningbert", np.nan),
            "bleu4":           s.get("bleu4", np.nan),
            "perplexity_gen":  s.get("perplexity gen", np.nan),
            "perplexity_orig": s.get("perplexity orig", np.nan),
            "toxicity_gen":    s.get("toxicity gen", np.nan),
            "toxicity_orig":   s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "config",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("wrote csv:", out_csv)
    else:
        print("no eval files found")

print("ok")

In [None]:
def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_run_folder_name_t5_global(
    num_candidates: int,
    max_length: int,
    temperature: float,
    top_k: int,
    top_p: float,
    w_T: float,
    w_S: float,
    w_F: float,
) -> str:
    return (
        f"t5_nc{num_candidates}_maxlen{max_length}_"
        f"temp{temperature}_topk{top_k}_topp{top_p}_"
        f"wT{w_T}_wS{w_S}_wF{w_F}"
    )

## run

In [None]:
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_Global-Reranking",
    echo: bool = False,
    num_examples: int = 1000,
    num_candidates: int = 10,
    max_length: int = 128,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    weights: Tuple[float, float, float] = (0.5, 0.3, 0.2),
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
):
    assert data_type in data_configs, f"Unknown data_type: {data_type}"

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    print("=" * 80)
    print(f"[{data_type}] loading data...")
    orig_texts = load_test_data(data_type, num_examples)
    print(f"  loaded {len(orig_texts)} examples")

    if echo:
        print("\nexample inputs (first up to 3):")
        for i, s in enumerate(orig_texts[:3]):
            print(f"  input[{i}]: {s}")
        print(f"\nweights (tox, sim, flu): {weights}")

    rerank_dir = "T5_Global_Rerank"
    cur_rel = os.path.join(base_out_rel, data_type, rerank_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    w_T, w_S, w_F = weights
    run_folder = _build_run_folder_name_t5_global(
        num_candidates=num_candidates,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        w_T=w_T,
        w_S=w_S,
        w_F=w_F,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)

    orig_path = os.path.join(final_abs, "orig.txt")
    gen_path  = os.path.join(final_abs, "gen.txt")
    stats_path = os.path.join(final_abs, "gen_stats.txt")

    if overwrite_gen or not os.path.exists(gen_path):
        print("  generating candidates with t5...")
        all_candidates = t5_generate_candidates_batch(
            orig_texts,
            t5_model,
            t5_tokenizer,
            num_candidates=num_candidates,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_length=max_length,
            device=DEVICE_T5,
        )

        if echo and all_candidates:
            print("\nexample candidates for input[0]:")
            for j, c in enumerate(all_candidates[0][:3]):
                print(f"    cand[{j}]: {c}")

        print("  global reranking...")
        gen_texts = rerank_candidates_global(
            sources=orig_texts,
            candidates=all_candidates,
            weights=weights,
        )

        if echo:
            print("\nselected outputs (first up to 3):")
            for i, g in enumerate(gen_texts[:3]):
                print(f"  output[{i}]: {g}")

        with open(orig_path, "w") as f:
            for t in orig_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")

        with open(gen_path, "w") as f:
            for t in gen_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")

        print("  saved to:", final_abs)
    else:
        print("  reusing existing from:", final_abs)
        with open(orig_path, "r") as f:
            orig_texts = [l.strip() for l in f]
        with open(gen_path, "r") as f:
            gen_texts = [l.strip() for l in f]
        print(f"  loaded {len(gen_texts)} examples")

    metrics = None
    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, rerank_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if os.path.exists(stats_path):
            metrics = _read_stats_file(stats_path)
            if echo:
                print("\nmetrics:")
                for k, v in metrics.items():
                    if isinstance(v, float) and math.isnan(v):
                        continue
                    print(f"  {k}: {v:.4f}")
        else:
            print("  gen_stats.txt not found")

    print("=" * 80)
    return metrics

print("ok")

## run eval

In [None]:
metrics_paradetox = detoxify(
    data_type="paradetox",
    output_folder="T5_w_Global-Reranking - KB",
    echo=True,
    num_examples=1000,
    num_candidates=10,
    max_length=128,
    temperature=1.0,
    top_k=50,
    top_p=0.95,
    weights=(0.5, 0.3, 0.2),
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    skip_ref_eval=False,
)

print("\nparadetox metrics:")
if metrics_paradetox:
    for k, v in metrics_paradetox.items():
        if isinstance(v, float) and math.isnan(v):
            continue
        print(f"  {k}: {v:.4f}")
