# Nikolas Iliopoulos 1115201800332
# AI_2 Part 2

# Install Tranfsormers

In [None]:
!pip install transformers
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json -O train-v2.0.json
!wget https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json -O dev-v2.0.json
from transformers import AutoTokenizer
from transformers import DistilBertForQuestionAnswering

import torch
from torch.utils.data import TensorDataset
from torch.optim import AdamW

import numpy as np
import pandas as pd

from tqdm import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score
from sklearn.metrics import recall_score
from sklearn.metrics import precision_score
from sklearn.metrics import roc_curve,roc_auc_score, auc

# Functions for computing F1

In [None]:
def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()

    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)

    common_tokens = set(pred_tokens) & set(truth_tokens)

    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0

    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)

    return 2 * (prec * rec) / (prec + rec)

def normalize_text(s):
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

# Define Trainer

In [None]:
# reading the training json
data = pd.read_json('./train-v2.0.json')
# reading the validation json
data_Val = pd.read_json('./dev-v2.0.json')

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, answers):
        self.encodings = encodings
        self.answers = answers
        
    def __getitem__(self, index):
        item = {}
        for key, val in self.encodings.items():
            item[key] = torch.tensor(val[index])
        item['answer'] = self.answers[index]
        return item

    def __len__(self):
        return len(self.answers)

class MyTrainer():
    def __init__(self, lr, epochs, batchSize):
        print("__init__()")
        self.lr = lr
        self.epochs = epochs
        self.batchSize = batchSize
        self.maxLength = 512
        self.model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
        self.optimizer = AdamW(self.model.parameters(),
                        lr=self.lr,
                        eps=1e-8)
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        print("maxLength=",self.maxLength)
        print("Device=",self.device)
        if self.use_cuda:
            # LOAD THE MODEL TO THE GPU
            self.model = self.model.cuda()
            print('GPU Model', torch.cuda.get_device_name(0))

    def processData(self, data):
        print("processData()")
        self.texts, self.queries, self.answers = [], [], []

        for d in data['data']:
            for paragraph in d['paragraphs']:
                context = paragraph['context']
                for qas in paragraph['qas']:
                    question = qas['question']
                    for answer in qas['answers']:
                        self.texts.append(context)
                        self.queries.append(question)
                        self.answers.append(answer)

        for answer, text in zip(self.answers,self.texts):
            indexStart = answer['answer_start']
            indexEnd = indexStart + len(answer['text'])
            
            if text[indexStart:indexEnd] == answer['text']:
                answer['answer_end']   = indexEnd
            elif text[ indexStart-1 : indexEnd-1 ] == answer['text']:
                answer['answer_start'] = indexStart - 1
                answer['answer_end']   = indexEnd   - 1  
            elif text[ indexStart-2 : indexEnd-2 ] == answer['text']:
                answer['answer_start'] = indexStart - 2
                answer['answer_end']   = indexEnd   - 2  
        
    def prepareInputs(self):
        print("prepareInputs()")

        out = self.tokenizer(self.texts,
                                self.queries,
                                truncation=True,
                                padding=True)
                                
        startIndex, endIndex = [], []

        count = 0

        for i in range(len(self.answers)):
          startIndex.append(out.char_to_token(i, self.answers[i]['answer_start']))
          endIndex.append(out.char_to_token(i, self.answers[i]['answer_end']))

          # if start position is None, the answer passage has been truncated
          if startIndex[-1] is None:
            startIndex[-1] = self.tokenizer.model_max_length
            
          # if end position is None, the 'char_to_token' function points to the space after the correct token, so add - 1
          if endIndex[-1] is None:
            endIndex[-1] = out.char_to_token(i, self.answers[i]['answer_end'] - 1)
            # if end position is still None the answer passage has been truncated
            if endIndex[-1] is None:
              count += 1
              endIndex[-1] = self.tokenizer.model_max_length

        # Update the data in dictionary
        out.update({'startIndex': startIndex, 'endIndex': endIndex})

        self.dataloader_train = torch.utils.data.DataLoader(MyDataset(out,self.answers), batch_size=self.batchSize, shuffle=True)
    def train(self):
         totalSteps = len(self.dataloader_train)
         n_epoch=0
         self.iters = []
         self.losses = []

         self.model.train()
         for epoch in range(self.epochs):
            batch_losses = []

            # TRAINING
            for batch in tqdm(self.dataloader_train):
                # LOAD THE DATA TO THE GPU
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                startIndex  = batch['startIndex'].to(self.device)
                endIndex   = batch['endIndex'].to(self.device)
                answer = batch['answer']['text']
      
                self.model.zero_grad()

                outputs = self.model(input_ids, attention_mask=attention_mask, start_positions=startIndex,end_positions=endIndex)
                loss = outputs.loss
                
                # store the loss of that batch
                batch_losses.append(loss.item())
                # BackwardPropagation to calculate the gradients.
                loss.backward()
                # Clip Gradients
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                # Update parameters
                self.optimizer.step()

                n_epoch+=1

            self.losses.append(sum(batch_losses)/len(self.dataloader_train))
            self.iters.append(n_epoch)

            print("Loss:    ", sum(batch_losses)/len(self.dataloader_train))
    def validate(self,data_Val):
        self.model.eval()
        self.f1_Val = []
        self.val_iters = 0
        for group in tqdm(data_Val['data']):
            for passage in group['paragraphs']:
                context = passage['context']
                for qa in passage['qas']:
                    question = qa['question']
                    for answer in qa['answers']:
#                         text - > context
                        tokens = self.tokenizer.encode_plus(context, question, truncation=True, return_tensors = 'pt')
                        tokens.to(self.device)
                        outputs = self.model(**tokens)
                        startIndex_pred = torch.argmax(outputs.start_logits)
                        endIndex_pred = torch.argmax(outputs.end_logits) + 1
                        answer_pred = self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(tokens['input_ids'][0][startIndex_pred:endIndex_pred]))
                        f1 = compute_f1(answer_pred,answer['text'])
                        self.f1_Val.append(f1)
                        self.val_iters += 1
                        
                        
    def plot(self):
        plt.title("Learning Curve")
        plt.plot(self.iters, self.losses, label="Train Loss")
        plt.xlabel("Iteration(Epoch)")
        plt.ylabel("Loss")
        plt.legend(loc='best')
        plt.show()
        print('F1 ',sum(self.f1_Val)/self.val_iters)


# Train & Plot

In [None]:
trainer = MyTrainer(2e-5, 2, 16)
trainer.processData(data)
trainer.prepareInputs()
trainer.train()
trainer.validate(data_Val)
trainer.plot()