# T5-ParaDetox Pipeline with Global Reranking

This notebook combines:
- **T5-base** fine-tuned on ParaDetox for detoxification
- **Global reranking** using toxicity, semantic similarity, and fluency

## Pipeline

1. Generate `num_candidates` detoxified texts per input using T5 sampling
2. Score each candidate using:
   - **Toxicity** (XLM-R large classifier)
   - **Semantic Similarity** (LaBSE embeddings)
   - **Fluency** (GPT-2 perplexity)
3. Select candidate with highest weighted score
4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity

---

## Global Reranking Formula

For each candidate $c$:

$$\text{Score}(c) = w_T \cdot (1 - \text{Toxicity}(c)) + w_S \cdot \text{Similarity}(c) + w_F \cdot \text{Fluency}(c)$$

Default weights: $(w_T, w_S, w_F) = (0.5, 0.3, 0.2)$

---

## `detoxify()` API

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_Global-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    weights: tuple = (0.5, 0.3, 0.2),  # (toxicity, similarity, fluency)
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
)
```

### Key Arguments

- `data_type`: Dataset key (paradetox, microagressions_test, sbf_test, dynabench_test, jigsaw_toxic, appdia_original, appdia_discourse)
- `output_folder`: Folder under `data/model_outputs/` for results
- `num_candidates`: Number of candidates to generate per input for reranking
- `weights`: Tuple of (toxicity_weight, similarity_weight, fluency_weight)
- `echo`: If True, print example inputs, candidates, and outputs

## Setup

In [1]:
#@title Mount Drive & locate project
from google.colab import drive; drive.mount('/content/drive')

import os, sys, torch

# Adjust this if your project lives somewhere else
PROJECT_BASE = "/content/drive/MyDrive/w266 - Project"
XDETOX_DIR   = os.path.join(PROJECT_BASE, "XDetox")
T5_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("PROJECT_BASE:", PROJECT_BASE)
print("XDETOX_DIR:", XDETOX_DIR)
print("T5_CHECKPOINT:", T5_CHECKPOINT)
print("XDETOX exists:", os.path.isdir(XDETOX_DIR))
print("T5 checkpoint exists:", os.path.isdir(T5_CHECKPOINT) or os.path.isfile(T5_CHECKPOINT))

assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"


Mounted at /content/drive
PROJECT_BASE: /content/drive/MyDrive/w266 - Project
XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
T5_CHECKPOINT: /content/drive/MyDrive/w266 - Project/t5-base-detox-model
XDETOX exists: True
T5 checkpoint exists: True


In [2]:
#@title Runtime setup (cache, GPU)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: Tesla T4


In [3]:
#@title Verify XDetox repo layout
for d in ["rewrite", "evaluation", "datasets", "data"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")

Repo folders OK.


In [4]:
#@title Install dependencies (aligned with LLM pipeline)
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━[0m [32m0.9/1.8 MB[0m [31m27.9 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m59.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
ipython 7.34.0 requires jedi>=0.16, which is not installed.[0m[31m
[0m[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of 

In [5]:
#@title NLTK data
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")

NLTK ready


In [6]:
#@title Imports
import glob, re, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
from typing import List, Tuple

from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    AutoTokenizer, AutoModel, AutoModelForSequenceClassification,
    GPT2TokenizerFast, GPT2LMHeadModel,
)
print("Libraries imported")

REPO = XDETOX_DIR



Libraries imported


## Dataset Configuration

In [7]:
#@title Data configs (matching XDetox paths + formats)
data_configs = {
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "format": "txt",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "format": "csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "format": "csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "format": "csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "format": "txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "format": "tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "format": "tsv",
    },
}
print(f"{len(data_configs)} datasets configured:", ", ".join(data_configs.keys()))

7 datasets configured: paradetox, microagressions_test, sbf_test, dynabench_test, jigsaw_toxic, appdia_original, appdia_discourse


## Helper Functions

In [8]:
#@title Helper functions (I/O + stats)

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_test_data(data_type: str, num_examples: int = None) -> List[str]:
    """
    Load toxic texts for a given data_type.
    Uses XDetox's datasets folder and simple heuristics for text column.
    """
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    cfg = data_configs[data_type]
    data_path = os.path.join(REPO, cfg["data_path"].lstrip("./"))

    texts = []

    if cfg["format"] == "txt":
        with open(data_path, "r", encoding="utf-8") as f:
            texts = [line.strip() for line in f if line.strip()]

    elif cfg["format"] == "csv":
        df = pd.read_csv(data_path)
        if "text" in df.columns:
            texts = df["text"].tolist()
        elif "toxic" in df.columns:
            texts = df["toxic"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    elif cfg["format"] == "tsv":
        df = pd.read_csv(data_path, sep="\t")
        if "text" in df.columns:
            texts = df["text"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    cleaned = []
    for t in texts:
        if pd.isna(t):
            continue
        s = str(t).strip()
        if s:
            cleaned.append(s)

    if num_examples and num_examples > 0:
        cleaned = cleaned[:num_examples]

    return cleaned

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float("nan")

def _read_stats_file(path: str):
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

print("Helper functions loaded")


Helper functions loaded


## T5 Model Loading

In [9]:
#@title Load T5 model (ParaDetox)

print(f"Loading T5 model from {T5_CHECKPOINT}...")
t5_tokenizer = T5Tokenizer.from_pretrained(T5_CHECKPOINT)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_CHECKPOINT)
t5_model.eval()

DEVICE_T5 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model.to(DEVICE_T5)

print(f"T5 model loaded on {DEVICE_T5}")

Loading T5 model from /content/drive/MyDrive/w266 - Project/t5-base-detox-model...


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


T5 model loaded on cuda


## T5 Multi-Candidate Generation

In [10]:
#@title T5 multi-candidate generation

def t5_generate_candidates(
    text: str,
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    device: torch.device = DEVICE_T5,
) -> List[str]:
    """
    Generate num_candidates candidate rewrites via sampling.
    """
    input_text = f"detoxify: {text}"
    input_ids = tokenizer.encode(
        input_text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_candidates,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            no_repeat_ngram_size=2,
        )

    return [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]

def t5_generate_candidates_batch(
    texts: List[str],
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    device: torch.device = DEVICE_T5,
) -> List[List[str]]:
    """
    Batch generation of candidates for many inputs.
    """
    all_candidates: List[List[str]] = []
    for text in tqdm(texts, desc="T5 Generation"):
        cands = t5_generate_candidates(
            text,
            model,
            tokenizer,
            num_candidates=num_candidates,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_length=max_length,
            device=device,
        )
        all_candidates.append(cands)
    return all_candidates

# Quick sanity check
test_text = "This is a stupid idea"
candidates = t5_generate_candidates(
    test_text,
    t5_model,
    t5_tokenizer,
    num_candidates=3,
    device=DEVICE_T5,
)
print(f"Input: {test_text}")
for i, c in enumerate(candidates):
    print(f"  cand[{i}]: {c}")

Input: This is a stupid idea
  cand[0]: This is not a good idea.
  cand[1]: This is not good idea.
  cand[2]: This is a bad idea


## Global Reranking Functions

In [11]:
#@title Global reranking models (toxicity, similarity, fluency)

DEVICE_SCORE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Scoring device:", DEVICE_SCORE)

# Toxicity model (XLM-R large, same as before)
_TOX_MODEL_NAME = "textdetox/xlmr-large-toxicity-classifier-v2"
_TOX_TOKENIZER = None
_TOX_MODEL = None

def _lazy_load_tox():
    global _TOX_TOKENIZER, _TOX_MODEL
    if _TOX_MODEL is None:
        print("Loading toxicity model...")
        _TOX_TOKENIZER = AutoTokenizer.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL = AutoModelForSequenceClassification.from_pretrained(_TOX_MODEL_NAME)
        _TOX_MODEL.to(DEVICE_SCORE).eval()

# Similarity model (LaBSE via AutoModel)
_LABSE_NAME = "sentence-transformers/LaBSE"
_LABSE_TOKENIZER = None
_LABSE_MODEL = None

def _lazy_load_labse():
    global _LABSE_TOKENIZER, _LABSE_MODEL
    if _LABSE_MODEL is None:
        print("Loading LaBSE model...")
        _LABSE_TOKENIZER = AutoTokenizer.from_pretrained(_LABSE_NAME)
        _LABSE_MODEL = AutoModel.from_pretrained(_LABSE_NAME).to(DEVICE_SCORE).eval()

# Fluency model (GPT-2 small)
_GPT2_NAME = "gpt2"
_GPT2_TOK = None
_GPT2_MOD = None

def _lazy_load_gpt2_scorer():
    global _GPT2_TOK, _GPT2_MOD
    if _GPT2_MOD is None:
        print("Loading GPT-2 model...")
        _GPT2_TOK = GPT2TokenizerFast.from_pretrained(_GPT2_NAME)
        _GPT2_MOD = GPT2LMHeadModel.from_pretrained(_GPT2_NAME).to(DEVICE_SCORE).eval()

print("Scoring model loaders ready")

Scoring device: cuda
Scoring model loaders ready


In [12]:
#@title Global scoring functions (toxicity, similarity, fluency)

@torch.no_grad()
def get_toxicity_scores(texts: List[str], batch_size: int = 32) -> List[float]:
    """
    Toxicity probabilities in [0,1] (higher = more toxic).
    """
    _lazy_load_tox()
    scores: List[float] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Toxicity", leave=False):
        batch = texts[i:i + batch_size]
        enc = _TOX_TOKENIZER(
            batch,
            return_tensors="pt",
            truncation=True,
            max_length=512,
            padding=True,
        ).to(DEVICE_SCORE)
        logits = _TOX_MODEL(**enc).logits
        probs = torch.softmax(logits, dim=-1)
        scores.extend(probs[:, 1].cpu().tolist())
    return scores

@torch.no_grad()
def get_labse_embeddings(texts: List[str], batch_size: int = 32) -> np.ndarray:
    """
    Mean-pooled LaBSE sentence embeddings.
    """
    _lazy_load_labse()
    embs: List[np.ndarray] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="LaBSE embeddings", leave=False):
        batch = texts[i:i + batch_size]
        enc = _LABSE_TOKENIZER(
            batch,
            return_tensors="pt",
            truncation=True,
            max_length=256,
            padding=True,
        ).to(DEVICE_SCORE)
        outputs = _LABSE_MODEL(**enc)
        hidden = outputs.last_hidden_state      # [B, L, H]
        mask = enc["attention_mask"].unsqueeze(-1)  # [B, L, 1]
        masked = hidden * mask
        summed = masked.sum(dim=1)             # [B, H]
        counts = mask.sum(dim=1).clamp(min=1e-6)
        sent_emb = (summed / counts).cpu().numpy()
        embs.append(sent_emb)
    if not embs:
        return np.zeros((0, 768), dtype=np.float32)
    return np.vstack(embs)

@torch.no_grad()
def get_gpt2_perplexities(texts: List[str]) -> List[float]:
    """
    Approximate sentence-level perplexity with GPT-2.
    """
    _lazy_load_gpt2_scorer()
    ppls: List[float] = []
    for s in tqdm(texts, desc="GPT-2 PPL", leave=False):
        enc = _GPT2_TOK(s, return_tensors="pt").to(DEVICE_SCORE)
        out = _GPT2_MOD(enc["input_ids"], labels=enc["input_ids"])
        ppl = math.exp(out.loss.item())
        if ppl > 1e4:
            ppl = 1e4
        ppls.append(float(ppl))
    return ppls

def perplexity_to_fluency(ppls: List[float],
                          p_min: float = 5.0,
                          p_max: float = 300.0) -> np.ndarray:
    """
    Map perplexities to [0,1] fluency scores.
    Low perplexity -> high fluency.
    """
    p = np.asarray(ppls, dtype=float)
    p = np.clip(p, p_min, p_max)
    log_p = np.log(p)
    log_min = math.log(p_min)
    log_max = math.log(p_max)
    F = (log_max - log_p) / (log_max - log_min + 1e-8)
    F = np.clip(F, 0.0, 1.0)
    return F

print("Global scoring functions defined")

Global scoring functions defined


In [13]:
#@title Global reranking

def rerank_candidates_global(
    sources: List[str],
    candidates: List[List[str]],
    weights: Tuple[float, float, float] = (0.5, 0.3, 0.2),
) -> List[str]:
    """
    Global reranking:
      Score = w_T * (1 - toxicity) + w_S * similarity + w_F * fluency.

    Args:
        sources:   list of input toxic sentences, length N
        candidates: list of list of candidates, shape [N][C]
        weights:   (w_toxicity, w_similarity, w_fluency)

    Returns:
        best_candidates: list[str] length N
    """
    w_T, w_S, w_F = weights
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return []

    C_list = [len(c) for c in candidates]
    C = C_list[0]
    assert all(c == C for c in C_list), "All inputs must have same num_candidates"

    # Flatten candidates
    flat = [cand for clist in candidates for cand in clist]
    flat_idx = np.repeat(np.arange(N), C)

    print("  Computing toxicity scores...")
    tox = np.array(get_toxicity_scores(flat))   # higher = more toxic
    safety = 1.0 - tox

    print("  Computing similarity scores (LaBSE)...")
    src_embs = get_labse_embeddings(sources)
    cand_embs = get_labse_embeddings(flat)

    # Normalize
    src_embs = src_embs / np.linalg.norm(src_embs, axis=1, keepdims=True).clip(1e-8)
    cand_embs = cand_embs / np.linalg.norm(cand_embs, axis=1, keepdims=True).clip(1e-8)

    sims = np.sum(cand_embs * src_embs[flat_idx], axis=1)  # cosine
    sims = (sims + 1.0) / 2.0  # to [0,1]

    print("  Computing fluency scores (GPT-2)...")
    ppls = get_gpt2_perplexities(flat)
    flus = perplexity_to_fluency(ppls)

    scores = w_T * safety + w_S * sims + w_F * flus
    scores = scores.reshape(N, C)

    best_idx = scores.argmax(axis=1)
    best = [candidates[i][best_idx[i]] for i in range(N)]
    return best

print("Global reranking function defined")

Global reranking function defined


## Evaluation Functions

In [14]:
#@title Evaluation helpers — call evaluation/evaluate_all.py (XDetox)

def _eval_with_toxicity(base_path: str,
                        overwrite_eval: bool = False,
                        skip_ref: bool = False,
                        tox_threshold: float = 0.5,
                        tox_batch_size: int = 32):
    """
    Call evaluation.evaluate_all on each run folder in base_path.
    This matches the LLM pipeline behaviour.
    """
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir  = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (
            ":" + env.get("PYTHONPATH", "") if env.get("PYTHONPATH") else ""
        )

        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")

        print("Eval:", " ".join(cmd))
        res = run(
            cmd,
            cwd=REPO,
            env=env,
            stdout=PIPE,
            stderr=PIPE,
            text=True,
        )
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _aggregate_eval_csv(output_folder: str,
                        data_type: str,
                        base_out_dir: str):
    """
    Aggregate eval metrics for T5 + Global reranking.

    Layout:
      base_out_dir/
        └── {data_type}/
            └── T5_Global_Rerank/
                └── {run_folder}/
                    └── gen_stats.txt

    Writes:
      base_out_dir/{data_type}/{data_type}.csv
    """
    rows = []

    rerank_dir = "T5_Global_Rerank"
    base_path  = os.path.join(base_out_dir, data_type, rerank_dir)
    if not os.path.isdir(base_path):
        print("No evaluation directory found:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir   = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            # label column; not a real threshold here, but kept for compatibility
            "config":          folder,
            "bertscore":       s.get("bertscore", np.nan),
            "meaningbert":     s.get("meaningbert", np.nan),
            "bleu4":           s.get("bleu4", np.nan),
            "perplexity_gen":  s.get("perplexity gen", np.nan),
            "perplexity_orig": s.get("perplexity orig", np.nan),
            "toxicity_gen":    s.get("toxicity gen", np.nan),
            "toxicity_orig":   s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "config",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        df.to_csv(out_csv, index=False)
        print("Wrote summary CSV:", out_csv)
    else:
        print("No evaluation files found to summarize.")

print("Evaluation helpers (evaluate_all) defined")

Evaluation helpers (evaluate_all) defined


In [15]:
#@title Helpers for folder naming

def _bool2str(x: bool) -> str:
    return "T" if x else "F"

def _build_run_folder_name_t5_global(
    num_candidates: int,
    max_length: int,
    temperature: float,
    top_k: int,
    top_p: float,
    w_T: float,
    w_S: float,
    w_F: float,
) -> str:
    """
    Encode T5 + Global hyperparameters into a folder name.
    """
    return (
        f"t5_nc{num_candidates}_maxlen{max_length}_"
        f"temp{temperature}_topk{top_k}_topp{top_p}_"
        f"wT{w_T}_wS{w_S}_wF{w_F}"
    )

## Main Pipeline Function

In [16]:
#@title detoxify() — T5 + Global reranking + evaluate_all

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_Global-Reranking",
    echo: bool = False,
    num_examples: int = 1000,
    num_candidates: int = 10,
    max_length: int = 128,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    weights: Tuple[float, float, float] = (0.5, 0.3, 0.2),
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
):
    """
    T5-ParaDetox pipeline:

      1. Generate num_candidates detoxified texts per input using T5 sampling.
      2. Global reranking with toxicity + similarity + fluency.
      3. Save orig/gen under XDetox/data/model_outputs/{output_folder}/{data_type}/T5_Global_Rerank/{run_folder}.
      4. Run evaluation via evaluation.evaluate_all (BLEU, BERTScore, MeaningBERT, PPL, Toxicity).
      5. Aggregate into per-dataset CSV (same style as LLM pipelines).

    Returns:
        If run_eval: dict of metrics for this run (parsed from gen_stats.txt).
        Else: None.
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"

    # Output base (relative to repo, as in LLM pipeline)
    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    # Load inputs
    print("=" * 80)
    print(f"[{data_type}] Loading data...")
    orig_texts = load_test_data(data_type, num_examples)
    print(f"  Loaded {len(orig_texts)} examples")

    if echo:
        print("\n[echo] Example inputs (first up to 3):")
        for i, s in enumerate(orig_texts[:3]):
            print(f"  input[{i}]: {s}")
        print(f"\n[echo] Global weights (tox, sim, flu): {weights}")

    # Current pipeline directory (T5 + Global)
    rerank_dir = "T5_Global_Rerank"
    cur_rel = os.path.join(base_out_rel, data_type, rerank_dir)
    cur_abs = os.path.join(REPO, cur_rel)
    _ensure_dir(cur_abs)

    # Folder name for this configuration
    w_T, w_S, w_F = weights
    run_folder = _build_run_folder_name_t5_global(
        num_candidates=num_candidates,
        max_length=max_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        w_T=w_T,
        w_S=w_S,
        w_F=w_F,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)

    orig_path = os.path.join(final_abs, "orig.txt")
    gen_path  = os.path.join(final_abs, "gen.txt")
    stats_path = os.path.join(final_abs, "gen_stats.txt")

    # Generate or load gen.txt
    if overwrite_gen or not os.path.exists(gen_path):
        print("  Generating candidates with T5...")
        all_candidates = t5_generate_candidates_batch(
            orig_texts,
            t5_model,
            t5_tokenizer,
            num_candidates=num_candidates,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_length=max_length,
            device=DEVICE_T5,
        )

        if echo and all_candidates:
            print("\n[echo] Example candidates for input[0]:")
            for j, c in enumerate(all_candidates[0][:3]):
                print(f"    cand[{j}]: {c}")

        print("  Global reranking (toxicity + similarity + fluency)...")
        gen_texts = rerank_candidates_global(
            sources=orig_texts,
            candidates=all_candidates,
            weights=weights,
        )

        if echo:
            print("\n[echo] Selected outputs (first up to 3):")
            for i, g in enumerate(gen_texts[:3]):
                print(f"  output[{i}]: {g}")

        # Save orig and gen
        with open(orig_path, "w") as f:
            for t in orig_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")

        with open(gen_path, "w") as f:
            for t in gen_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")

        print("  Saved orig/gen to:", final_abs)
    else:
        print("  Reusing existing orig/gen from:", final_abs)
        with open(orig_path, "r") as f:
            orig_texts = [l.strip() for l in f]
        with open(gen_path, "r") as f:
            gen_texts = [l.strip() for l in f]
        print(f"  Loaded {len(gen_texts)} generated examples")

    # Evaluation via evaluation/evaluate_all.py
    metrics = None
    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, rerank_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if os.path.exists(stats_path):
            metrics = _read_stats_file(stats_path)
            if echo:
                print("\n[echo] Evaluation metrics for this run:")
                for k, v in metrics.items():
                    if isinstance(v, float) and math.isnan(v):
                        continue
                    print(f"  {k}: {v:.4f}")
        else:
            print("  gen_stats.txt not found for this run; no metrics to print.")

    print("=" * 80)
    return metrics

print("detoxify() defined")

detoxify() defined


## Run Evaluation

In [17]:
#@title Example run — paradetox

metrics_paradetox = detoxify(
    data_type="paradetox",
    output_folder="T5_w_Global-Reranking - KB",
    echo=True,
    num_examples=1000,
    num_candidates=10,
    max_length=128,
    temperature=1.0,
    top_k=50,
    top_p=0.95,
    weights=(0.5, 0.3, 0.2),
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    skip_ref_eval=False,
)

print("\nParadetox metrics for this run:")
if metrics_paradetox:
    for k, v in metrics_paradetox.items():
        if isinstance(v, float) and math.isnan(v):
            continue
        print(f"  {k}: {v:.4f}")


[paradetox] Loading data...
  Loaded 671 examples

[echo] Example inputs (first up to 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .

[echo] Global weights (tox, sim, flu): (0.5, 0.3, 0.2)
  Generating candidates with T5...


T5 Generation:   0%|          | 0/671 [00:00<?, ?it/s]


[echo] Example candidates for input[0]:
    cand[0]: or the one thousand ton beast roaring towards you howling its horn
    cand[1]: Or the one thousand ton beast roaring towards you
    cand[2]: or the one thousand ton beast roaring towards you howling its horn .
  Global reranking (toxicity + similarity + fluency)...
  Computing toxicity scores...
Loading toxicity model...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Toxicity:   0%|          | 0/210 [00:00<?, ?it/s]

  Computing similarity scores (LaBSE)...
Loading LaBSE model...


LaBSE embeddings:   0%|          | 0/21 [00:00<?, ?it/s]

LaBSE embeddings:   0%|          | 0/210 [00:00<?, ?it/s]

  Computing fluency scores (GPT-2)...
Loading GPT-2 model...


GPT-2 PPL:   0%|          | 0/6710 [00:00<?, ?it/s]


[echo] Selected outputs (first up to 3):
  output[0]: or the loud one- thousand ton beast roaring toward you howling its horn .
  output[1]: mandated and "right now" would be good.
  output[2]: neither one of my coworkers cared when it came time to ditch mitch
  Saved orig/gen to: /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_Global-Reranking - KB/paradetox/T5_Global_Rerank/t5_nc10_maxlen128_temp1.0_topk50_topp0.95_wT0.5_wS0.3_wF0.2
Eval: /usr/bin/python3 -m evaluation.evaluate_all --orig_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_Global-Reranking - KB/paradetox/T5_Global_Rerank/t5_nc10_maxlen128_temp1.0_topk50_topp0.95_wT0.5_wS0.3_wF0.2/orig.txt --gen_path /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_Global-Reranking - KB/paradetox/T5_Global_Rerank/t5_nc10_maxlen128_temp1.0_topk50_topp0.95_wT0.5_wS0.3_wF0.2/gen.txt --tox_threshold 0.5 --tox_batch_size 32
Wrote summary CSV: /content/drive/MyDrive/w266 - Proj