<a href="https://colab.research.google.com/github/denniswillie/FiQA_GPT2/blob/main/FiQA_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoConfig, get_polynomial_decay_schedule_with_warmup

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
class LoadDataArgs():
  def __init__(self):
    self.data_dir="data"
    self.train_prefix="train"
    self.valid_prefix="valid"
    self.train_frac=0.85
    self.model_type="gpt2"

class TrainArgs():
  def __init__(self):
    self.seed = 0
    self.mode = "train"
    self.data_dir = "data"
    self.train_prefix="train"
    self.valid_prefix="valid"
    self.model_type="gpt2"
    self.bos_token="<bos>"
    self.sp1_token="<question>"
    self.sp2_token="<answer>"
    self.gpu="0"
    self.lr=2e-5
    self.warmup_ratio = 0.0
    self.batch_size = 2
    self.num_workers = 0
    self.num_epochs = 3
    self.max_len = 1024
    self.max_turns = 1
    self.ckpt_dir = "saved_models"
    self.ckpt_name = None

class InferArgs():
  def __init__(self):
    self.seed = 0
    self.mode = "infer"
    self.data_dir = "data"
    self.model_type = "gpt2"
    self.bos_token = "<bos>"
    self.sp1_token = "<question>"
    self.sp2_token = "<answer>"
    self.gpu = "0"
    self.max_len = 1024
    self.max_turns = 1
    self.top_p = 0.8
    self.ckpt_dir = "saved_models"
    self.ckpt_name = "something"
    self.end_command="Abort!"

In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from tqdm import tqdm
from datasets import *
import json
import pandas as pd

def save_data(prefix, data_dir, dialogues, tokenizer):
    print(f"Saving {prefix} text file...")
    with open(f"{data_dir}/{prefix}_utters.json", 'w') as f:
        json.dump(dialogues, f)

    print(f"Saving {prefix} idx file...")
    ids = []
    for dialogue in tqdm(dialogues):
        dialogue_ids = []
        for utter in dialogue:
            tokens = tokenizer.tokenize(utter)
            token_ids = tokenizer.convert_tokens_to_ids(tokens)
            dialogue_ids.append(token_ids)
        ids.append(dialogue_ids)

    assert len(ids) == len(dialogues)

    with open(f"{data_dir}/{prefix}_ids.json", 'w') as f:
        json.dump(ids, f)

from tqdm import tqdm
from datasets import *


# For all
space = 'Ġ'
pre_quote = '’'
end_marks = ['.', ',', '?', '!', '...']
quotes = ['"', '\'']
abbreviations = ['s', 'd', 't', 'm', 're', 'll', 've', 'S', 'D', 'T', 'M', 'Re', 'Ll', 'Ve']

def load_fiqa_dataset(tokenizer, train_frac):
    # Load the FIQA dataset.
    DATASET_DIRNAME = "./drive/MyDrive/FiQA_train_task2/"
    QUESTIONS_FILENAME = "FiQA_train_question_final.tsv"
    ANSWERS_FILENAME = "FiQA_train_doc_final.tsv"
    QNA_PAIRS_FILENAME = "FiQA_train_question_doc_final.tsv"

    q_df = pd.read_csv(DATASET_DIRNAME + QUESTIONS_FILENAME, delimiter='\t')
    a_df = pd.read_csv(DATASET_DIRNAME + ANSWERS_FILENAME, delimiter='\t')
    qna_df = pd.read_csv(DATASET_DIRNAME + QNA_PAIRS_FILENAME, delimiter='\t')

    # Merge dataframes based on the qna pairs
    merged_df1 = qna_df.merge(q_df, left_on='qid', right_on='qid')
    df = merged_df1.merge(a_df, left_on='docid', right_on='docid')
    df.drop(['Unnamed: 0', 'Unnamed: 0_x','Unnamed: 0_y', 'qid', 'docid', 'timestamp_x', 'timestamp_y'], axis=1, inplace=True)
    nan_rows = df[df.isnull().any(axis=1)]
    before_drop = df.shape[0]
    dropped_rows = df[df.isnull().any(axis=1)]
    df = df.dropna(how='any')
    df_dict = df.to_dict()

    total_dialogues = []
    for q, a in list(zip(df_dict['question'].values(), df_dict['doc'].values())):
        total_dialogues.append([q, a])

    for i, dialogue in enumerate(tqdm(total_dialogues)):
        new_dialogue = []
        for utter in dialogue:
            token_list = tokenizer.tokenize(utter.strip().replace(pre_quote, quotes[1]))
            token_list = process_token_list(token_list)
            text = tokenizer.convert_tokens_to_string(token_list)
            new_dialogue.append(text)

        total_dialogues[i] = new_dialogue

    train_utter_num = 0
    valid_utter_num = 0
    train_dialogues = total_dialogues[:int(len(total_dialogues)*train_frac)]
    valid_dialogues = total_dialogues[int(len(total_dialogues)*train_frac):]

    for dialogue in train_dialogues:
        train_utter_num += len(dialogue)

    for dialogue in valid_dialogues:
        valid_utter_num += len(dialogue)

    return train_dialogues, valid_dialogues, train_utter_num, valid_utter_num

def process_token_list(token_list):
    token_list[0] = token_list[0].capitalize()

    quote_count = 0
    for i, token in enumerate(token_list):
        if space in token:
            if token[1:] in end_marks or token[1:] in abbreviations:
                token_list[i] = token[1:]

            if token[1:] == quotes[1]:
                if i<len(token_list)-1:
                    if token_list[i+1] in abbreviations or (token_list[i+1][0] == space and token_list[i+1][1:] in abbreviations):
                        token_list[i] = token[1:]

        if len(token) > 1 and token[0] == space and token[1:] in quotes:
            if quote_count % 2 == 1:
                token_list[i] = token[1:]
                quote_count = 0
            else:
                if i<len(token_list)-1 and token_list[i+1][0] == space:
                    token_list[i+1] = token_list[i+1][1:]
                quote_count += 1

        if token in end_marks or token[1:] in end_marks:
            if i<len(token_list)-1:
                if token_list[i+1][0] != space:
                    token_list[i+1] = space + token_list[i+1].capitalize()
                else:
                    token_list[i+1] = space + token_list[i+1][1:].capitalize()

    new_token_list = [token for token in token_list if token != space and len(token)>0]
    if new_token_list[-1] not in end_marks:
        new_token_list.append(end_marks[0])

    return new_token_list

In [None]:
# Load data
import os
args = LoadDataArgs()
tokenizer = GPT2Tokenizer.from_pretrained(args.model_type)
args.data_dir = f"{args.data_dir}/{args.model_type}"
if not os.path.isdir(args.data_dir):
  os.makedirs(args.data_dir)

print("Loading & Merging all datasets...")
train_dialogues, valid_dialogues, num_train, num_valid = load_fiqa_dataset(tokenizer, args.train_frac)

print("Saving train data...")
save_data(args.train_prefix, args.data_dir, train_dialogues, tokenizer)
print("Saving validation data...")
save_data(args.valid_prefix, args.data_dir, valid_dialogues, tokenizer)
print("#"*50 + "Analysis on total data" + "#"*50)
print(f"The number of train dialogues: {len(train_dialogues)}")
print(f"The number of valid dialogues: {len(valid_dialogues)}")
print(f"The number of train utterances: {num_train}")
print(f"The number of valid utterances: {num_valid}")

Loading & Merging all datasets...


100%|██████████| 17072/17072 [00:30<00:00, 564.21it/s]


Saving train data...
Saving train text file...
Saving train idx file...


100%|██████████| 14511/14511 [00:20<00:00, 698.19it/s]


Saving validation data...
Saving valid text file...
Saving valid idx file...


100%|██████████| 2561/2561 [00:03<00:00, 737.05it/s]


##################################################Analysis on total data##################################################
The number of train dialogues: 14511
The number of valid dialogues: 2561
The number of train utterances: 29022
The number of valid utterances: 5122


In [None]:
from torch.utils.data import Dataset
from tqdm import tqdm
from itertools import chain

import torch
import copy
import json


class CustomDataset(Dataset):
    def __init__(self, prefix, args):
        assert prefix == args.train_prefix or prefix == args.valid_prefix

        print(f"Loading {prefix}_id.json...")
        with open(f"{args.data_dir}/{prefix}_ids.json", 'r') as f:
            dials = json.load(f)

        self.input_ids = []  # (N, L)
        self.token_type_ids = []  # (N, L)
        self.labels = []  # (N, L)

        print(f"Processing {prefix} data...")

        # dial = 1 dialog = multiple turns.
        for dial in tqdm(dials):
            hists = []
            for u, utter in enumerate(dial):
                if u % 2 == 0:
                    hists.append([args.sp1_id] + utter)
                else:
                    hists.append([args.sp2_id] + utter)

            assert len(hists) == 2

            input_ids = [args.bos_id] + list(chain.from_iterable(hists)) + [args.eos_id]
            if len(input_ids) <= args.max_len:
                token_type_ids = [args.sp1_id] * len(hists[0]) + [args.sp2_id] * len(hists[1])
                token_type_ids = [args.sp1_id] + token_type_ids + [args.sp2_id]
                assert len(input_ids) == len(token_type_ids)

                labels = []
                for i, token_type_id in enumerate(token_type_ids):
                    if token_type_id == args.sp1_id:
                        labels.append(-100)
                    elif token_type_id == args.sp2_id and token_type_ids[i - 1] == args.sp1_id:
                        labels.append(-100)
                    else:
                        labels.append(input_ids[i])
                assert len(input_ids) == len(labels)
                self.input_ids.append(input_ids)
                self.token_type_ids.append(token_type_ids)
                self.labels.append(labels)

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx], self.token_type_ids[idx], self.labels[idx]


class PadCollate():
    def __init__(self, eos_id):
        self.eos_id = eos_id

    def pad_collate(self, batch):
        input_ids, token_type_ids, labels =[], [], []
        for idx, seqs in enumerate(batch):
            input_ids.append(torch.LongTensor(seqs[0]))
            token_type_ids.append(torch.LongTensor(seqs[1]))
            labels.append(torch.LongTensor(seqs[2]))

        input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=self.eos_id)
        token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=self.eos_id)
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

        return input_ids, token_type_ids, labels


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel, get_polynomial_decay_schedule_with_warmup
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
from itertools import chain

import torch
import os, sys
import numpy as np
import argparse
import copy
import math
import random


class Manager():
    def __init__(self, args):
        self.args = args

        if torch.cuda.is_available():
            self.args.device = torch.device(f"cuda:{self.args.gpu}")
        else:
            self.args.device = torch.device("mps")

        # Tokenizer & Vocab
        print("Loading the tokenizer...")
        self.tokenizer = GPT2Tokenizer.from_pretrained(self.args.model_type)
        special_tokens = {
            'bos_token': self.args.bos_token,
            'additional_special_tokens': [self.args.sp1_token, self.args.sp2_token]
        }
        self.args.eos_token = self.tokenizer.eos_token
        num_new_tokens = self.tokenizer.add_special_tokens(special_tokens)
        vocab = self.tokenizer.get_vocab()
        self.args.vocab_size = len(vocab)
        self.args.bos_id = vocab[self.args.bos_token]
        self.args.eos_id = vocab[self.args.eos_token]
        self.args.sp1_id = vocab[self.args.sp1_token]
        self.args.sp2_id = vocab[self.args.sp2_token]

        # Load model
        print("Loading the model...")
        self.fix_seed(self.args.seed)
        self.model = GPT2LMHeadModel.from_pretrained(self.args.model_type).to(self.args.device)
        self.model.resize_token_embeddings(self.args.vocab_size)

        self.args.max_len = min(self.args.max_len, self.model.config.n_ctx)

        if self.args.mode == 'train':
            # Load optimizer
            print("Loading the optimizer...")
            self.optim = torch.optim.AdamW(self.model.parameters(), lr=self.args.lr)
            self.best_loss = sys.float_info.max
            self.last_epoch = 0

            # Load train & valid dataset
            print("Loading train & valid data...")
            train_set = CustomDataset(self.args.train_prefix, self.args)
            valid_set = CustomDataset(self.args.valid_prefix, self.args)
            ppd = PadCollate(eos_id=self.args.eos_id)

            self.train_loader = DataLoader(train_set,
                                           collate_fn=ppd.pad_collate,
                                           shuffle=True,
                                           batch_size=self.args.batch_size,
                                           num_workers=self.args.num_workers,
                                           pin_memory=True)
            self.valid_loader = DataLoader(valid_set,
                                           collate_fn=ppd.pad_collate,
                                           batch_size=self.args.batch_size,
                                           num_workers=self.args.num_workers,
                                           pin_memory=True)

            if not os.path.exists(self.args.ckpt_dir):
                os.makedirs(self.args.ckpt_dir)

            # Calculate total training steps
            num_batches = len(self.train_loader)
            args.total_train_steps = args.num_epochs * num_batches
            args.warmup_steps = int(args.warmup_ratio * args.total_train_steps)

            self.sched = get_polynomial_decay_schedule_with_warmup(
                self.optim,
                num_warmup_steps=args.warmup_steps,
                num_training_steps=args.total_train_steps,
                power=2
            )

            self.writer = SummaryWriter()

        if self.args.ckpt_name is not None:
            ckpt_path = f"{self.args.ckpt_dir}/{self.args.ckpt_name}.ckpt"
            if os.path.exists(ckpt_path):
                print("Loading the trained checkpoint...")
                ckpt = torch.load(ckpt_path, map_location=self.args.device)
                self.model.load_state_dict(ckpt['model_state_dict'])

                if self.args.mode == 'train':
                    print(f"The training restarts with the specified checkpoint: {self.args.ckpt_name}.ckpt.")
                    self.optim.load_state_dict(ckpt['optim_state_dict'])
                    self.sched.load_state_dict(ckpt['sched_state_dict'])
                    self.best_loss = ckpt['loss']
                    self.last_epoch = ckpt['epoch']
                else:
                    print("The inference will start with the specified checkpoint.")
            else:
                print(f"Cannot fine the specified checkpoint {ckpt_path}.")
                if self.args.mode == 'train':
                    print("Training will start with the initialized model.")
                else:
                    print("Cannot inference.")
                    exit()

        print("Setting finished.")

    def train(self):
        self.fix_seed(self.args.seed)  # Fix seed before training
        print("Training starts.")

        start_epoch = self.last_epoch+1
        for epoch in range(start_epoch, start_epoch+self.args.num_epochs):
            self.model.train()

            print(f"#"*50 + f"Epoch: {epoch}" + "#"*50)
            train_losses = []
            train_ppls = []
            for i, batch in enumerate(tqdm(self.train_loader)):
                input_ids, token_type_ids, labels = batch
                input_ids, token_type_ids, labels = \
                    input_ids.to(self.args.device), token_type_ids.to(self.args.device), labels.to(self.args.device)

                outputs = self.model(
                    input_ids=input_ids,
                    token_type_ids = token_type_ids,
                    labels = labels
                )

                loss, logits = outputs[0], outputs[1]

                self.optim.zero_grad()
                loss.backward()
                self.optim.step()
                self.sched.step()

                train_losses.append(loss.detach())
                ppl = torch.exp(loss.detach())
                train_ppls.append(ppl)

            train_losses = [loss.item() for loss in train_losses]
            train_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in train_ppls]
            train_loss = np.mean(train_losses)
            train_ppl = np.mean(train_ppls)
            print(f"Train loss: {train_loss} || Train perplexity: {train_ppl}")

            self.writer.add_scalar("Loss/train", train_loss, epoch)
            self.writer.add_scalar("PPL/train", train_ppl, epoch)

            self.last_epoch += 1

            valid_loss, valid_ppl = self.validation()

            if valid_loss < self.best_loss:
                self.best_loss = valid_loss
                state_dict = {
                    'model_state_dict': self.model.state_dict(),
                    'optim_state_dict': self.optim.state_dict(),
                    'sched_state_dict': self.sched.state_dict(),
                    'loss': self.best_loss,
                    'epoch': self.last_epoch
                }

                torch.save(state_dict, f"{self.args.ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(self.best_loss, 4)}.ckpt")
                print("*"*10 + "Current best checkpoint is saved." + "*"*10)
                print(f"{self.args.ckpt_dir}/best_ckpt_epoch={epoch}_valid_loss={round(self.best_loss, 4)}.ckpt")

            print(f"Best valid loss: {self.best_loss}")
            print(f"Valid loss: {valid_loss} || Valid perplexity: {valid_ppl}")

            self.writer.add_scalar("Loss/valid", valid_loss, epoch)
            self.writer.add_scalar("PPL/valid", valid_ppl, epoch)

            self.writer.add_scalars("Losses", {
                'train': train_loss,
                'valid': valid_loss,
            }, epoch)
            self.writer.add_scalars("PPLs", {
                'train': train_ppl,
                'valid': valid_ppl,
            }, epoch)

        print("Training finished!")

    def validation(self):
        print("Validation processing...")
        self.model.eval()

        valid_losses = []
        valid_ppls = []
        with torch.no_grad():
            for i, batch in enumerate(tqdm(self.valid_loader)):
                input_ids, token_type_ids, labels = batch
                input_ids, token_type_ids, labels = \
                    input_ids.to(self.args.device), token_type_ids.to(self.args.device), labels.to(self.args.device)

                outputs = self.model(
                    input_ids=input_ids,
                    token_type_ids = token_type_ids,
                    labels = labels
                )

                loss, logits = outputs[0], outputs[1]

                valid_losses.append(loss.detach())
                ppl = torch.exp(loss.detach())
                valid_ppls.append(ppl)

            valid_losses = [loss.item() for loss in valid_losses]
            valid_ppls = [ppl.item() if not math.isinf(ppl.item()) else 1e+8 for ppl in valid_ppls]
            valid_loss = np.mean(valid_losses)
            valid_ppl = np.mean(valid_ppls)

            if math.isnan(valid_ppl):
                valid_ppl = 1e+8

        return valid_loss, valid_ppl


    def infer(self):
        print("Let's start!")
        print(f"If you want to quit the conversation, please type \"{self.args.end_command}\".")
        self.model.eval()
        self.fix_seed(self.args.seed)

        with torch.no_grad():
            question = input("Question: ")
            input_ids = [self.args.sp1_id] + self.tokenizer.encode(question)
            input_ids = [self.args.bos_id] + input_ids + [self.args.sp2_id]
            token_type_ids = [self.args.sp1_id] * (len(input_ids) - 1) + [self.args.sp2_id]
            assert len(input_ids) == len(token_type_ids)
            input_len = len(input_ids)

            input_ids = torch.LongTensor(input_ids).unsqueeze(0).to(self.args.device)
            token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0).to(self.args.device)

            output_ids = self.nucleus_sampling(input_ids, token_type_ids, input_len)
            res = self.tokenizer.decode(output_ids, skip_special_tokens=True)

            print("Answer: {}".format(res))

    def nucleus_sampling(self, input_ids, token_type_ids, input_len):
        output_ids = []
        for pos in range(input_len, self.args.max_len):
            output = self.model(input_ids=input_ids, token_type_ids=token_type_ids)  # (1, V)
            output = output[0][:, pos-1]
            output = F.softmax(output, dim=-1)  # (1, V)

            sorted_probs, sorted_idxs = torch.sort(output, descending=True)
            cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
            idx_remove = cumsum_probs > self.args.top_p
            idx_remove[:, 1:] = idx_remove[:, :-1].clone()
            idx_remove[:, 0] = False
            sorted_probs[idx_remove] = 0.0
            sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)

            probs = torch.zeros(output.shape, device=self.args.device).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
            idx = torch.multinomial(probs, 1)  # (1, 1)

            idx_item = idx.squeeze(-1).squeeze(-1).item()
            output_ids.append(idx_item)

            if idx_item == self.args.eos_id:
                break

            input_ids = torch.cat((input_ids, idx), dim=-1)
            next_type_id = torch.LongTensor([[self.args.sp2_id]]).to(self.args.device)
            token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
            assert input_ids.shape == token_type_ids.shape

            break

        return output_ids

    def fix_seed(self, seed):
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        random.seed(seed)

In [None]:
args = TrainArgs()
args.data_dir = f"{args.data_dir}/{args.model_type}"
args.ckpt_dir = f"{args.ckpt_dir}/{args.model_type}"
manager = Manager(args)
manager.train()

Loading the tokenizer...
Loading the model...
Loading the optimizer...
Loading train & valid data...
Loading train_id.json...
Processing train data...


100%|██████████| 14511/14511 [00:01<00:00, 9396.90it/s] 


Loading valid_id.json...
Processing valid data...


 86%|████████▌ | 2202/2561 [00:00<00:00, 11641.12it/s]


KeyboardInterrupt: ignored

In [None]:
args = InferArgs()
args.ckpt_name = "gpt2/best_ckpt_epoch=3_valid_loss=3.258"
manager = Manager(args)
manager.infer()

Loading the tokenizer...
Loading the model...
Loading the trained checkpoint...
The inference will start with the specified checkpoint.
Setting finished.
Let's start!
If you want to quit the conversation, please type "Abort!".
Question: Should i have a business credit card?
tensor([[1532]], device='cuda:0')
1532
Answer: If
