In [1]:
from datasets import load_dataset
from transformers import    BertForQuestionAnswering,\
                            BertTokenizer,\
                            get_scheduler
from torch.optim import AdamW
import transformers
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
from collections import Counter
import string
import re

transformers.logging.set_verbosity_error()

In [2]:
class Data():
    def __init__(self, model_name= "bert-base-uncased", batch_size=128):
        self.model_name = model_name
        self.batch_size = batch_size
        self.raw_data = load_dataset('squad')
        self.tokenizer = BertTokenizer.from_pretrained(self.model_name)
        self.special_tokens_dict = {'additional_special_tokens': ["[ANS_START]","[ANS_END]"]}
        self.tokenizer.add_special_tokens(self.special_tokens_dict)

    def examine_raw_dataset(self):
        print("Dataset structure:")
        print(self.raw_data)
        print("First training example:")
        print(self.raw_data["train"][0])
    
    def _add_end_idx(self, answers, contexts):
        new_answers = []

        for answer, context in zip(answers, contexts):
            # quick reformating to remove lists
            answer['text'] = answer['text'][0]
            answer['answer_start'] = answer['answer_start'][0]
            # gold_text refers to the answer we are expecting to find in context
            gold_text = answer['text']
            # we already know the start index
            start_idx = answer['answer_start']
            # and ideally this would be the end index...
            end_idx = start_idx + len(gold_text)

            # ...however, sometimes squad answers are off by a character or two
            if context[start_idx:end_idx] == gold_text:
                # if the answer is not off :)
                answer['answer_end'] = end_idx
            else:
                # this means the answer is off by 1-2 tokens
                for n in [1, 2]:
                    if context[start_idx-n:end_idx-n] == gold_text:
                        answer['answer_start'] = start_idx - n
                        answer['answer_end'] = end_idx - n
            new_answers.append(answer)
        return new_answers

    def _generate_training_lists(self, dataset):
        contexts = dataset["context"]
        questions = dataset["question"]
        answers =  self._add_end_idx(dataset["answers"], contexts)
        return contexts, questions, answers
    
    def _generate_validation_lists(self, dataset):
        contexts = dataset["context"]
        questions = dataset["question"]
        answers = [answer["text"] for answer in dataset["answers"]]
        return contexts, questions, answers

    def _add_answers_special_tokens(self, contexts, answers):
        for idx, context in enumerate(contexts):
            start_index = answers[idx]["answer_start"]
            first_part = context[:start_index]
            second_part = context[start_index:]
            context = first_part + "[ANS_START] " + second_part

            shift = 12
            while True:
                end_index = answers[idx]["answer_end"] + shift
                try:
                    if context[end_index] == " ":
                        first_part = context[:end_index]
                        second_part = context[end_index:]
                        context = context = first_part + " [ANS_END]" + second_part
                        break
                    elif context[end_index] != " ":
                        shift += 1
                except IndexError:
                    context += " [ANS_END]"
                    break

            contexts[idx] = context
    
    def _get_answer_tokens(self, dataset, ground_truths=None):
        start_positions, end_positions = [], []
        counter = 0

        device = torch.device("cuda")
        dataset['input_ids'] = dataset['input_ids'].to(device)
        dataset['token_type_ids'] = dataset['token_type_ids'].to(device)
        dataset['attention_mask'] = dataset['attention_mask'].to(device)
        zero_tensor = torch.zeros(1).to(device)

        for i in tqdm(range(len(dataset['input_ids']))):

            doc_i = i - counter

            special_token_start_pos = torch.nonzero(dataset['input_ids'][doc_i] == torch.tensor(30522)).flatten()
            special_token_end_pos = torch.nonzero(dataset['input_ids'][doc_i] == torch.tensor(30523)).flatten()

            answer_start_pos = special_token_start_pos
            answer_end_pos = special_token_end_pos - 2

            if len(special_token_start_pos) == 0 or len(special_token_end_pos) == 0:
                counter += 1
                dataset['input_ids'] = torch.cat((dataset['input_ids'][:doc_i,:],
                                                    dataset['input_ids'][doc_i+1:,:]))
                dataset['token_type_ids'] = torch.cat((dataset['token_type_ids'][:doc_i,:],
                                                    dataset['token_type_ids'][doc_i+1:,:]))
                dataset['attention_mask'] = torch.cat((dataset['attention_mask'][:doc_i,:],
                                                    dataset['attention_mask'][doc_i+1:,:]))
                if ground_truths is not None:
                    ground_truths.pop(doc_i)
            else:
                start_positions.append(answer_start_pos)
                end_positions.append(answer_end_pos)

                dataset['input_ids'][doc_i] = torch.cat((dataset['input_ids'][doc_i][:special_token_end_pos],
                                                            dataset['input_ids'][doc_i][special_token_end_pos+1:],
                                                            zero_tensor))
                dataset['input_ids'][doc_i] = torch.cat((dataset['input_ids'][doc_i][:special_token_start_pos],
                                                            dataset['input_ids'][doc_i][special_token_start_pos+1:],
                                                            zero_tensor))
                dataset['token_type_ids'][doc_i] = torch.cat((dataset['token_type_ids'][doc_i][:special_token_end_pos],
                                                            dataset['token_type_ids'][doc_i][special_token_end_pos+1:],
                                                            zero_tensor))
                dataset['token_type_ids'][doc_i] = torch.cat((dataset['token_type_ids'][doc_i][:special_token_start_pos],
                                                            dataset['token_type_ids'][doc_i][special_token_start_pos+1:],
                                                            zero_tensor))
                dataset['attention_mask'][doc_i] = torch.cat((dataset['attention_mask'][doc_i][:special_token_end_pos],
                                                            dataset['attention_mask'][doc_i][special_token_end_pos+1:],
                                                            zero_tensor))
                dataset['attention_mask'][doc_i] = torch.cat((dataset['attention_mask'][doc_i][:special_token_start_pos],
                                                            dataset['attention_mask'][doc_i][special_token_start_pos+1:],
                                                            zero_tensor))
        
        device = torch.device("cpu")
        dataset['input_ids'] = dataset['input_ids'].to(device)
        dataset['token_type_ids'] = dataset['token_type_ids'].to(device)
        dataset['attention_mask'] = dataset['attention_mask'].to(device)

        torch.cuda.empty_cache()
                
        return start_positions, end_positions

    def _create_tensor_dataset(self, dataset):
        if set(["start_positions","end_positions"]).issubset(dataset.keys()):
            return TensorDataset(dataset["input_ids"], dataset["token_type_ids"],
                                dataset["attention_mask"], dataset["start_positions"],
                                dataset["end_positions"])
        else:
            return TensorDataset(dataset["input_ids"], dataset["token_type_ids"],
                                dataset["attention_mask"])

    def create_train_dataloader(self, show_info=False):
        contexts_train, questions_train, answers_train = self._generate_training_lists(self.raw_data["train"])

        self._add_answers_special_tokens(contexts_train, answers_train)

        train_dataset = self.tokenizer(contexts_train, questions_train,
                                        truncation=True, padding='max_length',
                                        max_length=512, return_tensors='pt')
        
        if show_info:
            print("Training dataset key after tokenizing:")
            print(train_dataset.keys())
            print("First example in training dataset after tokenizing:")
            print(self.tokenizer.decode(train_dataset['input_ids'][0])[:855])

            print("[ANS_START] tokenizer token ID: ", self.tokenizer("[ANS_START]")["input_ids"][1])
            print("[ANS_END] tokenizer token ID: ", self.tokenizer("[ANS_END]")["input_ids"][1])
            print("[PAD] tokenizer token ID: ", self.tokenizer("[PAD]")["input_ids"][1])
        
        start_positions, end_positions = self._get_answer_tokens(train_dataset)
        train_dataset["start_positions"] = torch.tensor(start_positions)
        train_dataset["end_positions"] = torch.tensor(end_positions)

        if show_info:
            print("Training dataset key after adding answers start and end positions:")
            print(train_dataset.keys())
            print("Training tensor types and sizes:")
            for key, value in train_dataset.items():
                print(key, type(value), value.size())
        
        train_dataset = self._create_tensor_dataset(train_dataset)

        train_dataloader = DataLoader(train_dataset,
                                batch_size = self.batch_size,
                                shuffle = True)

        return train_dataloader

    def create_val_dataloader(self):

        contexts_validation, questions_validation, validation_ground_truths = self._generate_validation_lists(self.raw_data["validation"])

        validation_dataset = self.tokenizer(questions_validation, contexts_validation,
                                                truncation=True, padding=True,
                                                max_length=512, return_tensors='pt')
        
        validation_dataset = self._create_tensor_dataset(validation_dataset)

        validation_dataloader = DataLoader(validation_dataset,
                                        batch_size = self.batch_size,
                                        shuffle = False)

        return validation_dataloader, validation_ground_truths

In [3]:
class Model():
    def __init__(self, model_name="bert-base-uncased", num_epochs=3, length_dataloader=0):
        self.model = BertForQuestionAnswering.from_pretrained(
            model_name
        )
        self.tokenizer = BertTokenizer.from_pretrained(
            model_name
        )
        self.optimizer = AdamW(self.model.parameters(),lr = 2e-5)
        self.num_epochs = num_epochs
        self.num_training_steps = self.num_epochs * length_dataloader
        self.lr_scheduler = get_scheduler(
                                    "linear",
                                    optimizer=self.optimizer,
                                    num_warmup_steps=0,
                                    num_training_steps=self.num_training_steps
                                    )
        self.device = self._get_device()
        self.model = self.model.to(self.device)

    def _get_device(self, show_info=False):
        if torch.cuda.is_available():    
            device = torch.device("cuda")

            if show_info:
                print('There are %d GPU(s) available.' % torch.cuda.device_count())
                print('We will use the GPU:', torch.cuda.get_device_name(0))

        else:
            device = torch.device("cpu")

            if show_info:
                print('No GPU available, using the CPU instead.')

        return device

    def load_model_state_dict(self, path="models_lisandro/BERT_model_state_dict.pt"):
        self.model.load_state_dict(torch.load(path))
    
    def save_model_state_dict(self, path="models_lisandro/BERT_model_state_dict.pt"):
        torch.save(self.model.state_dict(), path)
    
    def train(self, train_dataloader):
        self.model.train()

        total_train_loss = 0

        for batch in tqdm(train_dataloader):
            self.model.zero_grad()
            parameters = {
                "input_ids" : batch[0].to(self.device),
                "token_type_ids": batch[1].to(self.device),
                "attention_mask" :  batch[2].to(self.device), 
                "start_positions" : batch[3].to(self.device),
                "end_positions" : batch[4].to(self.device),
            }
            outputs = self.model(**parameters)

            loss = outputs.loss
            total_train_loss += loss.item()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

            self.optimizer.step()
            self.lr_scheduler.step()
            self.optimizer.zero_grad()

        return {"training_loss" : total_train_loss/len(train_dataloader)}
    
    def _normalize_answer(self, s):
        """Lower text and remove punctuation, articles and extra whitespace."""
        def remove_articles(text):
            return re.sub(r'\b(a|an|the)\b', ' ', text)

        def white_space_fix(text):
            return ' '.join(text.split())

        def remove_punc(text):
            exclude = set(string.punctuation)
            return ''.join(ch for ch in text if ch not in exclude)

        def lower(text):
            return text.lower()

        return white_space_fix(remove_articles(remove_punc(lower(s))))

    def _f1_score(self, prediction, ground_truth):
        prediction_tokens = self._normalize_answer(prediction).split()
        ground_truth_tokens = self._normalize_answer(ground_truth).split()
        common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
        num_same = sum(common.values())
        if num_same == 0:
            return 0
        precision = 1.0 * num_same / len(prediction_tokens)
        recall = 1.0 * num_same / len(ground_truth_tokens)
        f1 = (2 * precision * recall) / (precision + recall)
        return f1

    def _exact_match_score(self, prediction, ground_truth):
        return (1 if self._normalize_answer(prediction) == self._normalize_answer(ground_truth) else 0)

    def _metric_max_over_ground_truths(self, metric_fn, prediction, ground_truths):
        scores_for_ground_truths = []
        for ground_truth in ground_truths:
            score = metric_fn(prediction, ground_truth)
            scores_for_ground_truths.append(score)
        return max(scores_for_ground_truths)

    def _evaluate(self, ground_truths, predictions):
        """ Official evaluation script for v1.1 of the SQuAD dataset. """
        f1 = exact_match = total = 0

        for ground_truths_list, prediction in zip(ground_truths, predictions):
            total += 1
            exact_match += self._metric_max_over_ground_truths(
                            self._exact_match_score, prediction, ground_truths_list)
            f1 += self._metric_max_over_ground_truths(
                        self._f1_score, prediction, ground_truths_list)
        
        return {'batch_exact_match': exact_match, 'batch_f1': f1, 'batch_total': total}
    
    def predict(self, val_dataloader, val_groundtruth):
        self.model.eval()

        f1 = exact_match = total = 0

        for idx_batch, batch in tqdm(enumerate(val_dataloader)):
            parameters = {
                "input_ids" : batch[0].to(self.device),
                "token_type_ids": batch[1].to(self.device),
                "attention_mask" :  batch[2].to(self.device),
            }
            with torch.no_grad():
                outputs = self.model(**parameters)

            start_scores = outputs.start_logits
            end_scores = outputs.end_logits

            #get best answer
            answer_start = torch.argmax(start_scores, dim=-1)
            # Get the most likely end of answer with the argmax of the score
            answer_end = torch.argmax(end_scores, dim=-1) +1

            predictions_ids = (tensor[answer_start[idx_tensor]:answer_end[idx_tensor]] 
                                for idx_tensor, tensor in enumerate(parameters["input_ids"]))
            predictions = [self.tokenizer.decode(prediction_id) for prediction_id in predictions_ids]

            ground_truths = val_groundtruth[idx_batch*BATCH_SIZE:(idx_batch+1)*BATCH_SIZE]

            results = self._evaluate(ground_truths, predictions)
            f1 += results["batch_f1"]
            exact_match += results["batch_exact_match"]
            total += results["batch_total"]
        
        return { "val_f1" : 100.0 * f1 / total,
                "val_exact_match" : 100.0 * exact_match / total}

### Fine-tune BERT from scratch for QA (train)

In [4]:
MODEL_NAME = "bert-base-uncased"
NUM_EPOCHS = 1
BATCH_SIZE = 4

In [5]:
data_class = Data(MODEL_NAME, BATCH_SIZE)

train_dataloader = data_class.create_train_dataloader(show_info=True)

Reusing dataset squad (C:\Users\Lisandro\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
100%|██████████| 2/2 [00:00<00:00, 40.00it/s]


Training dataset key after tokenizing:
dict_keys(['input_ids', 'token_type_ids', 'attention_mask'])
First example in training dataset after tokenizing:
[CLS] architecturally, the school has a catholic character. atop the main building's gold dome is a golden statue of the virgin mary. immediately in front of the main building and facing it, is a copper statue of christ with arms upraised with the legend " venite ad me omnes ". next to the main building is the basilica of the sacred heart. immediately behind the basilica is the grotto, a marian place of prayer and reflection. it is a replica of the grotto at lourdes, france where the virgin mary reputedly appeared to [ANS_START] saint bernadette soubirous [ANS_END] in 1858. at the end of the main drive ( and in a direct line that connects through 3 statues and the gold dome ), is a simple, modern stone statue of mary. [SEP] to whom did the virgin mary allegedly appear in 1858 in lourdes france? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] 

100%|██████████| 87599/87599 [02:49<00:00, 516.42it/s]


Training dataset key after adding answers start and end positions:
dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])
Training tensor types and sizes:
input_ids <class 'torch.Tensor'> torch.Size([87589, 512])
token_type_ids <class 'torch.Tensor'> torch.Size([87589, 512])
attention_mask <class 'torch.Tensor'> torch.Size([87589, 512])
start_positions <class 'torch.Tensor'> torch.Size([87589])
end_positions <class 'torch.Tensor'> torch.Size([87589])


In [27]:
model_class = Model(MODEL_NAME, NUM_EPOCHS, len(train_dataloader))

training_stats = []

torch.cuda.empty_cache()

for epoch in range(NUM_EPOCHS):
    metrics_dict = model_class.train(train_dataloader)

    training_stats.append({
                "epoch":epoch+1,
                "training_loss": metrics_dict["training_loss"]
                })

model_class.save_model_state_dict()

training_stats

BATCH 100/21898:	Training loss(4.972761631011963)
BATCH 200/21898:	Training loss(3.631061553955078)
BATCH 300/21898:	Training loss(3.799131155014038)
BATCH 400/21898:	Training loss(3.615633964538574)
BATCH 500/21898:	Training loss(3.315737724304199)
BATCH 600/21898:	Training loss(1.8218226432800293)
BATCH 700/21898:	Training loss(3.2214529514312744)
BATCH 800/21898:	Training loss(2.0947859287261963)
BATCH 900/21898:	Training loss(2.3019957542419434)
BATCH 1000/21898:	Training loss(4.450675964355469)
BATCH 1100/21898:	Training loss(2.5524730682373047)
BATCH 1200/21898:	Training loss(1.5645030736923218)
BATCH 1300/21898:	Training loss(1.3093781471252441)
BATCH 1400/21898:	Training loss(1.8533928394317627)
BATCH 1500/21898:	Training loss(2.514796018600464)
BATCH 1600/21898:	Training loss(1.4977751970291138)
BATCH 1700/21898:	Training loss(1.7563726902008057)
BATCH 1800/21898:	Training loss(2.258014678955078)
BATCH 1900/21898:	Training loss(4.301990509033203)
BATCH 2000/21898:	Training los

[{'epoch': 1, 'training_loss': 1.3018079750173694}]

### Evaluate fine-tuned BERT model (validation)

In [None]:
MODEL_NAME = "bert-base-uncased"
NUM_EPOCHS = 1
BATCH_SIZE = 128

In [None]:
data_class = Data(MODEL_NAME, BATCH_SIZE)

validation_dataloader, validation_ground_truths = data_class.create_val_dataloader()

Reusing dataset squad (C:\Users\Lisandro\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
100%|██████████| 2/2 [00:00<00:00, 250.07it/s]
100%|██████████| 87599/87599 [02:57<00:00, 492.50it/s]


In [50]:
model_class = Model(MODEL_NAME)
model_class.load_model_state_dict()

torch.cuda.empty_cache()

validation_stats = model_class.predict(validation_dataloader, validation_ground_truths)

validation_stats

83it [05:16,  3.82s/it]


{'val_f1': 80.06160313355913, 'val_exact_match': 69.66887417218543}

### Evaluate pretrained BERT model for QA (from Huggingface)

In [4]:
MODEL_NAME = "bert-large-uncased-whole-word-masking-finetuned-squad"
BATCH_SIZE = 64

In [5]:
data_class = Data(MODEL_NAME, BATCH_SIZE)

validation_dataloader, validation_ground_truths = data_class.create_val_dataloader()

Reusing dataset squad (C:\Users\Lisandro\.cache\huggingface\datasets\squad\plain_text\1.0.0\d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)
100%|██████████| 2/2 [00:00<00:00, 41.71it/s]


In [6]:
model_class = Model(MODEL_NAME)

torch.cuda.empty_cache()

validation_stats = model_class.predict(validation_dataloader, validation_ground_truths)

validation_stats

166it [17:22,  6.28s/it]


{'val_f1': 86.4149533129056, 'val_exact_match': 77.918637653737}