# Text detoxification using fine-tunned BART


In [1]:
import warnings

import evaluate
import numpy as np
import pandas as pd
import torch
from datasets import Dataset
from torch.utils.data import random_split
from tqdm import tqdm
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)

In [2]:
MANUAL_SEED = 42

warnings.filterwarnings("ignore")

## Data loading and preprocessing


In [3]:
df = pd.read_csv("../data/raw/dataset_xs.csv")
print(f"{len(df)=}")
df.head()

len(df)=9441


Unnamed: 0,toxic,nontoxic
0,I like that shit.,I love it.
1,"Now, I understand you got your grievances with...","I understand you don't have to cut your bills,..."
2,Damn It!,"oh, my God."
3,"Help me, you cunt!","Aitchi, help me!"
4,Look at that shit.,look at this.


In [4]:
test_df = pd.read_csv("../data/raw/test.csv")
print(f"{len(test_df)=}")
test_df.head()

len(test_df)=500


Unnamed: 0,toxic,nontoxic
0,It's feeding time at the fucking zoo!,it's time to eat at the zoo!
1,Everyone here bet on the hero and lost their a...,they all took a hero and lost everything.
2,Then I got to come home to Melvin and his bull...,then I'm going home and Melvin's there.
3,Sara here was hoping to pick your brains.,Sara was hoping you could handle her.
4,"Oh, that's stupid. If anyone wants to tell me ...","if anyone wants to tell me what's going on, I'..."


In [5]:
train_seq, val_seq = random_split(
    range(len(df)),  # type: ignore
    [0.9, 0.1],
    generator=torch.Generator().manual_seed(MANUAL_SEED),
)
train_indices, val_indices = (
    list(train_seq.indices),
    list(val_seq.indices),
)
print(f"{len(train_indices)=}")
print(f"{len(val_indices)=}")

len(train_indices)=8497
len(val_indices)=944


## Model definition


In [8]:
checkpoint = "eugenesiow/bart-paraphrase"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

Downloading (…)okenizer_config.json:   0%|          | 0.00/332 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.69k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

In [6]:
BATCH_SIZE = 16

PREFIX = "paraphrase following to be nontoxic: \n"
MAX_LENGTH = 128

In [7]:
def preprocess_function(data):
    inputs = [PREFIX + data_point for data_point in data["toxic"]]
    targets = data["nontoxic"]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=MAX_LENGTH, truncation=True
    )
    return model_inputs


def post_process_text(predictions, targets):
    predictions = [pred.strip() for pred in predictions]
    targets = [label.strip() for label in targets]
    return predictions, targets


BLEU_METRIC = evaluate.load("bleu")


def compute_metrics(batch):
    predictions, targets = batch
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    targets = np.where(targets != -100, targets, tokenizer.pad_token_id)  # type: ignore
    decoded_targets = tokenizer.batch_decode(targets, skip_special_tokens=True)

    decoded_predictions, decoded_targets = post_process_text(
        decoded_predictions, decoded_targets
    )

    result = {}
    metrics = BLEU_METRIC.compute(
        predictions=decoded_predictions, references=decoded_targets
    )
    if metrics is not None:
        result.update({"bleu": metrics["bleu"]})

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [12]:
train_dataset = Dataset.from_pandas(df.iloc[train_indices]).map(
    preprocess_function, batched=True
)
val_dataset = Dataset.from_pandas(df.iloc[val_indices]).map(
    preprocess_function, batched=True
)

  0%|          | 0/9 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

## Train model


In [14]:
training_args = Seq2SeqTrainingArguments(
    output_dir="../models/train_data/bart",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # type: ignore
    eval_dataset=val_dataset,  # type: ignore
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [15]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,No log,1.230303,0.4144,10.4544
2,1.207600,1.233634,0.4246,10.8294
3,1.207600,1.27108,0.4617,10.8093
4,0.595100,1.374489,0.4549,10.1621
5,0.595100,1.461647,0.4666,10.5869
6,0.330200,1.599013,0.4655,10.1578
7,0.330200,1.717306,0.4812,10.0371
8,0.197900,1.757943,0.4766,10.1949
9,0.197900,1.839077,0.4769,10.3581
10,0.124500,1.887974,0.484,10.0763


TrainOutput(global_step=2660, training_loss=0.4673067569732666, metrics={'train_runtime': 2831.7826, 'train_samples_per_second': 30.006, 'train_steps_per_second': 0.939, 'total_flos': 6903838671986688.0, 'train_loss': 0.4673067569732666, 'epoch': 10.0})

In [16]:
trainer.save_model("../models/bart")

## Test model


I save my model to Hugging Face Hub, because source files are quite big (about 1.6G)

In [30]:
def save_to_hub(model, tokenizer, name: str = "pmldl1-bart"):
    from huggingface_hub import login

    login(token="...")
    tokenizer.push_to_hub(name)
    model.push_to_hub(name)


# save_to_hub(AutoModelForSeq2SeqLM.from_pretrained("../models/bart"), AutoTokenizer.from_pretrained("eugenesiow/bart-paraphrase"))

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

In [8]:
my_model = "dsomni/pmldl1-bart"

model = AutoModelForSeq2SeqLM.from_pretrained(my_model)
tokenizer = AutoTokenizer.from_pretrained(my_model)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.69k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/540 [00:00<?, ?B/s]

In [14]:
model.eval()
model.config.use_cache = False

In [15]:
def detoxify(model, tokenizer, prompt: str) -> str:
    inference_request = PREFIX + prompt
    input_ids = tokenizer(inference_request, return_tensors="pt").input_ids
    outputs = model.generate(input_ids=input_ids)
    return tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0)


detoxify(model, tokenizer, "shut up, man")

'quiet, man.'

In [35]:
model_answers = []
for i, r in tqdm(test_df.iterrows(), total=len(test_df)):
    model_answers.append(detoxify(model, tokenizer, r["toxic"]))


test_df["generated"] = model_answers

100%|██████████| 500/500 [14:22<00:00,  1.73s/it]


In [40]:
test_df.head()

Unnamed: 0,toxic,nontoxic,generated
0,It's feeding time at the fucking zoo!,it's time to eat at the zoo!,it's feeding time at the zoo.
1,Everyone here bet on the hero and lost their a...,they all took a hero and lost everything.,everyone here bet on him and they lost.
2,Then I got to come home to Melvin and his bull...,then I'm going home and Melvin's there.,then I have to come home to talk to Melvin.
3,Sara here was hoping to pick your brains.,Sara was hoping you could handle her.,Sara here was hoping to pick your brains.
4,"Oh, that's stupid. If anyone wants to tell me ...","if anyone wants to tell me what's going on, I'...","if anyone wants to talk to me, I'll be in the ..."


## Save results


In [41]:
test_df.to_csv("../data/generated/bart.csv", index=False)