In [6]:
# 0. Install dependencies if not already installed
# !pip install datasets transformers detoxify torch pandas tqdm

import torch
from datasets import load_dataset, concatenate_datasets
from transformers import (
    T5Tokenizer, T5ForConditionalGeneration,
    Seq2SeqTrainingArguments, Seq2SeqTrainer,
    DataCollatorForSeq2Seq, EarlyStoppingCallback
)
from torch.utils.data import DataLoader
from detoxify import Detoxify
import pandas as pd
from tqdm import tqdm

# 1. Load datasets
print("Loading datasets...")
paradetox = load_dataset("textdetox/multilingual_paradetox", split="en")
snlp = load_dataset("s-nlp/paradetox", split="train")

# 2. Preprocess to extract English input-output pairs
def preprocess_paradetox(example):
    return {
        "input_text": "detoxify: " + example["toxic_sentence"],
        "target_text": example["neutral_sentence"]
    }

def preprocess_snlp(example):
    return {
        "input_text": "detoxify: " + example["en_toxic_comment"],
        "target_text": example["en_neutral_comment"]
    }

paradetox = paradetox.map(preprocess_paradetox)
snlp = snlp.map(preprocess_snlp)

# 3. Keep only required columns
def clean_columns(ds):
    return ds.remove_columns([col for col in ds.column_names if col not in ["input_text", "target_text"]])

paradetox = clean_columns(paradetox)
snlp = clean_columns(snlp)

# 4. Combine and filter
combined = concatenate_datasets([paradetox, snlp])
combined = combined.filter(lambda x: x["input_text"] is not None and x["target_text"] is not None)

# 5. Tokenize
tokenizer = T5Tokenizer.from_pretrained("t5-small")

def tokenize(example):
    inputs = tokenizer(example["input_text"], padding="max_length", truncation=True, max_length=128)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example["target_text"], padding="max_length", truncation=True, max_length=128)
    labels["input_ids"] = [
        [(tok if tok != tokenizer.pad_token_id else -100) for tok in seq]
        for seq in labels["input_ids"]
    ]
    inputs["labels"] = labels["input_ids"]
    return inputs

tokenized = combined.map(tokenize, batched=True)

# 6. Split into train/test
split = tokenized.train_test_split(test_size=0.1)
train_dataset = split["train"]
eval_dataset = split["test"]

# Keep original test for evaluation
original_eval = combined.train_test_split(test_size=0.1)["test"]

# 7. Load model
model = T5ForConditionalGeneration.from_pretrained("t5-small")
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))



Loading datasets...


T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [4]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import Seq2SeqTrainer
import torch
import torch.nn.functional as F

class ToxicityPenaltyTrainer(Seq2SeqTrainer):
    def __init__(self, *args, lambda_penalty=0.5, **kwargs):
        super().__init__(*args, **kwargs)
        self.lambda_penalty = lambda_penalty

        # Load fast toxicity classifier
        self.tox_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
        self.tox_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert").to(self.model.device)
        self.tox_model.eval()

    def compute_loss(self, model, inputs, return_outputs=False):
        # Standard generation loss
        outputs = model(**inputs)
        generation_loss = outputs.loss

        # Decode model predictions
        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=50,
            num_beams=1,
            early_stopping=True,
            decoder_start_token_id=self.tokenizer.pad_token_id
        )
        decoded_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        # Tokenize generated text for classifier
        tox_inputs = self.tox_tokenizer(decoded_texts, return_tensors="pt", truncation=True, padding=True).to(self.model.device)

        with torch.no_grad():
            tox_logits = self.tox_model(**tox_inputs).logits
            tox_probs = torch.sigmoid(tox_logits[:, 0])  # binary: higher means more toxic

        # Average toxicity as penalty
        penalty = tox_probs.mean()

        # Combine losses
        total_loss = generation_loss + self.lambda_penalty * penalty
        return (total_loss, outputs) if return_outputs else total_loss

In [5]:
# 8. Training config
args = Seq2SeqTrainingArguments(
    output_dir="./mt5-detox",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=5,
    predict_with_generate=True,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none",
    fp16=torch.cuda.is_available()
)

trainer = ToxicityPenaltyTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2)],
    lambda_penalty=0.5  # Tune this value!
)

# 9. Train
print("Training model...")
trainer.train()

# 10. Generate detoxified outputs
def collate_fn(batch):
    texts = [ex["input_text"] for ex in batch]
    return tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128)

loader = DataLoader(original_eval, batch_size=16, collate_fn=collate_fn)
model.eval()

detoxified_outputs = []
input_texts = []
reference_texts = []

print("Generating detoxified outputs...")
for i, batch in enumerate(tqdm(loader)):
    batch = {k: v.to(model.device) for k, v in batch.items()}
    outputs = model.generate(**batch, max_length=50, num_beams=4, early_stopping=True, decoder_start_token_id=tokenizer.pad_token_id)
    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    detoxified_outputs.extend(decoded)
    for j in range(len(decoded)):
        idx = i * 16 + j
        if idx < len(original_eval):
            input_texts.append(original_eval[idx]["input_text"])
            reference_texts.append(original_eval[idx]["target_text"])

# 11. Evaluate with Detoxify
print("Scoring toxicity before and after...")
tox_model = Detoxify("unbiased")
tox_before = tox_model.predict(input_texts)["toxicity"]
tox_after = tox_model.predict(detoxified_outputs)["toxicity"]

# 12. Save and summarize
df = pd.DataFrame({
    "toxic_input": input_texts,
    "reference_output": reference_texts,
    "model_output": detoxified_outputs,
    "toxicity_before": tox_before,
    "toxicity_after": tox_after
})

df.to_csv("detoxified_results.csv", index=False)
print("\n✅ Results saved to 'detoxified_results.csv'")

# Summary
print(f"\nAverage Toxicity Before: {sum(tox_before)/len(tox_before):.4f}")
print(f"Average Toxicity After:  {sum(tox_after)/len(tox_after):.4f}")
print(f"Average Reduction:       {sum(t - d for t, d in zip(tox_before, tox_after))/len(tox_after):.4f}")


ValueError: --load_best_model_at_end requires the save and eval strategy to match, but found
- Evaluation strategy: IntervalStrategy.NO
- Save strategy: SaveStrategy.EPOCH