# T5-XDetox Pipeline (DecompX Masking + T5 Ensemble + DecompX Reranking)

This notebook runs the T5-XDetox pipeline with:

1. **DecompX masking** – token-level toxicity attribution on RoBERTa to decide which tokens to mask with `<extra_id_0>`.
2. **T5 Ensemble generation** – T5 base + expert (non-toxic) + anti-expert (toxic) with logits combination.
3. **Optional DecompX-based reranking** – generate several candidates and pick the **least toxic** one.
4. **Evaluation** – BLEU, BERTScore, perplexity, and toxicity, plus a summary CSV per dataset.

---

## What this pipeline does

For each chosen dataset:

1. **Masking with DecompX**: Use RoBERTa toxicity classifier with gradient attribution to get **per-token toxicity importance**. Tokens that contribute to toxicity are **replaced with `<extra_id_0>`**.

2. **Generation with T5 Ensemble**: For each masked input, use ensemble of T5 models:
   - **Base** model (your trained t5-base-detox-model)
   - **Expert** model (trained on non-toxic text)
   - **Anti-expert** model (trained on toxic text)
   
   During generation, logits are combined:
   $$\text{logits}_{\text{ens}} = \alpha_b \cdot \text{logits}_{\text{base}} + \alpha_e \cdot \text{logits}_{\text{expert}} - \alpha_a \cdot \text{logits}_{\text{anti}}$$

3. **DecompX-based reranking**: When `ranking=True`, generate `num_candidates` candidates and choose the one with **lowest toxicity importance**.

4. **Evaluation**: Compute BLEU, BERTScore, perplexity, toxicity for each threshold.

---

## Setup Requirements

- **Jigsaw dataset**: Download from [Kaggle](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data)
- **T5 base model**: Pre-trained at `{PROJECT_BASE}/t5-base-detox-model`
- **Expert/Anti-expert**: Train using cells below (2-3 hours each on A100)

In [1]:
#@title Mount Drive & Setup Paths
from google.colab import drive
drive.mount('/content/drive')

import os
import sys

# Base paths (matching T5-ParaDetox)
PROJECT_BASE = "/content/drive/MyDrive/ds266/w266 - Project"
HF_CACHE = os.path.join(PROJECT_BASE, "cache")
T5_BASE_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-base-detox-model")
DATASET_BASE = os.path.join(PROJECT_BASE, "XDetox")

# Expert/Anti-expert paths
T5_EXPERT_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-expert-nontoxic")
T5_ANTIEXPERT_CHECKPOINT = os.path.join(PROJECT_BASE, "t5-antiexpert-toxic")

os.environ["TRANSFORMERS_CACHE"] = HF_CACHE
os.environ["WANDB_DISABLED"] = "true"

print(f"PROJECT_BASE: {PROJECT_BASE}")
print(f"T5_BASE_CHECKPOINT: {T5_BASE_CHECKPOINT}")
print(f"DATASET_BASE: {DATASET_BASE}")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
PROJECT_BASE: /content/drive/MyDrive/ds266/w266 - Project
T5_BASE_CHECKPOINT: /content/drive/MyDrive/ds266/w266 - Project/t5-base-detox-model
DATASET_BASE: /content/drive/MyDrive/ds266/w266 - Project/XDetox


In [2]:
#@title Runtime setup (GPU check)
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU - training will be very slow!")

Device: cuda
GPU: NVIDIA L4
GPU Memory: 23.80 GB


In [3]:
#@title Install dependencies - RESTART RUNTIME AFTER THIS COMPLETES!

# Step 1: Completely uninstall transformers and related packages
!pip uninstall -y transformers tokenizers datasets evaluate

# Step 2: Install tokenizers SEPARATELY with --only-binary flag (to avoid build errors)
!pip install -q --no-cache-dir --only-binary=tokenizers "tokenizers==0.19.1"

# Step 3: Install core transformers packages (WITHOUT --only-binary)
!pip install -q --no-cache-dir \
    "transformers==4.41.2" "datasets==2.19.0" "evaluate==0.4.1" \
    "sacrebleu==2.4.1" sacremoses ftfy nltk matplotlib pandas

# Step 4: Install other dependencies (these won't upgrade transformers)
!pip install -q bert-score sentence-transformers accelerate scikit-learn

print("="*60)
print("✓ Package installation complete!")
print("="*60)
print("\n⚠️ CRITICAL NEXT STEPS:")
print("1. Click 'Runtime' -> 'Restart runtime' NOW!")
print("2. After restart, run the NEXT cell to verify versions")
print("3. Then continue from cell 5")

Found existing installation: transformers 4.41.2
Uninstalling transformers-4.41.2:
  Successfully uninstalled transformers-4.41.2
Found existing installation: tokenizers 0.19.1
Uninstalling tokenizers-0.19.1:
  Successfully uninstalled tokenizers-0.19.1
Found existing installation: datasets 2.19.0
Uninstalling datasets-2.19.0:
  Successfully uninstalled datasets-2.19.0
Found existing installation: evaluate 0.4.1
Uninstalling evaluate-0.4.1:
  Successfully uninstalled evaluate-0.4.1
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m193.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchtune 0.6.1 requires datasets, which is not installed.[0m[31m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.8/43.8 kB[0m [31m139.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━

In [4]:
#@title Verify installed versions (run AFTER restart)
import transformers, tokenizers, datasets, evaluate

print("="*60)
print("INSTALLED VERSIONS:")
print(f"✓ transformers: {transformers.__version__}")
print(f"✓ tokenizers: {tokenizers.__version__}")
print(f"✓ datasets: {datasets.__version__}")
print(f"✓ evaluate: {evaluate.__version__}")
print("="*60)

# Check if correct versions
if transformers.__version__ == "4.41.2":
    print("\n✅ Correct transformers version installed!")
else:
    print(f"\n⚠️ WARNING: Expected transformers 4.41.2, got {transformers.__version__}")



INSTALLED VERSIONS:
✓ transformers: 4.41.2
✓ tokenizers: 0.19.1
✓ datasets: 2.19.0
✓ evaluate: 0.4.1

✅ Correct transformers version installed!


In [5]:
#@title NLTK data
import nltk
nltk.download("punkt", quiet=True)
try:
    nltk.download("punkt_tab", quiet=True)
except:
    pass
print("NLTK ready")

NLTK ready


In [6]:
#@title Import libraries
import os
import sys
import json
import torch
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Core ML libraries
from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    AutoTokenizer, AutoModelForSequenceClassification,
    GPT2Tokenizer, GPT2LMHeadModel,
    Trainer, TrainingArguments, DataCollatorForSeq2Seq
)
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from evaluate import load
import torch.nn.functional as F

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print("✓ Libraries imported")

RuntimeError: Failed to import transformers.trainer because of the following error (look up to see its traceback):
cannot import name 'EncoderDecoderCache' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)

In [None]:
#@title Data configs (matching XDetox datasets)
data_configs = {
    "paradetox": {
        "data_path": os.path.join(DATASET_BASE, "datasets/paradetox/test_toxic_parallel.txt"),
        "alpha_a": 0.5,
        "alpha_e": 2.5,
        "temperature": 1.0,
    },
    "microagressions_test": {
        "data_path": os.path.join(DATASET_BASE, "datasets/microagressions/test.csv"),
        "alpha_a": 0.5,
        "alpha_e": 2.75,
        "temperature": 1.0,
    },
    "sbf_test": {
        "data_path": os.path.join(DATASET_BASE, "datasets/sbf/sbftst.csv"),
        "alpha_a": 0.5,
        "alpha_e": 3.0,
        "temperature": 1.0,
    },
    "dynabench_test": {
        "data_path": os.path.join(DATASET_BASE, "datasets/dynabench/db_test.csv"),
        "alpha_a": 0.5,
        "alpha_e": 2.75,
        "temperature": 1.0,
    },
    "jigsaw_toxic": {
        "data_path": os.path.join(DATASET_BASE, "datasets/jigsaw_full_30/test_10k_toxic.txt"),
        "alpha_a": 0.5,
        "alpha_e": 2.75,
        "temperature": 1.0,
    }
}

print("Datasets:", ", ".join(data_configs.keys()))

## Helper Functions

Core functions for T5-XDetox pipeline:
- Model loading
- DecompX masking
- T5 ensemble generation
- DecompX reranking
- Evaluation
- Data loading

In [None]:
#@title Helper Functions (DecompX masking, T5 ensemble, evaluation)

# ============================================================================
# MODEL LOADING
# ============================================================================
def load_models():
    """Load all required models for T5-XDetox pipeline"""
    global t5_base, t5_expert, t5_antiexpert, t5_tokenizer
    global decompx_model, decompx_tokenizer
    global toxicity_model, toxicity_tokenizer
    global gpt2_model, gpt2_tokenizer
    global bleu_metric, bertscore_metric

    print("="*80)
    print("LOADING MODELS")
    print("="*80)

    # T5 models
    print("Loading T5 base...")
    t5_tokenizer = T5Tokenizer.from_pretrained(T5_BASE_CHECKPOINT)
    t5_base = T5ForConditionalGeneration.from_pretrained(T5_BASE_CHECKPOINT).to(device).eval()

    # Expert/Anti-expert
    if os.path.exists(T5_EXPERT_CHECKPOINT):
        print("Loading T5 expert...")
        t5_expert = T5ForConditionalGeneration.from_pretrained(T5_EXPERT_CHECKPOINT).to(device).eval()
    else:
        print("⚠️ T5 expert not found - using base model")
        t5_expert = t5_base

    if os.path.exists(T5_ANTIEXPERT_CHECKPOINT):
        print("Loading T5 anti-expert...")
        t5_antiexpert = T5ForConditionalGeneration.from_pretrained(T5_ANTIEXPERT_CHECKPOINT).to(device).eval()
    else:
        print("⚠️ T5 anti-expert not found - using base model")
        t5_antiexpert = t5_base

    # DecompX classifier
    print("Loading DecompX classifier...")
    decompx_tokenizer = AutoTokenizer.from_pretrained("martin-ha/toxic-comment-model")
    decompx_model = AutoModelForSequenceClassification.from_pretrained("martin-ha/toxic-comment-model").to(device).eval()

    # Toxicity classifier for evaluation
    print("Loading toxicity classifier...")
    toxicity_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
    toxicity_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert").to(device).eval()

    # GPT-2 for perplexity
    print("Loading GPT-2 for perplexity...")
    gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device).eval()
    gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

    # Evaluation metrics
    bleu_metric = load('bleu')
    bertscore_metric = load('bertscore')

    print("="*80)
    print("✓ ALL MODELS LOADED")
    print("="*80)

# ============================================================================
# DECOMPX MASKING
# ============================================================================
def get_token_toxicity_scores(text):
    """Get per-token toxicity scores using gradient attribution"""
    inputs = decompx_tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Enable gradients on embeddings
    embeddings = decompx_model.get_input_embeddings()(inputs['input_ids'])
    embeddings.requires_grad = True

    outputs = decompx_model(inputs_embeds=embeddings, attention_mask=inputs['attention_mask'])
    probabilities = F.softmax(outputs.logits, dim=-1)
    toxicity_prob = probabilities[0, 1]

    toxicity_prob.backward()
    token_importance = embeddings.grad.abs().sum(dim=-1).squeeze().cpu().numpy()
    tokens = decompx_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())

    return tokens, token_importance, toxicity_prob.item()

def mask_toxic_tokens(text, threshold=0.20):
    """Mask toxic tokens with <extra_id_0> for T5"""
    try:
        tokens, importance, _ = get_token_toxicity_scores(text)

        if len(importance) > 0 and np.mean(importance) > 0:
            normalized = importance / (np.mean(importance) + 1e-8)
        else:
            normalized = importance

        mask_indices = normalized > threshold
        masked_tokens = []
        num_masked = 0

        for token, should_mask in zip(tokens, mask_indices):
            if token in ['<s>', '</s>', '<pad>', '<unk>', '[CLS]', '[SEP]']:
                masked_tokens.append(token)
            elif should_mask:
                masked_tokens.append('<extra_id_0>')
                num_masked += 1
            else:
                masked_tokens.append(token)

        masked_text = decompx_tokenizer.convert_tokens_to_string(masked_tokens)
        masked_text = masked_text.replace('</s>', '').replace('<s>', '').strip()

        return masked_text, num_masked
    except:
        return text, 0

# ============================================================================
# T5 ENSEMBLE GENERATION
# ============================================================================
def generate_with_ensemble(masked_text, alpha_b=1.0, alpha_e=2.5, alpha_a=0.5,
                           max_length=128, temperature=1.0, top_p=0.95, top_k=50):
    """Generate using T5 ensemble with TRUE logits combination"""
    input_text = f"detoxify: {masked_text}"
    input_ids = t5_tokenizer.encode(input_text, return_tensors='pt').to(device)
    generated_ids = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_length):
            # Get logits from all three models
            base_out = t5_base(input_ids=generated_ids)
            expert_out = t5_expert(input_ids=generated_ids)
            anti_out = t5_antiexpert(input_ids=generated_ids)

            # Ensemble combination
            ensemble_logits = (
                alpha_b * base_out.logits[:, -1, :] +
                alpha_e * expert_out.logits[:, -1, :] -
                alpha_a * anti_out.logits[:, -1, :]
            ) / temperature

            # Top-k filtering
            if top_k > 0:
                indices_to_remove = ensemble_logits < torch.topk(ensemble_logits, top_k)[0][..., -1, None]
                ensemble_logits[indices_to_remove] = float('-inf')

            # Top-p filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(ensemble_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                ensemble_logits[indices_to_remove] = float('-inf')

            # Sample next token
            probs = F.softmax(ensemble_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            generated_ids = torch.cat([generated_ids, next_token], dim=1)

            if next_token.item() == t5_tokenizer.eos_token_id:
                break

    text = t5_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    return text[9:].strip() if text.startswith('detoxify:') else text

# ============================================================================
# DECOMPX RERANKING
# ============================================================================
def get_toxicity_importance(text):
    """Get summed token-level toxicity importance for reranking"""
    try:
        _, importance, _ = get_token_toxicity_scores(text)
        return np.sum(np.abs(importance))
    except:
        return float('inf')

def rerank_candidates(candidates):
    """Rerank candidates, return best (lowest toxicity)"""
    if len(candidates) == 1:
        return candidates[0], [0.0]

    scores = [get_toxicity_importance(c) for c in candidates]
    best_idx = np.argmin(scores)
    return candidates[best_idx], scores

# ============================================================================
# EVALUATION
# ============================================================================
def evaluate_toxicity(texts):
    """Evaluate average toxicity"""
    scores = []
    for text in tqdm(texts, desc="Toxicity", leave=False):
        inputs = toxicity_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = toxicity_model(**inputs)
            probs = F.softmax(outputs.logits, dim=-1)
            scores.append(probs[0, 1].item())
    return np.mean(scores)

def evaluate_perplexity(texts):
    """Evaluate average perplexity"""
    total_loss, total_tokens = 0, 0
    for text in tqdm(texts, desc="Perplexity", leave=False):
        inputs = gpt2_tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = gpt2_model(**inputs, labels=inputs['input_ids'])
            total_loss += outputs.loss.item() * inputs['input_ids'].shape[1]
            total_tokens += inputs['input_ids'].shape[1]
    return np.exp(total_loss / total_tokens) if total_tokens > 0 else float('inf')

def evaluate_all(orig_texts, gen_texts):
    """Run all evaluations"""
    print("  Evaluating...")
    results = {}
    results['toxicity_orig'] = evaluate_toxicity(orig_texts)
    results['toxicity_gen'] = evaluate_toxicity(gen_texts)
    results['perplexity_orig'] = evaluate_perplexity(orig_texts)
    results['perplexity_gen'] = evaluate_perplexity(gen_texts)
    results['bleu4'] = bleu_metric.compute(predictions=gen_texts, references=[[t] for t in orig_texts])['bleu']
    results['bertscore'] = np.mean(bertscore_metric.compute(predictions=gen_texts, references=orig_texts, lang='en')['f1'])
    return results

# ============================================================================
# DATA LOADING
# ============================================================================
def load_dataset(data_type, num_examples=None):
    """Load dataset from file"""
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    data_path = data_configs[data_type]["data_path"]

    if not os.path.exists(data_path):
        print(f"⚠️ Dataset not found: {data_path}")
        return ["You are such an idiot.", "This is terrible."] * 50

    if data_path.endswith('.txt'):
        with open(data_path, 'r') as f:
            texts = [line.strip() for line in f if line.strip()]
    elif data_path.endswith('.csv'):
        df = pd.read_csv(data_path)
        col = 'text' if 'text' in df.columns else df.columns[0]
        texts = df[col].tolist()
    else:
        df = pd.read_csv(data_path, sep='\t')
        col = 'text' if 'text' in df.columns else df.columns[0]
        texts = df[col].tolist()

    texts = [str(t).strip() for t in texts if pd.notna(t)]
    return texts[:num_examples] if num_examples else texts

print("✓ Helper functions loaded")

## Training T5 Expert/Anti-Expert Models (Optional)

Run these cells to train expert and anti-expert models on Jigsaw data.

**Requirements:**
- Download `train.csv` from [Kaggle Jigsaw](https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data)
- Upload to: `{DATASET_BASE}/jigsaw/train.csv`

**Training time:** ~2-3 hours per model on A100 GPU

Set `SKIP_TRAINING = False` below to run training.

In [None]:
#@title Jigsaw Data Preparation & Training Code
SKIP_TRAINING = True  # Set to False to actually train

# Jigsaw data paths
JIGSAW_CSV = os.path.join(DATASET_BASE, "jigsaw/train.csv")
JIGSAW_SPLITS_DIR = os.path.join(DATASET_BASE, "jigsaw_splits")

# ============================================================================
# JIGSAW DATA PREPARATION
# ============================================================================
def prepare_jigsaw_splits(jigsaw_csv_path, toxic_threshold=0.5, output_dir=None):
    """
    Split Jigsaw dataset into toxic and non-toxic subsets for training.

    Args:
        jigsaw_csv_path: Path to Jigsaw train.csv from Kaggle
        toxic_threshold: Threshold for toxicity binary classification
        output_dir: Where to save splits (default: DATASET_BASE/jigsaw_splits)

    Returns:
        (toxic_path, nontoxic_path): Paths to created CSV files
    """
    if output_dir is None:
        output_dir = JIGSAW_SPLITS_DIR

    os.makedirs(output_dir, exist_ok=True)

    print(f"Loading Jigsaw data from {jigsaw_csv_path}...")
    df = pd.read_csv(jigsaw_csv_path)

    # Jigsaw has columns: id, comment_text, toxic, severe_toxic, obscene, threat, insult, identity_hate
    # Combine all toxicity columns
    toxicity_cols = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    df['is_toxic'] = (df[toxicity_cols].sum(axis=1) >= toxic_threshold).astype(int)

    # Split
    toxic_df = df[df['is_toxic'] == 1][['comment_text']].copy()
    nontoxic_df = df[df['is_toxic'] == 0][['comment_text']].copy()

    # Rename column
    toxic_df.columns = ['text']
    nontoxic_df.columns = ['text']

    # Save
    toxic_path = os.path.join(output_dir, 'toxic.csv')
    nontoxic_path = os.path.join(output_dir, 'nontoxic.csv')

    toxic_df.to_csv(toxic_path, index=False)
    nontoxic_df.to_csv(nontoxic_path, index=False)

    print(f"✓ Created splits:")
    print(f"  Toxic: {len(toxic_df)} examples → {toxic_path}")
    print(f"  Non-toxic: {len(nontoxic_df)} examples → {nontoxic_path}")

    return toxic_path, nontoxic_path

# ============================================================================
# TRAINING DATASET CLASS
# ============================================================================
class T5DetoxDataset(Dataset):
    """Dataset for T5 detoxification training"""
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx]).strip()

        # Input: "detoxify: <text>"
        input_text = f"detoxify: {text}"
        # Target: same text (self-supervised)
        target_text = text

        # Tokenize
        inputs = self.tokenizer(
            input_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        targets = self.tokenizer(
            target_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].squeeze(),
            'attention_mask': inputs['attention_mask'].squeeze(),
            'labels': targets['input_ids'].squeeze()
        }

# ============================================================================
# TRAINING FUNCTIONS
# ============================================================================
def train_t5_model(data_csv, output_dir, base_model_path, num_train=10000,
                   num_epochs=3, batch_size=8, learning_rate=5e-5):
    """
    Train T5 model on given data.

    Args:
        data_csv: Path to CSV with 'text' column
        output_dir: Where to save trained model
        base_model_path: Path to T5 base model to start from
        num_train: Number of training examples (default 10k)
        num_epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
    """
    print(f"\n{'='*80}")
    print(f"Training T5 model → {output_dir}")
    print(f"{'='*80}")

    # Load data
    df = pd.read_csv(data_csv)
    texts = df['text'].tolist()[:num_train]
    print(f"Loaded {len(texts)} training examples")

    # Split train/val
    train_texts, val_texts = train_test_split(texts, test_size=0.1, random_state=42)
    print(f"Train: {len(train_texts)}, Val: {len(val_texts)}")

    # Load tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(base_model_path)
    model = T5ForConditionalGeneration.from_pretrained(base_model_path)

    # Create datasets
    train_dataset = T5DetoxDataset(train_texts, tokenizer)
    val_dataset = T5DetoxDataset(val_texts, tokenizer)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        weight_decay=0.01,
        logging_steps=100,
        eval_steps=500,
        save_steps=500,
        evaluation_strategy="steps",
        save_total_limit=2,
        fp16=torch.cuda.is_available(),
        load_best_model_at_end=True,
        metric_for_best_model="eval_loss",
        report_to="none"
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator
    )

    # Train
    print("\nStarting training...")
    trainer.train()

    # Save
    print(f"\nSaving to {output_dir}...")
    trainer.save_model(output_dir)
    tokenizer.save_pretrained(output_dir)

    print(f"✓ Training complete: {output_dir}")

# ============================================================================
# MAIN TRAINING EXECUTION
# ============================================================================
if not SKIP_TRAINING:
    # Step 1: Prepare Jigsaw splits
    if os.path.exists(JIGSAW_CSV):
        print(f"Found Jigsaw dataset at {JIGSAW_CSV}")
        toxic_path, nontoxic_path = prepare_jigsaw_splits(JIGSAW_CSV)
    else:
        print(f"⚠️ Jigsaw dataset not found at {JIGSAW_CSV}")
        print("Please download train.csv from Kaggle and upload to Google Drive")
        print("https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge/data")
        toxic_path = os.path.join(JIGSAW_SPLITS_DIR, 'toxic.csv')
        nontoxic_path = os.path.join(JIGSAW_SPLITS_DIR, 'nontoxic.csv')

    # Step 2: Train Anti-Expert (on toxic data)
    if os.path.exists(toxic_path):
        print("\n" + "="*80)
        print("TRAINING ANTI-EXPERT (on toxic data)")
        print("="*80)
        train_t5_model(
            data_csv=toxic_path,
            output_dir=T5_ANTIEXPERT_CHECKPOINT,
            base_model_path=T5_BASE_CHECKPOINT,
            num_train=10000,
            num_epochs=3,
            batch_size=8
        )
    else:
        print(f"⚠️ Toxic data not found at {toxic_path}")

    # Step 3: Train Expert (on non-toxic data)
    if os.path.exists(nontoxic_path):
        print("\n" + "="*80)
        print("TRAINING EXPERT (on non-toxic data)")
        print("="*80)
        train_t5_model(
            data_csv=nontoxic_path,
            output_dir=T5_EXPERT_CHECKPOINT,
            base_model_path=T5_BASE_CHECKPOINT,
            num_train=10000,
            num_epochs=3,
            batch_size=8
        )
    else:
        print(f"⚠️ Non-toxic data not found at {nontoxic_path}")

    print("\n" + "="*80)
    print("✓ ALL TRAINING COMPLETE")
    print("="*80)
else:
    print("SKIP_TRAINING = True, skipping training cells")
    print("Models will be loaded from existing checkpoints")


In [None]:
#@title Main T5-XDetox Pipeline Function
def t5_detoxify(data_type="paradetox",
                thresholds=[0.10, 0.15, 0.20, 0.25, 0.30],
                alpha_b=1.0,
                alpha_e=None,  # Will use data_config default if None
                alpha_a=None,  # Will use data_config default if None
                temperature=None,  # Will use data_config default if None
                num_examples=None,
                ranking=True,
                num_candidates=5,
                output_folder="default",
                max_length=128,
                top_p=0.95,
                top_k=50):
    """
    Run T5-XDetox pipeline: DecompX masking → T5 ensemble → DecompX reranking

    Args:
        data_type: Dataset name from data_configs
        thresholds: List of DecompX masking thresholds
        alpha_b: Base model weight (default 1.0)
        alpha_e: Expert weight (uses data_config if None)
        alpha_a: Anti-expert weight (uses data_config if None)
        temperature: Sampling temperature (uses data_config if None)
        num_examples: Limit number of examples (None = all)
        ranking: Whether to use DecompX reranking
        num_candidates: Number of candidates for reranking
        output_folder: Output subdirectory name
        max_length: Max generation length
        top_p: Nucleus sampling parameter
        top_k: Top-k sampling parameter

    Returns:
        results_df: DataFrame with all results
    """
    print("\n" + "="*80)
    print(f"T5-XDETOX PIPELINE: {data_type}")
    print("="*80)

    # Get config
    if data_type not in data_configs:
        raise ValueError(f"Unknown data_type: {data_type}")

    config = data_configs[data_type]
    alpha_e = alpha_e if alpha_e is not None else config['alpha_e']
    alpha_a = alpha_a if alpha_a is not None else config['alpha_a']
    temperature = temperature if temperature is not None else config['temperature']

    print(f"Config: alpha_b={alpha_b}, alpha_e={alpha_e}, alpha_a={alpha_a}, temp={temperature}")
    print(f"Thresholds: {thresholds}")
    print(f"Ranking: {ranking} (candidates={num_candidates})")

    # Load data
    print(f"\nLoading dataset: {data_type}")
    texts = load_dataset(data_type, num_examples)
    print(f"Loaded {len(texts)} examples")

    # Output directory
    output_base = os.path.join(PROJECT_BASE, "data/t5_xdetox_outputs", output_folder, data_type)
    os.makedirs(output_base, exist_ok=True)

    # Results storage
    all_results = []

    # Process each threshold
    for threshold in thresholds:
        print(f"\n{'='*80}")
        print(f"Processing threshold: {threshold}")
        print(f"{'='*80}")

        # Output paths
        threshold_str = f"DecompX{threshold}"
        output_dir = os.path.join(output_base, threshold_str)
        os.makedirs(output_dir, exist_ok=True)

        masked_file = os.path.join(output_dir, "masked.txt")
        output_file = os.path.join(output_dir, "detoxified.txt")

        # Step 1: Mask toxic tokens
        print(f"\n[1/3] Masking toxic tokens (threshold={threshold})...")
        masked_texts = []
        for text in tqdm(texts, desc="Masking"):
            masked, num_masked = mask_toxic_tokens(text, threshold=threshold)
            masked_texts.append(masked)

        # Save masked
        with open(masked_file, 'w') as f:
            f.write('\n'.join(masked_texts))
        print(f"  Saved: {masked_file}")

        # Step 2: Generate with T5 ensemble
        print(f"\n[2/3] Generating with T5 ensemble...")
        generated_texts = []

        for masked_text in tqdm(masked_texts, desc="Generating"):
            if ranking:
                # Generate multiple candidates
                candidates = []
                for _ in range(num_candidates):
                    gen = generate_with_ensemble(
                        masked_text,
                        alpha_b=alpha_b,
                        alpha_e=alpha_e,
                        alpha_a=alpha_a,
                        temperature=temperature,
                        max_length=max_length,
                        top_p=top_p,
                        top_k=top_k
                    )
                    candidates.append(gen)

                # Step 3: Rerank candidates
                best, scores = rerank_candidates(candidates)
                generated_texts.append(best)
            else:
                # No reranking - single generation
                gen = generate_with_ensemble(
                    masked_text,
                    alpha_b=alpha_b,
                    alpha_e=alpha_e,
                    alpha_a=alpha_a,
                    temperature=temperature,
                    max_length=max_length,
                    top_p=top_p,
                    top_k=top_k
                )
                generated_texts.append(gen)

        # Save outputs
        with open(output_file, 'w') as f:
            f.write('\n'.join(generated_texts))
        print(f"  Saved: {output_file}")

        # Step 4: Evaluate
        metrics = evaluate_all(texts, generated_texts)

        # Store results
        result = {
            'data_type': data_type,
            'threshold': threshold,
            'alpha_b': alpha_b,
            'alpha_e': alpha_e,
            'alpha_a': alpha_a,
            'temperature': temperature,
            'ranking': ranking,
            'num_candidates': num_candidates if ranking else 1,
            **metrics
        }
        all_results.append(result)

        print(f"\n  Results:")
        print(f"    Toxicity: {metrics['toxicity_orig']:.4f} → {metrics['toxicity_gen']:.4f}")
        print(f"    Perplexity: {metrics['perplexity_orig']:.2f} → {metrics['perplexity_gen']:.2f}")
        print(f"    BLEU: {metrics['bleu4']:.4f}")
        print(f"    BERTScore: {metrics['bertscore']:.4f}")

    # Save summary
    results_df = pd.DataFrame(all_results)
    summary_file = os.path.join(output_base, "summary.csv")
    results_df.to_csv(summary_file, index=False)
    print(f"\n{'='*80}")
    print(f"✓ PIPELINE COMPLETE")
    print(f"  Summary saved: {summary_file}")
    print(f"{'='*80}")

    return results_df

print("✓ Main pipeline function loaded")


## Example Run

Below are examples of running the T5-XDetox pipeline.

**Small test:** Quick test on 50 examples
**Full run:** Complete evaluation on entire dataset

In [None]:
#@title Example: Run T5-XDetox Pipeline
# Load models first
load_models()

# Quick test on small subset
print("\nRunning small test (50 examples)...")
results_test = t5_detoxify(
    data_type="paradetox",
    thresholds=[0.15, 0.20],
    num_examples=50,
    ranking=True,
    num_candidates=3,
    output_folder="test_run"
)

print("\nTest Results:")
print(results_test)

# Uncomment below for full evaluation:
# results_full = t5_detoxify(
#     data_type="paradetox",
#     thresholds=[0.10, 0.15, 0.20, 0.25, 0.30],
#     ranking=True,
#     num_candidates=5,
#     output_folder="full_eval"
# )
