In [None]:
!pip install transformers sentencepiece sacrebleu

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Import necessary libraries

In [None]:
import os
from tqdm.auto import tqdm

import pandas as pd
import torch
from tqdm import tqdm
from transformers import (
    BartForConditionalGeneration,
    BartTokenizer,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    T5ForConditionalGeneration,
    T5Tokenizer,
)

We are using [ParaDetox](https://github.com/s-nlp/paradetox) dataset.

In [None]:
!gdown --id 16GHCKuILS6bj4h2jSlnsH9jw2501IulX

Downloading...
From: https://drive.google.com/uc?id=16GHCKuILS6bj4h2jSlnsH9jw2501IulX
To: /content/paradetox.csv
100% 2.06M/2.06M [00:00<00:00, 139MB/s]


In [None]:
dataset = pd.read_csv('paradetox.csv')

dataset.head()

Unnamed: 0,en_toxic_comment,en_neutral_comment
0,he had steel balls too !,he was brave too!
1,"dude should have been taken to api , he would ...",It would have been good if he went to api. He ...
2,"im not gonna sell the fucking picture , i just...","I'm not gonna sell the picture, i just want to..."
3,the garbage that is being created by cnn and o...,the news that is being created by cnn and othe...
4,the reason they dont exist is because neither ...,The reason they don't exist is because neither...


Error: Runtime no longer has a reference to this dataframe, please re-run this cell and try again.
Error: Runtime no longer has a reference to this dataframe, please re-run this cell and try again.


In [None]:
class DetoxDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer
        self.tokenizer.src_lang = "en_XX"
        self.tokenizer.tgt_lang = "en_XX"

    def __getitem__(self, idx):

        source = self.tokenizer(
            self.data.iloc[idx].en_toxic_comment,
            max_length=150,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        target = self.tokenizer(
            self.data.iloc[idx].en_neutral_comment,
            max_length=150,
            pad_to_max_length=True,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        source["labels"] = target["input_ids"]

        return {k: v.squeeze(0) for k, v in source.items()}

    def __len__(self):
        return self.data.shape[0]

Utility function for detoxification generation.

In [None]:
def paraphrase(
    text,
    model,
    tokenizer,
    n=None,
    max_length="auto",
    beams=5,
):
    texts = [text] if isinstance(text, str) else text
    inputs = tokenizer(texts, return_tensors="pt", padding=True)["input_ids"].to(
        model.device
    )

    if max_length == "auto":
        max_length = inputs.shape[1] + 10

    result = model.generate(
        inputs,
        num_return_sequences=n or 1,
        do_sample=False,
        temperature=1.0,
        repetition_penalty=10.0,
        max_length=max_length,
        min_length=int(0.5 * max_length),
        num_beams=beams,
        # forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
    )
    texts = [tokenizer.decode(r, skip_special_tokens=True) for r in result]

    if not n and isinstance(text, str):
        return texts[0]
    return texts[0]

Load BART model.

In [None]:
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small")
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

In [None]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(dataset, random_state=42, test_size=0.01)
trainset = DetoxDataset(train, tokenizer)
valset = DetoxDataset(val, tokenizer)

Define training arguments.

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="bart_detox",
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,  # 8 is too much
    weight_decay=1e-5,
    num_train_epochs=1, # use 3 or 5 epochs here
    learning_rate=1e-5,
    evaluation_strategy="steps",
    save_strategy="no",
    save_total_limit=1,
    logging_steps=500,
    gradient_accumulation_steps=1,
)

Define Trainer from Huggingface Transformers. Since we are dealing with seq2seq task, we are using corresponding Trainer. 

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=trainset,
    eval_dataset=valset,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()



Step,Training Loss,Validation Loss
500,14.2885,3.570108


KeyboardInterrupt: ignored

Let's get test data and generate detoxifications. 

In [None]:
!gdown --id 16VDvvra8joR3MLcx9Om05ET_qzq4mbt-
!gdown --id 1Wp0O3YzeXrGHKkznDzpcuIgQXrxrBBNK

Downloading...
From: https://drive.google.com/uc?id=16VDvvra8joR3MLcx9Om05ET_qzq4mbt-
To: /content/test_toxic_parallel.txt
100% 41.0k/41.0k [00:00<00:00, 121MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Wp0O3YzeXrGHKkznDzpcuIgQXrxrBBNK
To: /content/test_neutral_parallel.txt
100% 35.2k/35.2k [00:00<00:00, 104MB/s]


Generate detoxifications and save them in a .txt file

In [None]:
test_inputs = open("test_toxic_parallel.txt", "r").read().split("\n")[:20]
preds = []
for text in tqdm(test_inputs):
    preds.append(paraphrase(text, model, tokenizer))

with open(f"predictions.txt", "w") as f:
    f.write("\n".join(preds))

100%|██████████| 20/20 [00:09<00:00,  2.17it/s]


In [None]:
pd.DataFrame({'text': test_inputs, 'preds': preds})

Unnamed: 0,text,preds
0,. or the loud ass one - thousand ton beast roa...,"ing your lungs.' is the best thing you can do,..."
1,"mandated and "" right fucking now "" would be ...",DOESN'T GO HERE. DO SHOULDN'T BETWIRE TOMORROW
2,* neither * of my coworkers gave a shit when i...,Clickfunnelsity is the best way to do it. If y...
3,"* well shit , cunt shot himself .","Whist, cunt shot himself."
4,&gt i wouldn 't care how ignorant you are in y...,s aren't what you have in your life. They do t...
5,&gt today was one of the most fucked up days o...,'stucky day of my life is the best way to do i...
6,"<number> , people will freak the fuck out .",hide it.'s the people will freak them out.
7,3 . you don 't fuck with a mommas boy,You don't fuck with a mommas.
8,6. are you fucking insane ?,I am you fucking insane.
9,abortion is not a fundamental right it is a pr...,abortion is not a fundamentally right. It is a...
