# Fine tunning BART


In [1]:
from transformers import (
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    AutoTokenizer,
)
import evaluate
from datasets import Dataset
import numpy as np
import pandas as pd
import torch
from torch.utils.data import random_split

from tqdm import tqdm

import warnings

In [2]:
MANUAL_SEED = 42

warnings.filterwarnings("ignore")

## Data loading and preprocessing


In [3]:
df = pd.read_csv("../data/raw/dataset_xs.csv")
print(f"{len(df)=}")
df.head()

len(df)=9462


Unnamed: 0,toxic,nontoxic
0,I like that shit.,I love it.
1,"Now, I understand you got your grievances with...","I understand you don't have to cut your bills,..."
2,Damn It!,"oh, my God."
3,"Help me, you cunt!","Aitchi, help me!"
4,Look at that shit.,look at this.


In [13]:
train_seq, val_seq, test_seq = random_split(
    range(len(df)),  # type: ignore
    [0.85, 0.1, 0.05],
    generator=torch.Generator().manual_seed(MANUAL_SEED),
)
train_indices, val_indices, test_indices = (
    list(train_seq.indices),
    list(val_seq.indices),
    list(test_seq.indices),
)
print(f"{len(train_indices)=}")
print(f"{len(val_indices)=}")
print(f"{len(test_indices)=}")

len(train_indices)=8043
len(val_indices)=946
len(test_indices)=473


## Model definition


In [11]:
checkpoint = "eugenesiow/bart-paraphrase"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [9]:
BATCH_SIZE = 16

PREFIX = "paraphrase following to be nontoxic: \n"
MAX_LENGTH = 128

In [7]:
def preprocess_function(data):
    inputs = [PREFIX + data_point for data_point in data["toxic"]]
    targets = data["nontoxic"]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=MAX_LENGTH, truncation=True
    )
    return model_inputs


def post_process_text(predictions, targets):
    predictions = [pred.strip() for pred in predictions]
    targets = [label.strip() for label in targets]
    return predictions, targets


BLEU_METRIC = evaluate.load("bleu")


def compute_metrics(batch):
    predictions, targets = batch
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    targets = np.where(targets != -100, targets, tokenizer.pad_token_id)  # type: ignore
    decoded_targets = tokenizer.batch_decode(targets, skip_special_tokens=True)

    decoded_predictions, decoded_targets = post_process_text(
        decoded_predictions, decoded_targets
    )

    result = {}
    metrics = BLEU_METRIC.compute(
        predictions=decoded_predictions, references=decoded_targets
    )
    if metrics is not None:
        result.update({"bleu": metrics["bleu"]})

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [14]:
train_dataset = Dataset.from_pandas(df.iloc[train_indices]).map(
    preprocess_function, batched=True
)
val_dataset = Dataset.from_pandas(df.iloc[val_indices]).map(
    preprocess_function, batched=True
)

Map:   0%|          | 0/8043 [00:00<?, ? examples/s]

Map:   0%|          | 0/946 [00:00<?, ? examples/s]

## Train model


In [12]:
training_args = Seq2SeqTrainingArguments(
    output_dir="../models/train_data/bart",
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=10,
    predict_with_generate=True,
    report_to="none",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,  # type: ignore
    eval_dataset=val_dataset,  # type: ignore
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss,Bleu,Gen Len
1,1.4597,1.398838,0.399,12.2008
2,0.9515,1.331711,0.3942,9.9725
3,0.6794,1.351861,0.4393,10.7505
4,0.4703,1.496082,0.4403,11.6934
5,0.3357,1.632915,0.4496,11.0307
6,0.2459,1.775123,0.4539,9.9873
7,0.1844,1.883072,0.4592,10.6353
8,0.1392,1.965904,0.4503,10.9641
9,0.1048,2.019606,0.4577,10.5159
10,0.0791,2.16027,0.4566,11.4704


TrainOutput(global_step=5030, training_loss=0.4625845614768876, metrics={'train_runtime': 3270.0551, 'train_samples_per_second': 24.596, 'train_steps_per_second': 1.538, 'total_flos': 5941157476638720.0, 'train_loss': 0.4625845614768876, 'epoch': 10.0})

In [None]:
trainer.save_model("../models/bart")

## Test model


In [15]:
model = AutoModelForSeq2SeqLM.from_pretrained("../models/bart")
model.eval()
model.config.use_cache = False

In [16]:
test_df = df.iloc[test_indices]
test_df.head()

Unnamed: 0,toxic,nontoxic
7898,You know how the Romans settled this shit?,do you know how the Romans handled this?
5283,What kind of shit do you talk?,what are you talking about?
6611,Idiot. There's a storage shed near the back.,there's a warehouse in the back.
3187,Come get her! Goddamn you!,come and get her!
8466,"You fucker Hello, post office?","hello, post office?"


In [17]:
def detoxify(model, prompt):
    inference_request = PREFIX + prompt
    input_ids = tokenizer(inference_request, return_tensors="pt").input_ids
    outputs = model.generate(input_ids=input_ids)
    return tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0)


detoxify(model, "shut up, man")

'quiet, man.'

In [None]:
model_answers = []
for i, r in tqdm(test_df.iterrows(), total=len(test_df)):
    model_answers.append(detoxify(model, r["toxic"]))


test_df["generated"] = model_answers

100%|██████████| 473/473 [18:27<00:00,  2.34s/it]


In [None]:
test_df.head()

Unnamed: 0,toxic,nontoxic,generated
7898,You know how the Romans settled this shit?,do you know how the Romans handled this?,you know how the Romans decided this?
5283,What kind of shit do you talk?,what are you talking about?,what are you talking about?
6611,Idiot. There's a storage shed near the back.,there's a warehouse in the back.,there's a storage shed in the back.
3187,Come get her! Goddamn you!,come and get her!,come get her!
8466,"You fucker Hello, post office?","hello, post office?","hello, mailman?"


## Save results


In [17]:
test_df.to_csv("../data/generated/bart.csv", index=False)