In [None]:
# Necessary inputs
import warnings

from datasets import load_dataset, load_metric, concatenate_datasets
import transformers
import datasets
import random
import pandas as pd
from IPython.display import display, HTML

warnings.filterwarnings("ignore")

In [None]:
# selecting model checkpoint
model_checkpoint = "t5-small"

In [None]:
transformers.set_seed(42)

raw_datasets = load_dataset("../data/interim/dataset")
metric = load_metric("sacrebleu")

In [None]:
synonym_dataset = load_dataset("synonyms")  # add synonyms dataset

In [None]:
merged_train_dataset = concatenate_datasets(
    [raw_datasets["train"], synonym_dataset["train"]]
)
merged_val_dataset = concatenate_datasets(
    [raw_datasets["validation"], synonym_dataset["validation"]]
)

In [None]:
from transformers import AutoTokenizer

# we will use autotokenizer for this purpose
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# prefix for model input
prefix = "make sentence non-toxic:"

In [None]:
max_input_length = 256
max_target_length = 256
toxic = "source"
non_toxic = "target"


def preprocess_function(examples):
    inputs = [prefix + ex if ex else " " for ex in examples[toxic]]
    targets = [ex if ex else " " for ex in examples[non_toxic]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)

    # Setup the tokenizer for targets
    labels = tokenizer(targets, max_length=max_target_length, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_train = merged_train_dataset.map(preprocess_function, batched=True)

In [None]:
tokenized_validation = merged_val_dataset.map(preprocess_function, batched=True)

# Fine-tuning the model


In [None]:
from transformers import (
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

# create a model for the pretrained model
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [None]:
# defining the parameters for training
batch_size = 32
model_name = model_checkpoint.split("/")[-1]
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-detoxify",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    report_to="tensorboard",
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
import numpy as np


# simple postprocessing for text
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]

    return preds, labels


# compute metrics function to pass to trainer
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [None]:
# instead of writing train loop we will use Seq2SeqTrainer
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_validation,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,No log,2.485242,21.8933,10.7129
2,2.665900,2.382524,23.3744,10.9589
3,2.665900,2.337265,23.6036,10.9794
4,2.337600,2.312074,23.9011,11.0286
5,2.337600,2.292755,23.842,11.0626
6,2.259800,2.283076,23.9714,11.0528
7,2.228900,2.275079,23.8868,11.0259
8,2.228900,2.269132,23.8734,11.0206
9,2.188100,2.265964,23.8728,11.0161
10,2.188100,2.265641,23.872,11.0188


TrainOutput(global_step=2960, training_loss=2.3142215007060285, metrics={'train_runtime': 446.5402, 'train_samples_per_second': 212.12, 'train_steps_per_second': 6.629, 'total_flos': 1106994429689856.0, 'train_loss': 2.3142215007060285, 'epoch': 10.0})

In [None]:
# saving model
trainer.save_model("../models/t5-small-ft2")

In [None]:
# loading the model and run inference for it
model = AutoModelForSeq2SeqLM.from_pretrained("../models/t5-small-ft2")
model.eval()
model.config.use_cache = False

In [None]:
import torch
import pandas
from tqdm import tqdm

In [None]:
def test(model, tokenizer=tokenizer, batch_size=100):
    res = pd.DataFrame({"source": raw_datasets["test"][toxic]})
    model_res = []
    test_data = raw_datasets["test"]

    for i in tqdm(range(0, len(test_data), batch_size)):
        batch = test_data[i : i + batch_size]
        input_texts = [prefix + line for line in batch[toxic]]

        input_ids = tokenizer(
            input_texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=128,
        ).input_ids
        outputs = model.generate(input_ids=input_ids)

        decoded_outputs = [
            tokenizer.decode(output, skip_special_tokens=True) for output in outputs
        ]
        model_res.extend(decoded_outputs)

    res["target"] = model_res
    return res

In [None]:
res = test(model, tokenizer)
res.head()

100%|██████████| 30/30 [04:21<00:00,  8.72s/it]


Unnamed: 0,source,target
0,and you think grandpa is gonna protect us from...,and you think grandpa is gonna protect us from...
1,might i add very clever assholes,i m sure i m going to add some clever tricks
2,i hate dickheads,i hate dickheads
3,jason put down that stupid camera and come hel...,jason put down that camera and come help me
4,what a scumbag,what a scumbag


In [None]:
res.to_csv("t5-ft2-detoxify.csv")