<H1>Fine-Tuning BART Model on Response Generation in English</h1>

<i>vers. 10/2023</i>

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW, get_polynomial_decay_schedule_with_warmup
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
from tqdm import tqdm
import sys
import math
import evaluate
import os

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

if torch.cuda.is_available():
    device = torch.device(0)
    print('Using GPU')
else:
    device = torch.device("cpu")
    print('Using CPU')

In [None]:
dataset_name = 'daily_dialog'

print('Loading ', dataset_name)
dataset = load_dataset('daily_dialog')
train_dialogues = dataset['train']['dialog']
valid_dialogues = dataset['validation']['dialog']
test_dialogues = dataset['test']['dialog']


In [None]:
model_size = "large" # 'base'

In [None]:
# SET MODELS
if model_size == 'large':
    model_name =  "facebook/bart-large"
    output_path = "/BART_Large"

else:
    model_name =  "facebook/bart-base"
    output_path = "/BART_Base"

tokenizer = BartTokenizer.from_pretrained(model_name)

In [5]:
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

# For empathetic dialogues
exclude_symbol = "_conv"
comma_symbol = "_comma_"

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()

    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]

            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]

        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1

        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()

    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])

    return new_token_list

In [6]:
#CODE TO LOAD THE DATASET

def load_daily(dataset, tokenizer):

    for i, dialogue in enumerate(tqdm(dataset)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)

        dataset[i] = new_dialogue

    utter_num = 0

    for dialogue in dataset:
        utter_num += len(dialogue)

    return dataset, utter_num

In [None]:
train_dialogues, num_train = load_daily(train_dialogues, tokenizer)
valid_dialogues, num_valid = load_daily(valid_dialogues, tokenizer)
test_dialogues, num_test = load_daily(test_dialogues, tokenizer)

In [None]:
print(f"The number of train dialogues: {len(train_dialogues)}")
print(f"The number of valid dialogues: {len(valid_dialogues)}")
print(f"The number of test dialogues: {len(test_dialogues)}")

print(f"The number of train utterances: {num_train}")
print(f"The number of valid utterances: {num_valid}")
print(f"The number of test utterances: {num_test}")

In [None]:
sp1_token = '<sp1>'
sp2_token = '<sp2>'
# bos_token = '<bos>'
# eos_token = '<eos>'
max_len = 1024
seed = 0
gpu = 0

#Tokeniser
special_tokens = {#'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

# eos_token = tokenizer.eos_token
num_new_tokens = tokenizer.add_special_tokens(special_tokens)

vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
# bos_id = vocab[bos_token]
# eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

In [None]:
lr = 2e-5
batch_size = 32
num_workers = 0
num_epochs = 10
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Quit!'
top_p = 0.8

In [None]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

fix_seed(seed)

In [None]:
ckpt_dir = output_path + '/saved_models'
os.system("mkdir "+ ckpt_dir)

In [None]:
def preprocess_dialog(dialog, window_size=5+1):
    instances = []

    # response = dialog["dialog"][-1]  # Last utterance as the response

    for i in range(0, len(dialog) - window_size, 2):

        window = dialog[i:i+window_size]
        window_context = []
        for j, utterance in enumerate(window):
            speaker = sp1_token if j % 2 == 0 else sp2_token
            window_context.append(speaker + " " + utterance)

        # Add special tokens for bos, eos
        # window_context.insert(0, '<s>')
        # window_context.append("</s>")

        window_context = ' '.join(window_context)
        # window_context = window_context + sp2_token
        response =  sp2_token + dialog[i+window_size]

        # print('window_context: ', type(window_context), window_context)
        # print('response: ', type(response), response)



        # print()
        # print('window_context: ', window_context)
        # print('response: ', response)


        # Tokenize the context and response
        # input_ids = tokenizer.encode_plus(window_context, add_special_tokens=True, padding='max_length', max_length=max_len, truncation=True , return_tensors="pt")
        # decoder_input_ids = tokenizer.encode_plus(response, add_special_tokens=True, padding='max_length', max_length=max_len, truncation=True, return_tensors="pt")

        input_ids = tokenizer.encode_plus(window_context, add_special_tokens=True, return_tensors="pt")
        decoder_input_ids = tokenizer.encode_plus(response, add_special_tokens=True, return_tensors="pt")

        if (len(input_ids['input_ids']) + len(decoder_input_ids['input_ids']) -2) <= max_len: # 2 to ignore eos and bos tokens of decoder


            labels = decoder_input_ids['input_ids']
            # labels[labels[:, :] == vocab['<pad>']] = -100

            instance = {
                "input_ids": input_ids["input_ids"].squeeze(0),
                # "decoder_input_ids": decoder_input_ids["input_ids"].squeeze(0), #Testing purposes
                # "decoder_attention_mask": decoder_input_ids["attention_mask"].squeeze(0), #Testing purposes
                "attention_mask": input_ids["attention_mask"].squeeze(0),
                "labels": labels.squeeze(0)
            }

            # print('input_ids: ', instance['input_ids'].shape)
            # print('attention_mask: ', instance['attention_mask'].shape)
            # print('decoder_input_ids: ', instance['decoder_input_ids'].shape)
            # print('decoder_attention_mask: ', instance['decoder_attention_mask'].shape)
            instances.append(instance)

    return instances

In [None]:
train_dialogues[0]

In [None]:
train_instances = []
val_instances = []

#dummy
test_instances = []

for dialog in tqdm(train_dialogues):
    # print(len(dialog), dialog)
    train_instances.extend(preprocess_dialog(dialog))

for dialog in tqdm(valid_dialogues):
    val_instances.extend(preprocess_dialog(dialog))

for dialog in tqdm(test_dialogues):
    test_instances.extend(preprocess_dialog(dialog))

In [None]:
train_instances[0]['input_ids']

In [None]:
train_instances[0]['attention_mask']

In [None]:
train_instances[0]['labels']

In [None]:
class DialogueDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        return self.instances[idx]

class PadCollate():
    def __init__(self, pad_id):
        self.pad_id = pad_id

    def pad_collate(self, batch):
        input_ids, attention_mask, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs['input_ids']))
            attention_mask.append(torch.LongTensor(seqs['attention_mask']))
            labels.append(torch.LongTensor(seqs['labels']))

        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.pad_id)
        attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels
            }

In [None]:
#Create data

ppd = PadCollate(pad_id=vocab['<pad>'])


train_dataset = DialogueDataset(train_instances)
val_dataset = DialogueDataset(val_instances)
test_dataset =  DialogueDataset(test_instances)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=ppd.pad_collate)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=ppd.pad_collate)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=ppd.pad_collate)

In [None]:
print(train_dataset.__getitem__(0)['input_ids'].shape)
print(train_dataset.__getitem__(0)['attention_mask'].shape)


In [None]:
print(train_dataset.__getitem__(1)['input_ids'].shape)
print(train_dataset.__getitem__(1)['attention_mask'].shape)

In [None]:
if torch.cuda.is_available():
    device = torch.device(0)
    print('Using GPU')
else:
    device = torch.device("cpu")
    print('Using CPU')

In [None]:
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)
model.resize_token_embeddings(vocab_size)

In [None]:
print("Loading the optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
# Calculate total training steps
num_batches = len(train_dataloader)
total_train_steps = num_epochs * num_batches
warmup_steps = int(warmup_ratio * total_train_steps)

sched = get_polynomial_decay_schedule_with_warmup(
    optim,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps,
    power=2
)

writer = SummaryWriter()


In [None]:
def validation():

    print("Validation processing...")
    model.eval()

    valid_losses = []
    valid_ppls = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_dataloader)):

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            # decoder_input_ids = batch["decoder_input_ids"].to(device)
            # decoder_attention_mask = batch["decoder_attention_mask"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                # decoder_input_ids=decoder_input_ids,
                # decoder_attention_mask=decoder_attention_mask,
                labels = labels
                # use_cache=False
            )

            loss = outputs.loss

            valid_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            valid_ppls.append(ppl)

        valid_losses = [loss.item() for loss in valid_losses]
        valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
        valid_loss = np.mean(valid_losses)
        valid_ppl = np.mean(valid_ppls)

        if math.isnan(valid_ppl):
            valid_ppl = 1e+8

    return valid_loss, valid_ppl

In [None]:
def train():

    print('Number of epochs: ', num_epochs)
    fix_seed(seed)  # Fix seed before training
    print("Training starts.")

    best_loss = sys.float_info.max
    last_epoch= 0

    start_epoch = last_epoch +1

    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()

        print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
        train_losses = []
        train_ppls = []

        # total_loss = 0

        for batch in tqdm(train_dataloader):

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optim.zero_grad()


            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels = labels
            )

            loss = outputs.loss
            loss.backward()
            optim.step()
            sched.step()

            train_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            train_ppls.append(ppl)

        train_losses = [loss.item() for loss in train_losses]
        train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
        train_loss = np.mean(train_losses)
        train_ppl = np.mean(train_ppls)
        print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("PPL/train", train_ppl, epoch)

        last_epoch += 1

        valid_loss, valid_ppl = validation()

        if valid_loss < best_loss:
            best_loss = valid_loss
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'sched_state_dict': sched.state_dict(),
                'loss': best_loss,
                'epoch': last_epoch
            }

            torch.save(state_dict, f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            print("*"*10 + "Current best checkpoint is saved." + "*"*10)
            print(f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")

        print(f"Best valid loss: {best_loss}")
        print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")

        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("PPL/valid", valid_ppl, epoch)

        writer.add_scalars("Losses", {
            'train': train_loss,
            'valid': valid_loss,
        }, epoch)
        writer.add_scalars("PPLs", {
            'train': train_ppl,
            'valid': valid_ppl,
        }, epoch)

    print("Training finished!")

In [None]:
train()

<h2>Generate One Response: NO-CD</h2>

In [None]:
def infer(window_size=5):
    model.eval()
    fix_seed(seed)

    generated_responses = []
    actual_responses = []
    inputs = []

    with torch.no_grad():

        for dialog in tqdm(test_dialogues):

            for i in range(0, len(dialog) - window_size, 2): #In steps of 2

                window = dialog[i:i+window_size]
                window_context = []
                for j, utterance in enumerate(window):
                    speaker = sp1_token if j % 2 == 0 else sp2_token
                    window_context.append(speaker + " " + utterance)

                # Add special tokens for bos, eos
                # window_context.insert(0, '<s>')
                # window_context.append("</s>")

                window_context = ' '.join(window_context)
                # window_context = window_context

                # print()
                # print('window context: ', window_context)

                #Get encodings
                encodings = tokenizer.encode_plus(window_context, add_special_tokens=True, padding='max_length', max_length=512, truncation=True , return_tensors="pt")
                # print(encodings)
                input_ids = encodings['input_ids'].to(device)
                attention_mask = encodings['attention_mask'].to(device)

                # print('input_ids: ', input_ids.shape, input_ids)
                # print('attention_mask: ', attention_mask.shape, attention_mask)

                output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=max_len, do_sample=True, top_p=top_p).squeeze(0)

                #Generate response
                # output_ids = model(input_ids=input_ids, attention_mask=attention_mask)

                # print('encoded response: ', output_ids)
                # print('encoded response: ', output_ids.squeeze(0))

                response = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                # print('generated response: ', response)

                actual_response = dialog[i+window_size]

                # print('actual_response: ', actual_response)

                generated_responses.append(response)
                actual_responses.append(actual_response)
                inputs.append(window_context)
            # break

    return generated_responses, actual_responses, inputs


In [None]:
#GENERATE RESPONSES

WINDOW_SIZE = 3

generated_responses, actual_responses, inputs = infer(window_size=WINDOW_SIZE)

In [None]:
assert len(generated_responses) == len(actual_responses)
print(len(generated_responses))
print(len(actual_responses))

In [None]:
#SAVE RESPONSES

selected_model = 'bart'

file_generated = output_path+ "/" + selected_model + "_epochs_" + str(num_epochs) + "_generated_responses_" + dataset_name + "_window" + str(WINDOW_SIZE) +'_new'

df = pd.DataFrame({'input': inputs, 'actual_responses': actual_responses, 'generated_responses': generated_responses} )
df.to_csv(file_generated+'.csv', index = False, encoding = 'UTF-8')

In [None]:
#LOAD METRICS

sacrebleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
chrf = evaluate.load("chrf")

In [None]:
actual_responses = [[res] for res in actual_responses] #Refs must be in a list of list of str

print(generated_responses[:5])
print(actual_responses[:5])

In [None]:
bleu_score = sacrebleu.compute(predictions=generated_responses, references=actual_responses)

rouge_score = rouge.compute(predictions=generated_responses, references=actual_responses)

bert_score = bertscore.compute(predictions=generated_responses, references=actual_responses, lang='en')
precision = bert_score['precision']
recall = bert_score['recall']
f1 = bert_score['f1']
avg_precision_bert = sum(precision) / len(precision)
avg_recall_bert = sum(recall) / len(recall)
avg_f1_bert = sum(f1) / len(f1)

chrf_score = chrf.compute(predictions=generated_responses, references=actual_responses)

In [None]:
print('Saving results...')
fout = open(file_generated+".txt", "w")
fout.write('Bleu score: {} \n '.format(bleu_score)) #Range from 0 to 100
fout.write('Rouge score: {} \n'.format(rouge_score))
fout.write('Bert score:  {} \n'.format(bert_score))
fout.write('Avg precision Bert score: {} \n'.format(avg_precision_bert))
fout.write('Avg recall Bert score: {} \n'.format(avg_recall_bert))
fout.write('Avg f1 Bert score: {} \n'.format(avg_f1_bert))
fout.write('chrf score: {} \n'.format(chrf_score))
fout.close()


<h2>Generate Multiple Candidates: CD1 / CD2 </h2>

In [None]:
def infer_bart(window_size=WINDOW_SIZE, n = 1):
    model.eval()
    fix_seed(seed)

    generated_responses = []
    actual_responses = []
    inputs = []

    with torch.no_grad():

        for dialog in tqdm(test_dialogues):

            for i in range(0, len(dialog) - window_size, 2): #In steps of 2

                window = dialog[i:i+window_size]
                inputs.append(window)

                window_context = []
                for j, utterance in enumerate(window):
                    speaker = sp1_token if j % 2 == 0 else sp2_token
                    window_context.append(speaker + " " + utterance)

                # Add special tokens for bos, eos
                # window_context.insert(0, '<s>')
                # window_context.append("</s>")

                window_context = ' '.join(window_context)

                # window_context = window_context

                # print()
                # print('window context: ', window_context)

                #Get encodings
                encodings = tokenizer.encode_plus(window_context, add_special_tokens=True, padding='max_length', max_length=512, truncation=True , return_tensors="pt")
                # print(encodings)
                input_ids = encodings['input_ids'].to(device)
                attention_mask = encodings['attention_mask'].to(device)

                if n > 1:
                    output_ids = model.generate(
                        input_ids=input_ids, 
                        temperature = 1.0,
                        attention_mask=attention_mask, 
                        max_length=max_len, 
                        do_sample=True, 
                        top_p=top_p, 
                        num_return_sequences=n, 
                        )    
                    response = [tokenizer.decode(r, skip_special_tokens=True, clean_up_tokenization_spaces=True) for r in output_ids]

                else:
                    output_ids = model.generate(
                        input_ids=input_ids, 
                        attention_mask=attention_mask, 
                        max_length=max_len, 
                        do_sample=True, 
                        top_p=top_p).squeeze(0)     
                    response = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

                # print('generated response: ', response)

                actual_response = dialog[i+window_size]

                # print('actual_response: ', actual_response)

                generated_responses.append(response)
                actual_responses.append(actual_response)
            # break

    return generated_responses, actual_responses, inputs

In [None]:
#GENERATE MULTIPLE CANDIDATES
N=10
generated_responses, actual_responses, contexts = infer_bart(window_size= WINDOW_SIZE, n=N)

In [None]:
#SAVE RESULTS

import pandas as pd
output_path = 'path/to/results'
model_name = 'bart'

new_file_generated = output_path + model_name + "_generated_multiple_responses_" + dataset_name  +'_window'+ str(WINDOW_SIZE) +'_N'+ str(N)
print(new_file_generated)

df = pd.DataFrame({'inputs': contexts,'actual responses':actual_responses})
for res in range(N):
    df['generated_responses_'+str(res)] = [x[res] for x in generated_responses]

df.to_csv(new_file_generated+'.csv', index = False, encoding = 'UTF-8')

In [None]:
print('Saving results...')
fout = open(file_generated+".txt", "w")
fout.write('Bleu score: {} \n '.format(bleu_score)) #Range from 0 to 100
fout.write('Rouge score: {} \n'.format(rouge_score))
fout.write('Bert score:  {} \n'.format(bert_score))
fout.write('Avg precision Bert score: {} \n'.format(avg_precision_bert))
fout.write('Avg recall Bert score: {} \n'.format(avg_recall_bert))
fout.write('Avg f1 Bert score: {} \n'.format(avg_f1_bert))
fout.write('chrf score: {} \n'.format(chrf_score))
fout.close()