# T5-ParaDetox Pipeline
This notebook mirrors the XDetox_Pipeline structure for direct comparison:

- **Small-batch runs**: choose how many examples to process
- **Dataset picker**: run a single dataset or **all**
- **Same datasets** as XDetox (paradetox, microagressions, sbf, dynabench, jigsaw, appdia)
- **Same evaluation metrics** (BLEU, BERTScore, Perplexity, Toxicity)
- **Same output format** (CSV summaries)

> **Prereqs**: You have the trained T5 model checkpoint on Drive and datasets available.

## Setup

In [None]:
from google.colab import drive; drive.mount('/content/drive')

import os, sys, torch

PROJECT_BASE = "/content/drive/MyDrive/w266 - Project"
XDETOX_DIR   = os.path.join(PROJECT_BASE, "XDetox")
T5_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")

print("project:", PROJECT_BASE)
print("xdetox:", XDETOX_DIR, "->", os.path.isdir(XDETOX_DIR))
print("checkpoint:", T5_CHECKPOINT)

assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"
assert os.path.isdir(T5_CHECKPOINT), f"T5_CHECKPOINT does not exist: {T5_CHECKPOINT}"

HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("cache:", HF_CACHE)
print("cuda:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("gpu:", torch.cuda.get_device_name(0))

REPO = XDETOX_DIR
DATASET_BASE = REPO

In [None]:
!pip -q install --upgrade pip setuptools wheel
!pip -q install "transformers==4.41.2" "tokenizers==0.19.1" \
                "datasets==2.19.0" "evaluate==0.4.1" \
                "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas jedi \
                sentencepiece
!pip -q install bert-score

In [None]:
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("nltk ok")

In [None]:
import glob, re, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from subprocess import run, PIPE
from typing import List

from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
)

print("imports done")

## data config

In [None]:
data_configs = {
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "format": "txt",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "format": "csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "format": "csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "format": "csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "format": "txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "format": "tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "format": "tsv",
    },
}
print(f"{len(data_configs)} datasets")

## helpers

In [None]:
def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_test_data(data_type: str, num_examples: int = None) -> List[str]:
    """load data"""
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    cfg = data_configs[data_type]
    data_path = os.path.join(DATASET_BASE, cfg["data_path"].lstrip("./"))

    texts = []

    if cfg["format"] == "txt":
        with open(data_path, "r", encoding="utf-8") as f:
            texts = [line.strip() for line in f if line.strip()]

    elif cfg["format"] == "csv":
        df = pd.read_csv(data_path)
        if "text" in df.columns:
            texts = df["text"].tolist()
        elif "toxic" in df.columns:
            texts = df["toxic"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    elif cfg["format"] == "tsv":
        df = pd.read_csv(data_path, sep="\t")
        if "text" in df.columns:
            texts = df["text"].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    cleaned = []
    for t in texts:
        if pd.isna(t):
            continue
        s = str(t).strip()
        if s:
            cleaned.append(s)

    if num_examples and num_examples > 0:
        cleaned = cleaned[:num_examples]

    return cleaned

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float("nan")

def _read_stats_file(path: str) -> dict:
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

print("helpers loaded")

## t5 model

In [None]:
print(f"loading from {T5_CHECKPOINT}...")

t5_tokenizer = T5Tokenizer.from_pretrained(T5_CHECKPOINT)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_CHECKPOINT)
t5_model.eval()

DEVICE_T5 = torch.device("cuda" if torch.cuda.is_available() else "cpu")
t5_model.to(DEVICE_T5)

print(f"loaded on {DEVICE_T5}")

## inference

In [None]:
def t5_detoxify_text(
    text: str,
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    max_length: int = 128,
    num_beams: int = 5,
    device: torch.device = DEVICE_T5,
) -> str:
    """single text"""
    input_text = f"detoxify: {text}"
    input_ids = tokenizer.encode(
        input_text,
        return_tensors="pt",
        max_length=max_length,
        truncation=True,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            no_repeat_ngram_size=2,
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

def t5_detoxify_batch(
    texts: List[str],
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    max_length: int = 128,
    num_beams: int = 5,
    batch_size: int = 8,
    device: torch.device = DEVICE_T5,
) -> List[str]:
    """batch"""
    outputs_all: List[str] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="T5 gen"):
        batch_texts = texts[i:i + batch_size]
        prompts = [f"detoxify: {t}" for t in batch_texts]

        enc = tokenizer(
            prompts,
            return_tensors="pt",
            max_length=max_length,
            truncation=True,
            padding=True,
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(
                **enc,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                no_repeat_ngram_size=2,
            )

        decoded = [tokenizer.decode(o, skip_special_tokens=True) for o in outputs]
        outputs_all.extend(decoded)

    return outputs_all

# test
test_text = "This is a stupid idea"
detoxified = t5_detoxify_text(test_text, t5_model, t5_tokenizer, device=DEVICE_T5)
print(f"in: {test_text}")
print(f"out: {detoxified}")

## eval

In [None]:
def _eval_with_toxicity(
    base_path: str,
    overwrite_eval: bool = False,
    skip_ref: bool = False,
    tox_threshold: float = 0.5,
    tox_batch_size: int = 32,
):
    """run evaluate_all on each folder"""
    import sys as _sys
    for folder in os.listdir(base_path):
        gen_dir = os.path.join(base_path, folder)
        if not os.path.isdir(gen_dir):
            continue
        orig_path = os.path.join(gen_dir, "orig.txt")
        gen_path  = os.path.join(gen_dir, "gen.txt")
        out_stats = os.path.join(gen_dir, "gen_stats.txt")
        if not (os.path.exists(orig_path) and os.path.exists(gen_path)):
            continue
        if os.path.exists(out_stats) and not overwrite_eval:
            continue

        env = os.environ.copy()
        env["PYTHONPATH"] = REPO + (
            ":" + env.get("PYTHONPATH", "") if env.get("PYTHONPATH") else ""
        )
        cmd = [
            _sys.executable, "-m", "evaluation.evaluate_all",
            "--orig_path", orig_path,
            "--gen_path",  gen_path,
            "--tox_threshold", str(tox_threshold),
            "--tox_batch_size", str(tox_batch_size),
        ]
        if skip_ref:
            cmd.append("--skip_ref")
        print("eval:", " ".join(cmd))
        res = run(cmd, cwd=REPO, env=env, stdout=PIPE, stderr=PIPE, text=True)
        if res.returncode != 0:
            print(res.stdout)
            print(res.stderr)
            res.check_returncode()

def _aggregate_eval_csv_baseline(
    output_folder: str,
    data_type: str,
    base_out_dir: str,
    model_dir: str = "T5_Baseline",
):
    """aggregate metrics"""
    rows = []

    base_path = os.path.join(base_out_dir, data_type, model_dir)
    if not os.path.isdir(base_path):
        print("no eval dir:", base_path)
        return

    for folder in os.listdir(base_path):
        gen_dir    = os.path.join(base_path, folder)
        stats_path = os.path.join(gen_dir, "gen_stats.txt")
        if not os.path.exists(stats_path):
            continue
        s = _read_stats_file(stats_path)
        rows.append({
            "folder":          folder,
            "bertscore":       s.get("bertscore", np.nan),
            "meaningbert":     s.get("meaningbert", np.nan),
            "bleu4":           s.get("bleu4", np.nan),
            "perplexity_gen":  s.get("perplexity gen", np.nan),
            "perplexity_orig": s.get("perplexity orig", np.nan),
            "toxicity_gen":    s.get("toxicity gen", np.nan),
            "toxicity_orig":   s.get("toxicity orig", np.nan),
        })

    if rows:
        cols = [
            "folder",
            "bertscore", "meaningbert", "bleu4",
            "perplexity_gen", "perplexity_orig",
            "toxicity_gen", "toxicity_orig",
        ]
        df = pd.DataFrame(rows)
        df = df[cols]
        out_csv = os.path.join(base_out_dir, data_type, f"{data_type}.csv")
        _ensure_dir(os.path.dirname(out_csv))
        df.to_csv(out_csv, index=False)
        print("wrote csv:", out_csv)
    else:
        print("no eval files")

print("eval helpers ok")

In [None]:
def _build_run_folder_name_t5_baseline(
    max_length: int,
    num_beams: int,
) -> str:
    return f"t5_baseline_maxlen{max_length}_beams{num_beams}"

print("folder naming ok")

## detoxify

In [None]:
def detoxify_baseline(
    data_type: str = "paradetox",
    output_folder: str = "T5_ParaDetox_Pipeline",
    echo: bool = False,
    num_examples: int = 100,
    batch_size: int = 8,
    max_length: int = 128,
    num_beams: int = 5,
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    skip_ref_eval: bool = False,
):
    assert data_type in data_configs, f"unknown: {data_type}"

    base_out_rel = os.path.join("data", "model_outputs", output_folder)
    base_out_abs = os.path.join(REPO, base_out_rel)
    _ensure_dir(base_out_abs)

    print("=" * 60)
    print(f"[{data_type}] loading...")
    orig_texts = load_test_data(data_type, num_examples)
    print(f"  got {len(orig_texts)} examples")

    if echo:
        print("\ninputs (first 3):")
        for i, s in enumerate(orig_texts[:3]):
            print(f"  [{i}]: {s}")

    model_dir = "T5_Baseline"
    cur_abs = os.path.join(base_out_abs, data_type, model_dir)
    _ensure_dir(cur_abs)

    run_folder = _build_run_folder_name_t5_baseline(
        max_length=max_length,
        num_beams=num_beams,
    )
    final_abs = os.path.join(cur_abs, run_folder)
    _ensure_dir(final_abs)

    orig_path  = os.path.join(final_abs, "orig.txt")
    gen_path   = os.path.join(final_abs, "gen.txt")
    stats_path = os.path.join(final_abs, "gen_stats.txt")

    if overwrite_gen or not os.path.exists(gen_path):
        print("  generating...")
        generations = t5_detoxify_batch(
            texts=orig_texts,
            model=t5_model,
            tokenizer=t5_tokenizer,
            max_length=max_length,
            num_beams=num_beams,
            batch_size=batch_size,
            device=DEVICE_T5,
        )

        if echo:
            print("\noutputs (first 3):")
            for i, g in enumerate(generations[:3]):
                print(f"  [{i}]: {g}")

        with open(orig_path, "w") as f:
            for t in orig_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")
        with open(gen_path, "w") as f:
            for t in generations:
                f.write(re.sub(r"\s+", " ", t).strip() + "\n")

        print("  saved to:", final_abs)
    else:
        print("  reusing:", final_abs)
        with open(orig_path, "r") as f:
            orig_texts = [l.strip() for l in f]
        with open(gen_path, "r") as f:
            generations = [l.strip() for l in f]
        print(f"  loaded {len(generations)} gen")

    metrics = None
    if run_eval:
        base_path = os.path.join(base_out_abs, data_type, model_dir)
        _eval_with_toxicity(
            base_path,
            overwrite_eval=overwrite_eval,
            skip_ref=skip_ref_eval,
            tox_threshold=0.5,
            tox_batch_size=32,
        )
        _aggregate_eval_csv_baseline(
            output_folder,
            data_type,
            os.path.join(REPO, "data", "model_outputs", output_folder),
        )

        if os.path.exists(stats_path):
            metrics = _read_stats_file(stats_path)
            if echo:
                print("\nmetrics:")
                for k, v in metrics.items():
                    if isinstance(v, float) and math.isnan(v):
                        continue
                    print(f"  {k}: {v:.4f}")
        else:
            print("  no stats file")

    print("=" * 60)
    return metrics

print("detoxify_baseline() ready")

## run

In [None]:
metrics_paradetox = detoxify_baseline(
    data_type="paradetox",
    output_folder="T5_ParaDetox_Pipeline",
    echo=True,
    num_examples=1000,
    batch_size=8,
    max_length=128,
    num_beams=5,
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    skip_ref_eval=False,
)

if metrics_paradetox:
    print("\nfinal metrics:")
    for k, v in metrics_paradetox.items():
        if isinstance(v, float) and math.isnan(v):
            continue
        print(f"  {k}: {v:.4f}")