# T5-ParaDetox Pipeline with DecompX Reranking

This notebook combines:
- **T5-base** fine-tuned on ParaDetox for detoxification
- **DecompX reranking** to select the least toxic candidate from multiple generations

## Pipeline

1. Generate `num_candidates` detoxified texts per input using T5 sampling
2. Score each candidate using DecompX toxicity attribution (RoBERTa-based)
3. Select candidate with lowest toxicity score
4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity

---

## `detoxify()` API

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_DecompX-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
)
```

### Key Arguments

- `data_type`: Dataset key (paradetox, microagressions_test, sbf_test, dynabench_test, jigsaw_toxic, appdia_original, appdia_discourse)
- `output_folder`: Folder under `data/model_outputs/` for results
- `num_candidates`: Number of candidates to generate per input for reranking
- `temperature`: Sampling temperature for diversity (higher = more diverse)
- `echo`: If True, print example inputs, candidates, and outputs

## Setup

In [1]:
#@title Mount Drive & locate project
from google.colab import drive
drive.mount('/content/drive')

import os, sys, torch

# Set your project base path
PROJECT_BASE = "/content/drive/MyDrive/ds266/w266 - Project"
XDETOX_DIR = os.path.join(PROJECT_BASE, "XDetox")
T5_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")
DATASET_BASE = XDETOX_DIR

# Add XDetox to path for DecompX imports
if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("PROJECT_BASE:", PROJECT_BASE)
print("XDETOX_DIR:", XDETOX_DIR)
print("T5_CHECKPOINT:", T5_CHECKPOINT)

assert os.path.isdir(PROJECT_BASE), f"PROJECT_BASE does not exist: {PROJECT_BASE}"
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

Mounted at /content/drive
PROJECT_BASE: /content/drive/MyDrive/ds266/w266 - Project
XDETOX_DIR: /content/drive/MyDrive/ds266/w266 - Project/XDetox
T5_CHECKPOINT: /content/drive/MyDrive/ds266/w266 - Project/t5-base-detox-model


AssertionError: PROJECT_BASE does not exist: /content/drive/MyDrive/ds266/w266 - Project

In [None]:
#@title Runtime setup (cache, GPU)
HF_CACHE = os.path.join(PROJECT_BASE, "cache")
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

TRANSFORMERS_CACHE: /content/drive/MyDrive/ds266/w266 - Project/cache
CUDA available: True
GPU: Tesla T4


In [None]:
#@title Install dependencies
!pip install -q transformers torch datasets
!pip install -q evaluate sacrebleu bert-score
!pip install -q sentence-transformers accelerate -U
!pip install -q rouge_score pandas numpy scikit-learn matplotlib nltk

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


In [None]:
#@title NLTK data
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")

In [None]:
#@title Import libraries
import glob, re, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from typing import List

from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    AutoTokenizer, AutoModelForSequenceClassification,
    GPT2Tokenizer, GPT2LMHeadModel
)
from sentence_transformers import SentenceTransformer
from evaluate import load

# DecompX compatibility fix: apply_chunking_to_forward moved from
# transformers.modeling_utils to transformers.pytorch_utils in newer versions
import transformers.modeling_utils as modeling_utils
try:
    from transformers.modeling_utils import apply_chunking_to_forward
except ImportError:
    from transformers.pytorch_utils import apply_chunking_to_forward
    modeling_utils.apply_chunking_to_forward = apply_chunking_to_forward

# DecompX imports (for reranking)
from rewrite.mask_orig import Masker

print("Libraries imported")

## Dataset Configuration

In [None]:
#@title Data configs (matching XDetox datasets)
data_configs = {
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "format": "txt",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "format": "csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "format": "csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "format": "csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "format": "txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "format": "tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "format": "tsv",
    },
}

print(f"{len(data_configs)} datasets configured:")
for name in data_configs.keys():
    print(f"  - {name}")

## Helper Functions

In [None]:
#@title Helper functions

def _ensure_dir(p):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_test_data(data_type, num_examples=None):
    """
    Load test data from various formats (.txt, .csv, .tsv).
    Returns list of toxic texts as strings.
    """
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    cfg = data_configs[data_type]
    data_path = os.path.join(DATASET_BASE, cfg["data_path"].lstrip("./"))

    texts = []

    if cfg["format"] == "txt":
        with open(data_path, 'r', encoding='utf-8') as f:
            texts = [line.strip() for line in f if line.strip()]

    elif cfg["format"] == "csv":
        df = pd.read_csv(data_path)
        if 'text' in df.columns:
            texts = df['text'].tolist()
        elif 'toxic' in df.columns:
            texts = df['toxic'].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    elif cfg["format"] == "tsv":
        df = pd.read_csv(data_path, sep='\t')
        if 'text' in df.columns:
            texts = df['text'].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    # Clean and convert to strings
    cleaned_texts = []
    for text in texts:
        if pd.isna(text):
            continue
        text_str = str(text).strip()
        if text_str:
            cleaned_texts.append(text_str)

    if num_examples and num_examples > 0:
        cleaned_texts = cleaned_texts[:num_examples]

    return cleaned_texts

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path):
    """Read gen_stats.txt into dict."""
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

print("Helper functions loaded")

## T5 Model Loading

In [None]:
#@title Load T5 model
print(f"Loading T5 model from {T5_CHECKPOINT}...")

t5_tokenizer = T5Tokenizer.from_pretrained(T5_CHECKPOINT)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_CHECKPOINT)
t5_model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
t5_model = t5_model.to(device)

print(f"T5 model loaded on {device}")

## T5 Multi-Candidate Generation

In [None]:
#@title T5 multi-candidate generation functions

def t5_generate_candidates(text, model, tokenizer, num_candidates,
                           temperature=1.0, top_k=50, top_p=0.95,
                           max_length=128, device="cuda"):
    """
    Generate num_candidates different outputs via sampling.
    """
    input_text = f"detoxify: {text}"
    input_ids = tokenizer.encode(input_text, return_tensors='pt',
                                  max_length=max_length, truncation=True)
    input_ids = input_ids.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_candidates,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            no_repeat_ngram_size=2
        )

    candidates = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
    return candidates

def t5_generate_candidates_batch(texts, model, tokenizer, num_candidates,
                                  temperature=1.0, top_k=50, top_p=0.95,
                                  max_length=128, device="cuda"):
    """
    Batch generation of candidates for multiple inputs.
    """
    all_candidates = []
    for text in tqdm(texts, desc="T5 Generation"):
        candidates = t5_generate_candidates(
            text, model, tokenizer, num_candidates,
            temperature, top_k, top_p, max_length, device
        )
        all_candidates.append(candidates)
    return all_candidates

# Test
test_text = "This is a stupid idea"
candidates = t5_generate_candidates(test_text, t5_model, t5_tokenizer,
                                     num_candidates=3, device=device)
print(f"Input: {test_text}")
print(f"Candidates:")
for i, c in enumerate(candidates):
    print(f"  [{i}]: {c}")

## DecompX Reranking

In [None]:
#@title DecompX reranking functions

# Global masker instance (lazy loaded)
_decompx_masker = None

def _get_decompx_masker():
    global _decompx_masker
    if _decompx_masker is None:
        _decompx_masker = Masker()
    return _decompx_masker

def decompx_toxicity_score(texts: List[str]) -> List[float]:
    """
    Get toxicity scores using DecompX attribution.
    Returns sum of positive toxicity attributions for each text.
    Lower score = less toxic.
    """
    masker = _get_decompx_masker()
    scores = []

    for text in texts:
        # Process text with high threshold (no masking, just get attribution)
        result = masker.process_text(sentence=[text], threshold=1.0)
        # Get the toxicity attribution sum
        # Higher positive attribution = more toxic
        attr = masker.get_last_attributions()
        if attr is not None and len(attr) > 0:
            # Sum positive attributions
            pos_attr = sum(a for a in attr[0] if a > 0)
            scores.append(pos_attr)
        else:
            scores.append(0.0)

    return scores

def rerank_with_decompx(sources: List[str], candidates: List[List[str]]) -> List[str]:
    """
    Rerank candidates using DecompX toxicity scores.
    Select candidate with lowest toxicity for each source.

    Args:
        sources: List of original toxic texts (N)
        candidates: List of candidate lists (N x C)

    Returns:
        List of best candidates (N)
    """
    N = len(sources)
    assert len(candidates) == N

    best_outputs = []

    for i, cand_list in enumerate(tqdm(candidates, desc="DecompX Reranking")):
        if len(cand_list) == 1:
            best_outputs.append(cand_list[0])
            continue

        # Score all candidates for this input
        scores = decompx_toxicity_score(cand_list)

        # Select candidate with lowest toxicity score
        best_idx = np.argmin(scores)
        best_outputs.append(cand_list[best_idx])

    return best_outputs

def release_decompx():
    """Release DecompX model to free GPU memory."""
    global _decompx_masker
    if _decompx_masker is not None:
        _decompx_masker.release_model()
        _decompx_masker = None
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

print("DecompX reranking functions loaded")

## Evaluation Functions

In [None]:
#@title Load evaluation models
print("Loading evaluation models...")

# Toxicity classifier
tox_tokenizer = AutoTokenizer.from_pretrained("s-nlp/roberta_toxicity_classifier")
tox_model = AutoModelForSequenceClassification.from_pretrained("s-nlp/roberta_toxicity_classifier")
tox_model.eval()
tox_model = tox_model.to(device)

# Perplexity model (GPT-2)
ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
ppl_model.eval()
ppl_model = ppl_model.to(device)
if ppl_tokenizer.pad_token is None:
    ppl_tokenizer.pad_token = ppl_tokenizer.eos_token

# Sentence embeddings for MeaningBERT
sim_model = SentenceTransformer('all-MiniLM-L6-v2')

# Metrics
bleu_metric = load("sacrebleu")
bertscore_metric = load("bertscore")

print("Evaluation models loaded")

In [None]:
#@title Evaluation functions

def compute_toxicity(texts, tokenizer, model, device="cuda", batch_size=32):
    """
    Compute average toxicity score.
    """
    all_scores = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            toxic_probs = predictions[:, 1]  # Label 1 = toxic
            all_scores.extend(toxic_probs.cpu().tolist())

    return np.mean(all_scores)

def compute_perplexity(texts, tokenizer, model, device="cuda"):
    """
    Compute average perplexity.
    """
    perplexities = []

    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            ppl = torch.exp(loss).item()
            perplexities.append(ppl)

    return np.mean(perplexities)

def compute_bertscore(predictions, references):
    """
    Compute BERTScore.
    """
    result = bertscore_metric.compute(predictions=predictions, references=references, lang="en")
    return np.mean(result['f1'])

def compute_bleu(predictions, references):
    """
    Compute BLEU score.
    """
    formatted_refs = [[ref] for ref in references]
    result = bleu_metric.compute(predictions=predictions, references=formatted_refs)
    return result['score']

def compute_meaningbert(predictions, references):
    """
    Compute MeaningBERT score using sentence-transformers.
    Returns cosine similarity scaled to 0-100.
    """
    pred_embs = sim_model.encode(predictions, convert_to_tensor=True)
    ref_embs = sim_model.encode(references, convert_to_tensor=True)
    cosine_scores = torch.nn.functional.cosine_similarity(pred_embs, ref_embs)
    return cosine_scores.mean().item() * 100

def evaluate_all(orig_texts, gen_texts, device="cuda"):
    """
    Run all evaluations including MeaningBERT.
    """
    results = {}

    print("  Computing toxicity scores...")
    results['toxicity_gen'] = compute_toxicity(gen_texts, tox_tokenizer, tox_model, device)
    results['toxicity_orig'] = compute_toxicity(orig_texts, tox_tokenizer, tox_model, device)

    print("  Computing perplexity...")
    results['perplexity_gen'] = compute_perplexity(gen_texts, ppl_tokenizer, ppl_model, device)
    results['perplexity_orig'] = compute_perplexity(orig_texts, ppl_tokenizer, ppl_model, device)

    print("  Computing BERTScore...")
    results['bertscore'] = compute_bertscore(gen_texts, orig_texts)

    print("  Computing MeaningBERT...")
    results['meaningbert'] = compute_meaningbert(gen_texts, orig_texts)

    print("  Computing BLEU...")
    results['bleu4'] = compute_bleu(gen_texts, orig_texts)

    return results

print("Evaluation functions defined")

## Main Pipeline Function

In [None]:
#@title detoxify() - T5 + DecompX reranking pipeline

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_DecompX-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
):
    """
    T5-ParaDetox pipeline with DecompX reranking.

    1. Generate num_candidates detoxified texts per input using T5 sampling
    2. Score each candidate using DecompX toxicity attribution
    3. Select candidate with lowest toxicity score
    4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"

    # Output paths
    base_out_dir = os.path.join(XDETOX_DIR, "data", "model_outputs", output_folder)
    data_out_dir = os.path.join(base_out_dir, data_type)
    _ensure_dir(data_out_dir)

    orig_path = os.path.join(data_out_dir, "orig.txt")
    gen_path = os.path.join(data_out_dir, "gen.txt")
    stats_path = os.path.join(data_out_dir, "gen_stats.txt")

    # Load data
    print(f"\n[{data_type}] Loading data...")
    orig_texts = load_test_data(data_type, num_examples)
    print(f"  Loaded {len(orig_texts)} examples")

    if echo:
        print(f"\n[echo] Example inputs (first 3):")
        for i, s in enumerate(orig_texts[:3]):
            print(f"  input[{i}]: {s}")

    # Generate or load
    if overwrite_gen or not os.path.exists(gen_path):
        print(f"  Generating {num_candidates} candidates per input...")
        all_candidates = t5_generate_candidates_batch(
            orig_texts, t5_model, t5_tokenizer, num_candidates,
            temperature, top_k, top_p, max_length, device
        )

        if echo:
            print(f"\n[echo] Example candidates for input[0]:")
            for j, c in enumerate(all_candidates[0][:3]):
                print(f"    candidate[{j}]: {c}")

        print(f"  Reranking with DecompX...")
        gen_texts = rerank_with_decompx(orig_texts, all_candidates)
        release_decompx()  # Free GPU memory

        if echo:
            print(f"\n[echo] Selected outputs (first 3):")
            for i, g in enumerate(gen_texts[:3]):
                print(f"  output[{i}]: {g}")

        # Save outputs
        with open(orig_path, 'w') as f:
            for t in orig_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + '\n')
        with open(gen_path, 'w') as f:
            for t in gen_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + '\n')

        print(f"  Saved outputs to {data_out_dir}")
    else:
        print(f"  Loading existing outputs...")
        with open(orig_path, 'r') as f:
            orig_texts = [l.strip() for l in f]
        with open(gen_path, 'r') as f:
            gen_texts = [l.strip() for l in f]
        print(f"  Loaded {len(gen_texts)} examples")

    # Evaluate
    if run_eval and (overwrite_eval or not os.path.exists(stats_path)):
        print(f"  Running evaluation...")
        results = evaluate_all(orig_texts, gen_texts, device)

        with open(stats_path, 'w') as f:
            for k, v in results.items():
                f.write(f"{k}: {v}\n")

        if echo:
            print(f"\n[echo] Evaluation metrics:")
            for k, v in results.items():
                print(f"  {k}: {v:.4f}")

        print(f"  Saved stats to {stats_path}")
        return results

    elif run_eval:
        print(f"  Loading existing stats...")
        results = _read_stats_file(stats_path)
        return results

    return None

print("detoxify() function defined")

## Run Evaluation

In [None]:
#@title Example run on paradetox
detoxify(
    data_type="paradetox",
    output_folder="T5_w_DecompX-Reranking",
    num_examples=1000,
    num_candidates=10,
    temperature=1.0,
    run_eval=True,
    overwrite_gen=True,
    overwrite_eval=True,
    echo=True,
)

In [None]:
#@title Run on multiple datasets

datasets_to_eval = ["paradetox", "microagressions_test", "sbf_test", "dynabench_test"]
num_examples = 200
output_folder = "T5_w_DecompX-Reranking"

all_results = {}

print("=" * 80)
print("T5-PARADETOX + DECOMPX RERANKING PIPELINE")
print("=" * 80)

for dataset_name in datasets_to_eval:
    try:
        results = detoxify(
            data_type=dataset_name,
            output_folder=output_folder,
            batch_size=8,
            max_length=128,
            num_examples=num_examples,
            num_candidates=10,
            temperature=1.0,
            overwrite_gen=False,
            run_eval=True,
            overwrite_eval=False,
            echo=False,
        )

        if results:
            all_results[dataset_name] = results
            print(f"  {dataset_name} complete!")

    except Exception as e:
        print(f"  Error on {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "=" * 80)

## Results Summary

In [None]:
#@title Display results table

if all_results:
    rows = []
    for dataset_name, results in all_results.items():
        row = {'dataset': dataset_name}
        row.update(results)
        rows.append(row)

    df = pd.DataFrame(rows)

    col_order = [
        'dataset',
        'bertscore',
        'meaningbert',
        'bleu4',
        'perplexity_gen',
        'perplexity_orig',
        'toxicity_gen',
        'toxicity_orig',
    ]
    df = df[[col for col in col_order if col in df.columns]]

    # Save to CSV
    summary_csv = os.path.join(XDETOX_DIR, "data", "model_outputs", output_folder, "t5_decompx_summary.csv")
    df.to_csv(summary_csv, index=False)
    print(f"Saved summary to {summary_csv}\n")

    # Display
    print("=" * 80)
    print("T5-PARADETOX + DECOMPX RERANKING RESULTS")
    print("=" * 80)
    print(df.to_string(index=False))
    print("=" * 80)
else:
    print("No results available.")