<H1>Fine-Tuning GPT2 / DialoGPT Models on Response Generation in English</h1>

<i>vers. 10/2023</i>

## Import required libraries

In [48]:
#!pip install datasets
#!pip install transformers
#!pip install evaluate
#!pip install rouge_score bert_score sacrebleu

In [49]:
from datasets import list_datasets, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from torch.nn import functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from itertools import chain
import torch
import math
import numpy as np
import random
import datasets
import pandas as pd

In [None]:
#For networking purposes
import os, sys
os.environ['CURL_CA_BUNDLE'] = ''
os.environ["CUDA_VISIBLE_DEVICES"] = "5"

import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

In [50]:
selected_model = "gpt2" # "dialoGPT" #
model_size = "large"  # "base" #
dataset_name = 'daily_dialog'
WINDOW_SIZE = 3

## Tokenizer

In [51]:
if selected_model == 'dialoGPT':
    if model_size = "large":
        model_name = 'microsoft/DialoGPT-medium'
        output_path = '/dialoGPT_Medium'
    
    else:
        model_name = 'dialoGPT-small'
        output_path = '/dialoGPT_Small'

    tokenizer = AutoTokenizer.from_pretrained(model_name)


elif selected_model == 'gpt2':

    if model_size = "large":
        model_name = 'gpt2-large'
        output_path = '/GPT2_Medium'

    else:
        model_name = 'gpt2'
        output_path = '/GPT2_Small'
    
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)

else:
    print('No tokenizer')

In [52]:
#Parameters

sp1_token = '<sp1>'
sp2_token = '<sp2>'
bos_token = '<bos>'
max_turns = 5
max_len = 1024
seed = 0
gpu = 0

In [None]:
#Tokeniser
special_tokens = {'bos_token': bos_token,
                'additional_special_tokens': [sp1_token, sp2_token]}

eos_token = tokenizer.eos_token

#Add pad token
pad_token = eos_token
tokenizer.add_special_tokens({'pad_token':pad_token})

num_new_tokens = tokenizer.add_special_tokens(special_tokens)


vocab = tokenizer.get_vocab()
vocab_size = len(vocab)
bos_id = vocab[bos_token]
eos_id = vocab[eos_token]
sp1_id = vocab[sp1_token]
sp2_id = vocab[sp2_token]

In [None]:
lr = 2e-5
batch_size = 8
num_workers = 1
num_epochs = 10
warmup_ratio = 0.1
last_epoch = 0
end_command = 'Quit!'
top_p = 0.8

In [53]:
def fix_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)

fix_seed(seed)

In [54]:
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

# For empathetic dialogues
exclude_symbol = "_conv"
comma_symbol = "_comma_"

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()

    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]

            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]

        if token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1

        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()

    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])

    return new_token_list

## Load dataset

In [None]:
print('Loading ', dataset_name)
dataset = load_dataset('daily_dialog')
train_dialogues = dataset['train']['dialog']
valid_dialogues = dataset['validation']['dialog']
test_dialogues = dataset['test']['dialog']


In [56]:

def load_daily(dataset, tokenizer):

    for i, dialogue in enumerate(tqdm(dataset)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)

        dataset[i] = new_dialogue

    utter_num = 0

    for dialogue in dataset:
        utter_num += len(dialogue)

    return dataset, utter_num

In [None]:
train_dialogues, num_train = load_daily(train_dialogues, tokenizer)
valid_dialogues, num_valid = load_daily(valid_dialogues, tokenizer)
test_dialogues, num_test = load_daily(test_dialogues, tokenizer)

In [None]:
print(f"The number of train dialogues: {len(train_dialogues)}")
print(f"The number of valid dialogues: {len(valid_dialogues)}")
print(f"The number of test dialogues: {len(test_dialogues)}")

print(f"The number of train utterances: {num_train}")
print(f"The number of valid utterances: {num_valid}")
print(f"The number of test utterances: {num_test}")

In [59]:
all_len = [d for d in train_dialogues if len(d)>= 6]

In [None]:
len(all_len)

In [None]:
ckpt_dir = output_path + '/saved_models'
os.system("mkdir "+ ckpt_dir)

In [62]:
def preprocess_dialog(dialog, window_size=WINDOW_SIZE + 1): #Context is 5, 6th is response
  instances = []

  # response = dialog["dialog"][-1]  # Last utterance as the response

  for i in range(0, len(dialog) - window_size + 1, 2): #Need the +1 to consider last response
    window = dialog[i:i+window_size]
    window_context = []
    for j, utterance in enumerate(window):
      sp_id = sp1_id if j % 2 == 0 else sp2_id
      input_ids = [sp_id] + tokenizer.encode(utterance)
      window_context.append(input_ids)

    input_ids = [bos_id] + list(chain.from_iterable(window_context)) + [eos_id]

    if len(input_ids) <= max_len:
      token_type_ids = [[sp1_id] * len(ctx) if c % 2 == 0 else [sp2_id] * len(ctx) for c, ctx in enumerate(window_context)]
      token_type_ids = [sp1_id] + list(chain.from_iterable(token_type_ids)) + [sp2_id]

      assert len(input_ids) == len(token_type_ids)

      labels = [[-100] * len(ctx) if c < len(window_context)-1 else [-100] + ctx[1:] for c, ctx in enumerate(window_context)]
      assert labels[-1][1:] == window_context[-1][1:]
      labels = [-100] + list(chain.from_iterable(labels)) + [eos_id]
      assert len(input_ids) == len(labels)

      instance = {
          "input_ids": torch.LongTensor(input_ids),
          "token_type_ids": torch.LongTensor(token_type_ids),
          "labels": torch.LongTensor(labels)
      }

      instances.append(instance)
    else:
      print('xd')


  return instances

In [None]:
train_instances = []
val_instances = []

# dummy
# test_instances = []
debug=0
for dialog in tqdm(train_dialogues):
    train_instances.extend(preprocess_dialog(dialog))

for dialog in tqdm(valid_dialogues):
    val_instances.extend(preprocess_dialog(dialog))

# for dialog in tqdm(test_dialogues):
#     test_instances.extend(preprocess_dialog(dialog))


In [None]:
print('bos_id: ', bos_id, 'eos_id: ', eos_id, 'sp1_id: ', sp1_id, 'sp2_id: ', sp2_id )

In [None]:
train_instances[0]['input_ids'].shape

In [None]:
train_instances[0]['token_type_ids'].shape

In [None]:
train_instances[0]['labels'].shape

In [68]:
class DialogueDataset(Dataset):
    def __init__(self, instances):
        self.instances = instances

    def __len__(self):
        return len(self.instances)

    def __getitem__(self, idx):
        return self.instances[idx]

class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id

    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs['input_ids']))
            token_type_ids.append(torch.LongTensor(seqs['token_type_ids']))
            labels.append(torch.LongTensor(seqs['labels']))

        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id)
        token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=self.eos_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        return {
                "input_ids": input_ids,
                "token_type_ids": token_type_ids,
                "labels": labels
            }

In [None]:
train_dialogues[0]

In [None]:
if torch.cuda.is_available():
    device = torch.device(f"cuda:{gpu}")
    print('Using GPU')
else:
    device = torch.device("cpu")
    print('Using CPU')

In [None]:
print("Loading the model: ", selected_model)

if selected_model == 'dialoGPT':
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
elif selected_model == 'gpt2':
    model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
else:
    print('No model')

model.resize_token_embeddings(vocab_size)
max_len = min(max_len, model.config.n_ctx)

In [72]:
#Load from checkpoint
# ckpt = torch.load("/content/saved_models/best_ckpt_epoch=5_valid_loss=2.583.ckpt", map_location=device)
# model.load_state_dict(ckpt['model_state_dict'])

In [None]:
# Load optimizer
print("Loading the optimizer...")
optim = torch.optim.AdamW(model.parameters(), lr=lr)

In [74]:
#Create data loaders

ppd = PadCollate(eos_id=eos_id)

train_dataset = DialogueDataset(train_instances)
val_dataset = DialogueDataset(val_instances)
# test_dataset =  DialogueDataset(test_instances)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=ppd.pad_collate)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=ppd.pad_collate)
# test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=ppd.pad_collate)

In [75]:
# Calculate total training steps

num_batches = len(train_dataloader)
total_train_steps = num_epochs * num_batches
warmup_steps = int(warmup_ratio * total_train_steps)

sched = get_polynomial_decay_schedule_with_warmup(
    optim,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_train_steps,
    power=2
)

writer = SummaryWriter()

In [76]:
def validation():

    print("Validation processing...")
    model.eval()

    valid_losses = []
    valid_ppls = []
    with torch.no_grad():
        for i, batch in enumerate(tqdm(val_dataloader)):

            input_ids = batch["input_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )

            loss, logits = outputs[0], outputs[1]

            valid_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            valid_ppls.append(ppl)

        valid_losses = [loss.item() for loss in valid_losses]
        valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
        valid_loss = np.mean(valid_losses)
        valid_ppl = np.mean(valid_ppls)

        if math.isnan(valid_ppl):
            valid_ppl = 1e+8

    return valid_loss, valid_ppl

In [77]:
def train():

    fix_seed(seed)  # Fix seed before training
    print("Training starts.")

    best_loss = sys.float_info.max
    last_epoch= 0

    start_epoch = last_epoch +1
    for epoch in range(start_epoch, start_epoch+num_epochs):
        model.train()

        print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
        train_losses = []
        train_ppls = []
        for i, batch in enumerate(tqdm(train_dataloader)):

            input_ids = batch["input_ids"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            labels = batch["labels"].to(device)

            optim.zero_grad()

            outputs = model(
                input_ids=input_ids,
                token_type_ids = token_type_ids,
                labels = labels
            )

            loss = outputs.loss
            loss.backward()
            optim.step()
            sched.step()

            train_losses.append(loss.detach())
            ppl = torch.exp(loss.detach())
            train_ppls.append(ppl)

        train_losses = [loss.item() for loss in train_losses]
        train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
        train_loss = np.mean(train_losses)
        train_ppl = np.mean(train_ppls)
        print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")

        writer.add_scalar("Loss/train", train_loss, epoch)
        writer.add_scalar("PPL/train", train_ppl, epoch)

        last_epoch += 1

        valid_loss, valid_ppl = validation()

        if valid_loss < best_loss:
            best_loss = valid_loss
            state_dict = {
                'model_state_dict': model.state_dict(),
                'optim_state_dict': optim.state_dict(),
                'sched_state_dict': sched.state_dict(),
                'loss': best_loss,
                'epoch': last_epoch
            }

            torch.save(state_dict, f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")
            print("*"*10 + "Current best checkpoint is saved." + "*"*10)
            print(f"{ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(best_loss, 4)}.ckpt")

        print(f"Best valid loss: {best_loss}")
        print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")

        writer.add_scalar("Loss/valid", valid_loss, epoch)
        writer.add_scalar("PPL/valid", valid_ppl, epoch)

        writer.add_scalars("Losses", {
            'train': train_loss,
            'valid': valid_loss,
        }, epoch)
        writer.add_scalars("PPLs", {
            'train': train_ppl,
            'valid': valid_ppl,
        }, epoch)

    print("Training finished!")

In [None]:
train()

<h2>Generate One Response: NO-CD</h2>

In [None]:
def infer(window_size=5):
    model.eval()
    fix_seed(seed)
    generated_responses = []
    actual_responses = []
    inputs = []

    with torch.no_grad():

        for dialog in tqdm(test_dialogues):

            for i in range(0, len(dialog) - window_size, 2): #In steps of 2

                window = dialog[i:i+window_size]
                window_context = []

                for j, utterance in enumerate(window):
                    #Set speaker 1 or speaker 2
                    sp_id = sp1_id if j % 2 == 0 else sp2_id
                    input_ids = [sp_id] + tokenizer.encode(utterance)
                    window_context.append(input_ids)
                    # context.append(utterance)

                # print()
                # for c in window_context:
                #     print(tokenizer.decode(c))
                start_sp_id = window_context[0][0]
                next_sp_id = sp1_id if start_sp_id == sp2_id else sp2_id
                assert start_sp_id != next_sp_id

                input_ids = [bos_id] + list(chain.from_iterable(window_context)) + [next_sp_id] #Because window is 5, so 6th utter is = sp2


                token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(window_context)]
                assert len(token_type_ids) == len(window_context)
                token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [next_sp_id]

                assert len(input_ids) == len(token_type_ids)
                input_len = len(input_ids)

                input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
                token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)

                output_ids = model.generate(input_ids=input_ids, token_type_ids=token_type_ids, pad_token_id=eos_id,
                                            max_length=max_len, do_sample=True, top_p=top_p).squeeze(0)

                output_ids = output_ids.tolist()[input_len:]
                inputs.append('<SEP>'.join(window))


                # output_ids = nucleus_sampling(input_ids, token_type_ids, input_len, next_sp_id)


                res = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)



                actual_res = dialog[i+window_size]

                # print(f"Bot response: {res}")
                # print(f"Actual response: {actual_res}")


                generated_responses.append(res)
                actual_responses.append(actual_res)

    return generated_responses, actual_responses,inputs

In [None]:
WINDOW_SIZE = 3
generated_responses, actual_responses, inputs= infer(window_size= WINDOW_SIZE)

In [None]:
assert len(generated_responses) == len(actual_responses)
print(len(generated_responses))
print(len(actual_responses))

### Store responses

In [None]:
import pandas as pd

file_generated = output_path+ "/" + selected_model + "_" + model_size+ "_epochs_" + str(num_epochs) + "_generated_responses_" + dataset_name + "_window" + str(WINDOW_SIZE)

df = pd.DataFrame({'input': inputs, 'actual_responses': actual_responses, 'generated_responses': generated_responses} )
df.to_csv(file_generated+'.csv', index = False, encoding = 'UTF-8')

### Compute metrics

In [None]:
#LOAD METRICS

import evaluate

sacrebleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")
chrf = evaluate.load("chrf")

In [None]:
actual_responses = [[res] for res in actual_responses] #Refs must be in a list of list of str

print(generated_responses[:5])
print(actual_responses[:5])

In [None]:
bleu_score = sacrebleu.compute(predictions=generated_responses, references=actual_responses)

rouge_score = rouge.compute(predictions=generated_responses, references=actual_responses)

bert_score = bertscore.compute(predictions=generated_responses, references=actual_responses, lang='en')
precision = bert_score['precision']
recall = bert_score['recall']
f1 = bert_score['f1']
avg_precision_bert = sum(precision) / len(precision)
avg_recall_bert = sum(recall) / len(recall)
avg_f1_bert = sum(f1) / len(f1)

chrf_score = chrf.compute(predictions=generated_responses, references=actual_responses)

In [None]:
print('Saving results...')
fout = open(file_generated+".txt", "w")
fout.write('Bleu score: {} \n '.format(bleu_score)) #Range from 0 to 100
fout.write('Rouge score: {} \n'.format(rouge_score))
fout.write('Bert score:  {} \n'.format(bert_score))
fout.write('Avg precision Bert score: {} \n'.format(avg_precision_bert))
fout.write('Avg recall Bert score: {} \n'.format(avg_recall_bert))
fout.write('Avg f1 Bert score: {} \n'.format(avg_f1_bert))
fout.write('chrf score: {} \n'.format(chrf_score))
fout.close()

<h2>Generate Multiple Candidates: CD1 / CD2 </h2>

In [None]:
def infer_gpt(window_size=5, n=1):
    model.eval()
    fix_seed(seed)
    generated_responses = []
    actual_responses = []
    inputs = []

    with torch.no_grad():

        for dialog in tqdm(test_dialogues):

            for i in range(0, len(dialog) - window_size, 2): #In steps of 2

                window = dialog[i:i+window_size]
                window_context = []
                inputs.append(window)


                for j, utterance in enumerate(window):
                    #Set speaker 1 or speaker 2
                    sp_id = sp1_id if j % 2 == 0 else sp2_id
                    input_ids = [sp_id] + tokenizer.encode(utterance)
                    window_context.append(input_ids)
                    # context.append(utterance)

                # print()
                # for c in window_context:
                #     print(tokenizer.decode(c))
                start_sp_id = window_context[0][0]
                next_sp_id = sp1_id if start_sp_id == sp2_id else sp2_id
                assert start_sp_id != next_sp_id

                input_ids = [bos_id] + list(chain.from_iterable(window_context)) + [next_sp_id] #Because window is 5, so 6th utter is = sp2


                token_type_ids = [[start_sp_id] * len(hist) if h % 2 == 0 else [next_sp_id] * len(hist) for h, hist in enumerate(window_context)]
                assert len(token_type_ids) == len(window_context)
                token_type_ids = [start_sp_id] + list(chain.from_iterable(token_type_ids)) + [next_sp_id]

                assert len(input_ids) == len(token_type_ids)
                input_len = len(input_ids)

                input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(device)
                token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(device)


                # output_ids = nucleus_sampling(input_ids, token_type_ids, input_len, next_sp_id)

                if n>1:
                    #print('hello')
                    output_ids = model.generate(
                        input_ids=input_ids, 
                        temperature= 1.0,
                        token_type_ids=token_type_ids, 
                        pad_token_id=eos_id,
                        max_length=max_len, 
                        do_sample=True, 
                        top_p=top_p, 
                        num_return_sequences=n, 
                        )
                    res = [tokenizer.decode(r.tolist()[input_len:], skip_special_tokens=True, clean_up_tokenization_spaces=True) for r in output_ids]
                    #print('bye')

                else:
                    output_ids = model.generate(
                        input_ids=input_ids, 
                        token_type_ids=token_type_ids, 
                        pad_token_id=eos_id,
                        max_length=max_len, 
                        do_sample=True, 
                        top_p=top_p).squeeze(0)
                    output_ids = output_ids.tolist()[input_len:]
                    res = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)


                actual_res = dialog[i+window_size]
                generated_responses.append(res)
                actual_responses.append(actual_res)

    return generated_responses, actual_responses,inputs

In [None]:
N=10
generated_responses, actual_responses, contexts = infer_gpt(window_size= WINDOW_SIZE, n=N)

In [None]:
import pandas as pd

model_name = 'gpt2_LARGE'

new_file_generated = output_path +model_name + "_generated_multiple_responses_" + dataset_name  +'_window'+ str(WINDOW_SIZE) +'_N'+ str(N)
print(new_file_generated)

df = pd.DataFrame({'inputs': contexts,'actual responses':actual_responses})

for res in range(N):
    df['generated_responses_'+str(res)] = [x[res] for x in generated_responses]

df.to_csv(new_file_generated+'.csv', index = False, encoding = 'UTF-8')

In [None]:
bleu_score = sacrebleu.compute(predictions=generated_responses, references=actual_responses)

rouge_score = rouge.compute(predictions=generated_responses, references=actual_responses)

bert_score = bertscore.compute(predictions=generated_responses, references=actual_responses, lang='en')
precision = bert_score['precision']
recall = bert_score['recall']
f1 = bert_score['f1']
avg_precision_bert = sum(precision) / len(precision)
avg_recall_bert = sum(recall) / len(recall)
avg_f1_bert = sum(f1) / len(f1)

chrf_score = chrf.compute(predictions=generated_responses, references=actual_responses)

In [None]:
print('Saving results...')
fout = open(file_generated+".txt", "w")
fout.write('Bleu score: {} \n '.format(bleu_score)) #Range from 0 to 100
fout.write('Rouge score: {} \n'.format(rouge_score))
fout.write('Bert score:  {} \n'.format(bert_score))
fout.write('Avg precision Bert score: {} \n'.format(avg_precision_bert))
fout.write('Avg recall Bert score: {} \n'.format(avg_recall_bert))
fout.write('Avg f1 Bert score: {} \n'.format(avg_f1_bert))
fout.write('chrf score: {} \n'.format(chrf_score))
fout.close()