# T5-ParaDetox Pipeline with DecompX Reranking

This notebook combines:
- **T5-base** fine-tuned on ParaDetox for detoxification
- **DecompX reranking** to select the least toxic candidate from multiple generations

## Pipeline

1. Generate `num_candidates` detoxified texts per input using T5 sampling
2. Score each candidate using DecompX toxicity attribution (RoBERTa-based)
3. Select candidate with lowest toxicity score
4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity

---

## `detoxify()` API

```python
def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_DecompX-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
)
```

### Key Arguments

- `data_type`: Dataset key (paradetox, microagressions_test, sbf_test, dynabench_test, jigsaw_toxic, appdia_original, appdia_discourse)
- `output_folder`: Folder under `data/model_outputs/` for results
- `num_candidates`: Number of candidates to generate per input for reranking
- `temperature`: Sampling temperature for diversity (higher = more diverse)
- `echo`: If True, print example inputs, candidates, and outputs

## Setup

In [2]:
#@title Mount Drive, Imports & locate XDetox
from google.colab import drive; drive.mount('/content/drive')

import os, sys, torch

# Use the same layout as your working XDetox LLM pipeline
candidate = "/content/drive/MyDrive/w266 - Project/XDetox"
print("Try MyDrive:", candidate, "->", os.path.isdir(candidate))

XDETOX_DIR = candidate
print("Using XDETOX_DIR:", XDETOX_DIR)
assert os.path.isdir(XDETOX_DIR), f"XDETOX_DIR does not exist: {XDETOX_DIR}"

# Project root = parent of XDETOX_DIR
PROJECT_BASE = os.path.dirname(XDETOX_DIR)

# T5 checkpoint lives next to XDetox repo
T5_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")

print("PROJECT_BASE:", PROJECT_BASE)
print("XDETOX_DIR:", XDETOX_DIR)
print("T5_CHECKPOINT:", T5_CHECKPOINT)

# Runtime setup (paths, cache, GPU)
HF_CACHE = os.path.join(XDETOX_DIR, "cache")
os.makedirs(HF_CACHE, exist_ok=True)
os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

if XDETOX_DIR not in sys.path:
    sys.path.append(XDETOX_DIR)

print("TRANSFORMERS_CACHE:", HF_CACHE)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

# Optional: verify XDetox layout like your LLM pipeline
for d in ["rewrite", "evaluation", "datasets"]:
    assert os.path.isdir(os.path.join(XDETOX_DIR, d)), f"Missing folder: {d}"
print("Repo folders OK.")

# This is what your T5 data loader uses
DATASET_BASE = XDETOX_DIR


Mounted at /content/drive
Try MyDrive: /content/drive/MyDrive/w266 - Project/XDetox -> True
Using XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
PROJECT_BASE: /content/drive/MyDrive/w266 - Project
XDETOX_DIR: /content/drive/MyDrive/w266 - Project/XDetox
T5_CHECKPOINT: /content/drive/MyDrive/w266 - Project/t5-base-detox-model
TRANSFORMERS_CACHE: /content/drive/MyDrive/w266 - Project/XDetox/cache
CUDA available: True
GPU: Tesla T4
Repo folders OK.


In [3]:
#@title Install dependencies
!pip install -q transformers torch datasets
!pip install -q evaluate sacrebleu bert-score
!pip install -q sentence-transformers accelerate -U
!pip install -q rouge_score pandas numpy scikit-learn matplotlib nltk

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m10.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone


In [4]:
#@title NLTK data
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except Exception:
    pass
print("NLTK ready")

NLTK ready


In [6]:
#@title Imports
import glob, re, json, shutil, math
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from pathlib import Path
from typing import List, Tuple

from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    AutoTokenizer, AutoModelForSequenceClassification,
    GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast
)
from sentence_transformers import SentenceTransformer
from evaluate import load

# --- DecompX compatibility fixes for newer Transformers versions ---
import transformers.modeling_utils as modeling_utils

# 1) apply_chunking_to_forward moved to transformers.pytorch_utils
try:
    from transformers.modeling_utils import apply_chunking_to_forward
except ImportError:
    from transformers.pytorch_utils import apply_chunking_to_forward
    modeling_utils.apply_chunking_to_forward = apply_chunking_to_forward

# 2) find_pruneable_heads_and_indices moved out of modeling_utils
try:
    from transformers.modeling_utils import find_pruneable_heads_and_indices
except ImportError:
    # In recent versions it still exists in the BERT modeling file
    try:
        from transformers.models.bert.modeling_bert import find_pruneable_heads_and_indices
        modeling_utils.find_pruneable_heads_and_indices = find_pruneable_heads_and_indices
    except ImportError:
        # Fallback stub (should not normally be hit)
        def find_pruneable_heads_and_indices(*args, **kwargs):
            raise NotImplementedError("find_pruneable_heads_and_indices is not available in this Transformers version")
        modeling_utils.find_pruneable_heads_and_indices = find_pruneable_heads_and_indices

# 3) prune_linear_layer may also have moved
try:
    from transformers.modeling_utils import prune_linear_layer
except ImportError:
    try:
        from transformers.models.bert.modeling_bert import prune_linear_layer
        modeling_utils.prune_linear_layer = prune_linear_layer
    except ImportError:
        # Fallback stub
        def prune_linear_layer(*args, **kwargs):
            raise NotImplementedError("prune_linear_layer is not available in this Transformers version")
        modeling_utils.prune_linear_layer = prune_linear_layer

# DecompX (Masker) from XDetox — same alias as in your LLM pipeline
from rewrite.mask_orig import Masker as Masker_single

print("Libraries imported")


Libraries imported


## Dataset Configuration

In [7]:
#@title Data configs (matching XDetox datasets)
data_configs = {
    "paradetox": {
        "data_path": "./datasets/paradetox/test_toxic_parallel.txt",
        "format": "txt",
    },
    "microagressions_test": {
        "data_path": "./datasets/microagressions/test.csv",
        "format": "csv",
    },
    "sbf_test": {
        "data_path": "./datasets/sbf/sbftst.csv",
        "format": "csv",
    },
    "dynabench_test": {
        "data_path": "./datasets/dynabench/db_test.csv",
        "format": "csv",
    },
    "jigsaw_toxic": {
        "data_path": "./datasets/jigsaw_full_30/test_10k_toxic.txt",
        "format": "txt",
    },
    "appdia_original": {
        "data_path": "./datasets/appdia/original-annotated-data/original-test.tsv",
        "format": "tsv",
    },
    "appdia_discourse": {
        "data_path": "./datasets/appdia/discourse-augmented-data/discourse-test.tsv",
        "format": "tsv",
    },
}

print(f"{len(data_configs)} datasets configured:")
for name in data_configs.keys():
    print(f"  - {name}")

7 datasets configured:
  - paradetox
  - microagressions_test
  - sbf_test
  - dynabench_test
  - jigsaw_toxic
  - appdia_original
  - appdia_discourse


## Helper Functions

In [8]:
#@title Helper functions

def _ensure_dir(p: str):
    Path(p).mkdir(parents=True, exist_ok=True)

def load_test_data(data_type: str, num_examples: int = None) -> List[str]:
    """
    Load test data from various formats (.txt, .csv, .tsv).
    Returns a list of toxic texts as strings.
    """
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    cfg = data_configs[data_type]
    data_path = os.path.join(DATASET_BASE, cfg["data_path"].lstrip("./"))

    texts = []

    if cfg["format"] == "txt":
        with open(data_path, 'r', encoding='utf-8') as f:
            texts = [line.strip() for line in f if line.strip()]

    elif cfg["format"] == "csv":
        df = pd.read_csv(data_path)
        if 'text' in df.columns:
            texts = df['text'].tolist()
        elif 'toxic' in df.columns:
            texts = df['toxic'].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    elif cfg["format"] == "tsv":
        df = pd.read_csv(data_path, sep='\t')
        if 'text' in df.columns:
            texts = df['text'].tolist()
        else:
            texts = df.iloc[:, 0].tolist()

    # Clean and convert to strings
    cleaned_texts = []
    for text in texts:
        if pd.isna(text):
            continue
        text_str = str(text).strip()
        if text_str:
            cleaned_texts.append(text_str)

    if num_examples and num_examples > 0:
        cleaned_texts = cleaned_texts[:num_examples]

    return cleaned_texts

def _safe_float(x):
    try:
        return float(x)
    except Exception:
        return float('nan')

def _read_stats_file(path: str) -> dict:
    """Read gen_stats.txt into dict."""
    out = {}
    with open(path, "r") as f:
        for line in f:
            if ":" not in line:
                continue
            k, v = line.strip().split(": ", 1)
            k = k.replace("(skipped)", "").strip().lower()
            out[k] = _safe_float(v)
    return out

print("Helper functions loaded")

Helper functions loaded


## T5 Model Loading

In [9]:
#@title Load T5 model
print(f"Loading T5 model from {T5_CHECKPOINT}...")

t5_tokenizer = T5Tokenizer.from_pretrained(T5_CHECKPOINT)
t5_model = T5ForConditionalGeneration.from_pretrained(T5_CHECKPOINT)
t5_model.eval()

device = "cuda" if torch.cuda.is_available() else "cpu"
t5_model = t5_model.to(device)

print(f"T5 model loaded on {device}")




Loading T5 model from /content/drive/MyDrive/w266 - Project/t5-base-detox-model...
T5 model loaded on cuda


In [10]:
#@title T5 multi-candidate generation (with batching)

def t5_generate_candidates(
    text: str,
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    device: str = "cuda",
) -> List[str]:
    """
    Generate num_candidates different outputs via sampling for a single input.
    """
    input_text = f"detoxify: {text}"
    input_ids = tokenizer.encode(
        input_text,
        return_tensors='pt',
        max_length=max_length,
        truncation=True,
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=num_candidates,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            no_repeat_ngram_size=2,
        )

    candidates = [tokenizer.decode(out, skip_special_tokens=True) for out in outputs]
    return candidates

def t5_generate_candidates_batch(
    texts: List[str],
    model: T5ForConditionalGeneration,
    tokenizer: T5Tokenizer,
    num_candidates: int,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    max_length: int = 128,
    batch_size: int = 8,
    device: str = "cuda",
) -> List[List[str]]:
    """
    Batch generation of candidates for multiple inputs.

    Uses a real batch dimension for speed:
      - Each batch has size B.
      - We ask for num_candidates sequences per input.
      - Outputs are reshaped into [N][num_candidates].
    """
    all_candidates: List[List[str]] = []
    for i in tqdm(range(0, len(texts), batch_size), desc="T5 Generation"):
        batch_texts = texts[i:i + batch_size]
        prompts = [f"detoxify: {t}" for t in batch_texts]

        enc = tokenizer(
            prompts,
            return_tensors="pt",
            max_length=max_length,
            truncation=True,
            padding=True,
        ).to(device)

        with torch.no_grad():
            outputs = model.generate(
                **enc,
                max_length=max_length,
                num_return_sequences=num_candidates,
                do_sample=True,
                temperature=temperature,
                top_k=top_k,
                top_p=top_p,
                no_repeat_ngram_size=2,
            )

        # outputs shape: (B * num_candidates, seq_len)
        decoded = [
            tokenizer.decode(o, skip_special_tokens=True)
            for o in outputs
        ]

        B = len(batch_texts)
        for b in range(B):
            start = b * num_candidates
            end = (b + 1) * num_candidates
            all_candidates.append(decoded[start:end])

    return all_candidates

# Quick sanity test
test_text = "This is a stupid idea"
test_candidates = t5_generate_candidates(
    test_text, t5_model, t5_tokenizer, num_candidates=3, device=device
)
print(f"Input: {test_text}")
for i, c in enumerate(test_candidates):
    print(f"  [{i}]: {c}")

Input: This is a stupid idea
  [0]: This is a bad idea.
  [1]: This is a bad idea
  [2]: This is a bad idea.


## DecompX reranking (using Masker, mask-count-based score)

In [11]:
#@title DecompX reranking helpers

def _decompx_mask_texts(texts: List[str],
                        threshold: float = 0.20,
                        batch_size: int = 16) -> List[str]:
    if not texts:
        return []

    masker = Masker_single()
    masked_all = []
    for i in tqdm(range(0, len(texts), batch_size),
                  desc="DecompX masking for reranking", leave=False):
        batch = texts[i:i + batch_size]
        batch_out = masker.process_text(sentence=batch, threshold=threshold)
        masked_all.extend(batch_out)

    cleaned = [
        m.replace("<s>", "").replace("</s>", "").strip()
        for m in masked_all
    ]
    masker.release_model()
    return cleaned


def _decompx_toxicity_scores(
    texts: List[str],
    threshold: float = 0.20,
    batch_size: int = 16,
) -> np.ndarray:
    """
    Score texts by DecompX toxicity:

      score = (# of <mask> tokens DecompX inserts) / (# tokens)

    Lower score => less toxic.
    """
    if not texts:
        return np.zeros((0,), dtype=float)

    masked = _decompx_mask_texts(
        texts,
        threshold=threshold,
        batch_size=batch_size,
    )

    scores = []
    for m in masked:
        num_masks = len(re.findall(r"<mask>", m))
        tokens = m.split()
        length = max(len(tokens), 1)
        scores.append(num_masks / length)

    return np.asarray(scores, dtype=float)

def rerank_candidates_decompx(
    sources: List[str],
    candidates: List[List[str]],
    threshold: float = 0.20,
    batch_size_mask: int = 16,
) -> List[str]:
    """
    DecompX-based reranking:

    1. Flatten all candidates.
    2. Score each candidate with DecompX.
    3. For each source, choose the candidate with lowest score.

    Returns:
        best_outputs: List of chosen candidates (len = N)
    """
    N = len(sources)
    assert len(candidates) == N, "candidates length mismatch"

    if N == 0:
        return []

    C_list = [len(c) for c in candidates]
    if len(set(C_list)) != 1:
        raise ValueError("All inputs must have same num_candidates")
    C = C_list[0]
    if C == 0:
        raise ValueError("num_candidates must be >= 1")

    flat_cands: List[str] = []
    for cand_list in candidates:
        flat_cands.extend(cand_list)

    scores = _decompx_toxicity_scores(
        flat_cands,
        threshold=threshold,
        batch_size=batch_size_mask,
    )  # shape [N*C]

    scores_matrix = scores.reshape(N, C)
    best_idx = np.argmin(scores_matrix, axis=1)

    best_outputs = [
        candidates[i][best_idx[i]] for i in range(N)
    ]
    return best_outputs

print("DecompX reranking functions loaded")

DecompX reranking functions loaded


## Evaluation models and metrics

In [12]:
#@title Load evaluation models
print("Loading evaluation models...")

# Toxicity classifier (RoBERTa)
eval_tox_tokenizer = AutoTokenizer.from_pretrained("s-nlp/roberta_toxicity_classifier")
eval_tox_model = AutoModelForSequenceClassification.from_pretrained("s-nlp/roberta_toxicity_classifier")
eval_tox_model.eval()
eval_tox_model = eval_tox_model.to(device)

# Perplexity model (GPT-2 medium)
eval_ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
eval_ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
eval_ppl_model.eval()
eval_ppl_model = eval_ppl_model.to(device)
if eval_ppl_tokenizer.pad_token is None:
    eval_ppl_tokenizer.pad_token = eval_ppl_tokenizer.eos_token

# Sentence embeddings for MeaningBERT-style similarity
sim_model = SentenceTransformer('all-MiniLM-L6-v2')

# Metrics
bleu_metric = load("sacrebleu")
bertscore_metric = load("bertscore")

print("Evaluation models loaded")

Loading evaluation models...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Evaluation models loaded


## Evaluation Functions

In [13]:
#@title Load evaluation models
print("Loading evaluation models...")

# Toxicity classifier
tox_tokenizer = AutoTokenizer.from_pretrained("s-nlp/roberta_toxicity_classifier")
tox_model = AutoModelForSequenceClassification.from_pretrained("s-nlp/roberta_toxicity_classifier")
tox_model.eval()
tox_model = tox_model.to(device)

# Perplexity model (GPT-2)
ppl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
ppl_model = GPT2LMHeadModel.from_pretrained("gpt2-medium")
ppl_model.eval()
ppl_model = ppl_model.to(device)
if ppl_tokenizer.pad_token is None:
    ppl_tokenizer.pad_token = ppl_tokenizer.eos_token

# Sentence embeddings for MeaningBERT
sim_model = SentenceTransformer('all-MiniLM-L6-v2')

# Metrics
bleu_metric = load("sacrebleu")
bertscore_metric = load("bertscore")

print("Evaluation models loaded")

Loading evaluation models...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

Evaluation models loaded


In [14]:
#@title Evaluation functions

def compute_toxicity(
    texts: List[str],
    tokenizer,
    model,
    device: str = "cuda",
    batch_size: int = 32,
) -> float:
    """
    Compute average toxicity score (probability of 'toxic' class).
    """
    all_scores: List[float] = []

    for i in range(0, len(texts), batch_size):
        batch = texts[i:i + batch_size]
        inputs = tokenizer(
            batch,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
            toxic_probs = predictions[:, 1]  # label 1 = toxic
            all_scores.extend(toxic_probs.cpu().tolist())

    return float(np.mean(all_scores)) if all_scores else float("nan")

def compute_perplexity(
    texts: List[str],
    tokenizer,
    model,
    device: str = "cuda",
) -> float:
    """
    Compute average perplexity using GPT-2-medium.
    """
    perplexities: List[float] = []

    for text in texts:
        inputs = tokenizer(
            text,
            return_tensors="pt",
            truncation=True,
            max_length=512,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs, labels=inputs["input_ids"])
            loss = outputs.loss
            ppl = torch.exp(loss).item()
            perplexities.append(ppl)

    return float(np.mean(perplexities)) if perplexities else float("nan")

def compute_bertscore(predictions: List[str], references: List[str]) -> float:
    """
    Compute BERTScore F1.
    """
    result = bertscore_metric.compute(
        predictions=predictions,
        references=references,
        lang="en",
    )
    return float(np.mean(result['f1']))

def compute_bleu(predictions: List[str], references: List[str]) -> float:
    """
    Compute BLEU score (sacrebleu).
    """
    formatted_refs = [[ref] for ref in references]
    result = bleu_metric.compute(
        predictions=predictions,
        references=formatted_refs,
    )
    return float(result['score'])

def compute_meaningbert(predictions: List[str], references: List[str]) -> float:
    """
    Compute MeaningBERT-style score using sentence-transformers cosine similarity.
    Returns cosine similarity scaled to 0-100.
    """
    pred_embs = sim_model.encode(predictions, convert_to_tensor=True)
    ref_embs = sim_model.encode(references, convert_to_tensor=True)
    cosine_scores = torch.nn.functional.cosine_similarity(pred_embs, ref_embs)
    return float(cosine_scores.mean().item() * 100)

def evaluate_all(
    orig_texts: List[str],
    gen_texts: List[str],
    device: str = "cuda",
) -> dict:
    """
    Run all evaluations: toxicity, perplexity, BERTScore, MeaningBERT, BLEU.
    """
    results = {}

    print("  Computing toxicity scores...")
    results['toxicity_gen'] = compute_toxicity(gen_texts, eval_tox_tokenizer, eval_tox_model, device)
    results['toxicity_orig'] = compute_toxicity(orig_texts, eval_tox_tokenizer, eval_tox_model, device)

    print("  Computing perplexity...")
    results['perplexity_gen'] = compute_perplexity(gen_texts, eval_ppl_tokenizer, eval_ppl_model, device)
    results['perplexity_orig'] = compute_perplexity(orig_texts, eval_ppl_tokenizer, eval_ppl_model, device)

    print("  Computing BERTScore...")
    results['bertscore'] = compute_bertscore(gen_texts, orig_texts)

    print("  Computing MeaningBERT...")
    results['meaningbert'] = compute_meaningbert(gen_texts, orig_texts)

    print("  Computing BLEU...")
    results['bleu4'] = compute_bleu(gen_texts, orig_texts)

    return results

print("Evaluation functions defined")

Evaluation functions defined


## Main Pipeline Function

In [15]:
#@title detoxify() — T5 + DecompX Reranking pipeline

def detoxify(
    data_type: str = "paradetox",
    output_folder: str = "T5_w_DecompX-Reranking",
    batch_size: int = 8,
    max_length: int = 128,
    num_examples: int = 100,
    num_candidates: int = 10,
    temperature: float = 1.0,
    top_k: int = 50,
    top_p: float = 0.95,
    overwrite_gen: bool = False,
    run_eval: bool = True,
    overwrite_eval: bool = False,
    echo: bool = False,
    # DecompX reranking parameters
    decompx_threshold: float = 0.20,
    decompx_batch_size: int = 16,
):
    """
    T5-ParaDetox pipeline with DecompX reranking.

    Steps:
      1. Generate `num_candidates` detoxified texts per input using T5 sampling.
      2. Score each candidate with DecompX (mask-count-based toxicity).
      3. Select the candidate with lowest DecompX score.
      4. Evaluate with BLEU, BERTScore, MeaningBERT, Perplexity, Toxicity.
    """
    assert data_type in data_configs, f"Unknown data_type: {data_type}"

    # Output paths
    base_out_dir = os.path.join(XDETOX_DIR, "data", "model_outputs", output_folder)
    data_out_dir = os.path.join(base_out_dir, data_type)
    _ensure_dir(data_out_dir)

    orig_path = os.path.join(data_out_dir, "orig.txt")
    gen_path = os.path.join(data_out_dir, "gen.txt")
    stats_path = os.path.join(data_out_dir, "gen_stats.txt")

    # Load data
    print(f"\n[{data_type}] Loading data...")
    orig_texts = load_test_data(data_type, num_examples)
    print(f"  Loaded {len(orig_texts)} examples")

    if echo:
        print(f"\n[echo] Example inputs (first 3):")
        for i, s in enumerate(orig_texts[:3]):
            print(f"  input[{i}]: {s}")
        print(f"\n[echo] DecompX threshold: {decompx_threshold}")
        print(f"[echo] num_candidates per input: {num_candidates}")

    # Generate or load detoxified outputs
    if overwrite_gen or not os.path.exists(gen_path):
        print(f"  Generating {num_candidates} candidates per input with T5...")
        all_candidates = t5_generate_candidates_batch(
            orig_texts,
            t5_model,
            t5_tokenizer,
            num_candidates=num_candidates,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            max_length=max_length,
            batch_size=batch_size,
            device=device,
        )

        if echo:
            print(f"\n[echo] Example candidates for input[0]:")
            for j, c in enumerate(all_candidates[0][:min(3, len(all_candidates[0]))]):
                print(f"    candidate[{j}]: {c}")

        print(f"  Reranking with DecompX (threshold={decompx_threshold})...")
        gen_texts = rerank_candidates_decompx(
            orig_texts,
            all_candidates,
            threshold=decompx_threshold,
            batch_size_mask=decompx_batch_size,
        )

        if echo:
            print(f"\n[echo] Selected outputs (first 3):")
            for i, g in enumerate(gen_texts[:3]):
                print(f"  output[{i}]: {g}")

        # Save outputs
        with open(orig_path, 'w') as f:
            for t in orig_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + '\n')
        with open(gen_path, 'w') as f:
            for t in gen_texts:
                f.write(re.sub(r"\s+", " ", t).strip() + '\n')

        print(f"  Saved outputs to {data_out_dir}")
    else:
        print(f"  Loading existing outputs...")
        with open(orig_path, 'r') as f:
            orig_texts = [l.strip() for l in f]
        with open(gen_path, 'r') as f:
            gen_texts = [l.strip() for l in f]
        print(f"  Loaded {len(gen_texts)} examples")

    # Evaluate
    if run_eval and (overwrite_eval or not os.path.exists(stats_path)):
        print(f"  Running evaluation...")
        results = evaluate_all(orig_texts, gen_texts, device)

        with open(stats_path, 'w') as f:
            for k, v in results.items():
                f.write(f"{k}: {v}\n")

        if echo:
            print(f"\n[echo] Evaluation metrics:")
            for k, v in results.items():
                print(f"  {k}: {v:.4f}" if isinstance(v, (int, float)) else f"  {k}: {v}")

        print(f"  Saved stats to {stats_path}")
        return results

    elif run_eval:
        print(f"  Loading existing stats...")
        results = _read_stats_file(stats_path)
        return results

    return None

print("detoxify() function defined")

detoxify() function defined


## Run Evaluation

In [16]:
#@title Example run on ParaDetox
detoxify(
    data_type="paradetox",
    output_folder="T5_w_DecompX-Reranking",
    batch_size=8,
    max_length=128,
    num_examples=1000,
    num_candidates=10,
    temperature=1.0,
    top_k=50,
    top_p=0.95,
    overwrite_gen=True,
    run_eval=True,
    overwrite_eval=True,
    echo=True,
    decompx_threshold=0.20,
    decompx_batch_size=16,
    )


[paradetox] Loading data...
  Loaded 671 examples

[echo] Example inputs (first 3):
  input[0]: . or the loud ass one - thousand ton beast roaring towards you howling its horn .
  input[1]: mandated  and " right fucking now " would be good .
  input[2]: * neither * of my coworkers gave a shit when it came time to ditch mitch . ugh .

[echo] DecompX threshold: 0.2
[echo] num_candidates per input: 10
  Generating 10 candidates per input with T5...


T5 Generation:   0%|          | 0/84 [00:00<?, ?it/s]


[echo] Example candidates for input[0]:
    candidate[0]: or the one - thousand ton beast roaring towards you howling its horn
    candidate[1]: Or a loud one - thousand ton beast roaring towards you howling its horn.
    candidate[2]: or the one thousand ton beast roaring towards you howling its horn .
  Reranking with DecompX (threshold=0.2)...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/420 [00:00<?, ?it/s]




[echo] Selected outputs (first 3):
  output[0]: . or the loud one - thousand ton beast roaring toward you howling its horn
  output[1]: Mandated and " right now " would be good .
  output[2]: * neither * of my coworkers cared when it came time to ditch mitch.
  Saved outputs to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/paradetox
  Running evaluation...
  Computing toxicity scores...


`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


  Computing perplexity...
  Computing BERTScore...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  Computing MeaningBERT...




  Computing BLEU...

[echo] Evaluation metrics:
  toxicity_gen: 0.1689
  toxicity_orig: 0.9771
  perplexity_gen: 323.4124
  perplexity_orig: 354.8185
  bertscore: 0.9475
  meaningbert: 83.7222
  bleu4: 58.0672
  Saved stats to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/paradetox/gen_stats.txt


{'toxicity_gen': 0.16889041005790892,
 'toxicity_orig': 0.9771397517339622,
 'perplexity_gen': 323.4124468762189,
 'perplexity_orig': 354.81846606642586,
 'bertscore': 0.9475076117742968,
 'meaningbert': 83.72224569320679,
 'bleu4': 58.06722042930158}

In [17]:
#@title Run on multiple datasets

datasets_to_eval = ["paradetox", "microagressions_test", "sbf_test", "dynabench_test"]
num_examples = 200
output_folder = "T5_w_DecompX-Reranking"

all_results = {}

print("=" * 80)
print("T5-PARADETOX + DECOMPX RERANKING PIPELINE")
print("=" * 80)

for dataset_name in datasets_to_eval:
    try:
        results = detoxify(
            data_type=dataset_name,
            output_folder=output_folder,
            batch_size=8,
            max_length=128,
            num_examples=num_examples,
            num_candidates=10,
            temperature=1.0,
            top_k=50,
            top_p=0.95,
            overwrite_gen=False,
            run_eval=True,
            overwrite_eval=False,
            echo=False,
            decompx_threshold=0.20,
            decompx_batch_size=16,
        )
        if results:
            all_results[dataset_name] = results
            print(f"  {dataset_name} complete!")
    except Exception as e:
        print(f"  Error on {dataset_name}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("\n" + "=" * 80)

T5-PARADETOX + DECOMPX RERANKING PIPELINE

[paradetox] Loading data...
  Loaded 200 examples
  Loading existing outputs...
  Loaded 671 examples
  Loading existing stats...
  paradetox complete!

[microagressions_test] Loading data...
  Loaded 200 examples
  Generating 10 candidates per input with T5...


T5 Generation:   0%|          | 0/25 [00:00<?, ?it/s]

  Reranking with DecompX (threshold=0.2)...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/125 [00:00<?, ?it/s]



  Error on microagressions_test: CUDA out of memory. Tried to allocate 1.45 GiB. GPU 0 has a total capacity of 14.74 GiB of which 1.01 GiB is free. Process 2808 has 13.72 GiB memory in use. Of the allocated memory 13.38 GiB is allocated by PyTorch, and 228.25 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

[sbf_test] Loading data...


Traceback (most recent call last):
  File "/tmp/ipython-input-1125702773.py", line 15, in <cell line: 0>
    results = detoxify(
              ^^^^^^^^^
  File "/tmp/ipython-input-1125045290.py", line 75, in detoxify
    gen_texts = rerank_candidates_decompx(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-1593715700.py", line 88, in rerank_candidates_decompx
    scores = _decompx_toxicity_scores(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-1593715700.py", line 40, in _decompx_toxicity_scores
    masked = _decompx_mask_texts(
             ^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipython-input-1593715700.py", line 14, in _decompx_mask_texts
    batch_out = masker.process_text(sentence=batch, threshold=threshold)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/content/drive/MyDrive/w266 - Project/XDetox/rewrite/mask_orig.py", line 192, in process_text
    df = self.classify_sentence(sentence)
         ^^^^^^^^^^^^^^^^

  Loaded 200 examples
  Generating 10 candidates per input with T5...


T5 Generation:   0%|          | 0/25 [00:00<?, ?it/s]

  Reranking with DecompX (threshold=0.2)...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/125 [00:00<?, ?it/s]



  Saved outputs to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/sbf_test
  Running evaluation...
  Computing toxicity scores...
  Computing perplexity...
  Computing BERTScore...
  Computing MeaningBERT...
  Computing BLEU...
  Saved stats to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/sbf_test/gen_stats.txt
  sbf_test complete!

[dynabench_test] Loading data...
  Loaded 200 examples
  Generating 10 candidates per input with T5...


T5 Generation:   0%|          | 0/25 [00:00<?, ?it/s]

  Reranking with DecompX (threshold=0.2)...


Some weights of the model checkpoint at s-nlp/roberta_toxicity_classifier were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


DecompX masking for reranking:   0%|          | 0/125 [00:00<?, ?it/s]



  Saved outputs to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/dynabench_test
  Running evaluation...
  Computing toxicity scores...
  Computing perplexity...
  Computing BERTScore...
  Computing MeaningBERT...
  Computing BLEU...
  Saved stats to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/dynabench_test/gen_stats.txt
  dynabench_test complete!



## Results Summary

In [18]:
#@title Results summary table

if all_results:
    rows = []
    for dataset_name, results in all_results.items():
        row = {'dataset': dataset_name}
        row.update(results)
        rows.append(row)

    df = pd.DataFrame(rows)

    col_order = [
        'dataset',
        'bertscore',
        'meaningbert',
        'bleu4',
        'perplexity_gen',
        'perplexity_orig',
        'toxicity_gen',
        'toxicity_orig',
    ]
    df = df[[col for col in col_order if col in df.columns]]

    summary_csv = os.path.join(
        XDETOX_DIR, "data", "model_outputs", output_folder, "t5_decompx_summary.csv"
    )
    _ensure_dir(os.path.dirname(summary_csv))
    df.to_csv(summary_csv, index=False)
    print(f"Saved summary to {summary_csv}\n")

    print("=" * 80)
    print("T5-PARADETOX + DECOMPX RERANKING RESULTS")
    print("=" * 80)
    print(df.to_string(index=False))
    print("=" * 80)
else:
    print("No results available.")

Saved summary to /content/drive/MyDrive/w266 - Project/XDetox/data/model_outputs/T5_w_DecompX-Reranking/t5_decompx_summary.csv

T5-PARADETOX + DECOMPX RERANKING RESULTS
       dataset  bertscore  meaningbert     bleu4  perplexity_gen  perplexity_orig  toxicity_gen  toxicity_orig
     paradetox   0.947508    83.722246 58.067220      323.412447       354.818466      0.168890       0.977140
      sbf_test   0.958866    89.372975  1.210058             NaN              NaN      0.000046       0.000045
dynabench_test   0.967334    90.359968 74.822102      200.258896       312.836657      0.373739       0.499826
