Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Question Answering on the SQuAD Dataset using BERT
## Summary
This notebook demonstrates how to fine tune [pretrained BERT model](https://github.com/huggingface/pytorch-transformers) for extractive question answering task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation. 

BERT[\[1\]](#References) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.  
The figure below illustrates how BERT can be fine tuned for extractive question answering task. The question and paragraph tokens are concatenated as a single input token sequence with a special token [SEP] between them. For the paragraph tokens, BERT predicts the probabilities of each token being the start and end of the answer span. The tokens with the highest sum of starting probability and ending probability define the span of the predicted answer

<img src="https://nlpbp.blob.core.windows.net/images/bert_qa.PNG">

In [1]:
from datetime import datetime
startTime = datetime.now()

In [2]:
import os
import sys

import torch
import numpy as np

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.squad import load_pandas_df
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.models.bert.question_answering import BERTQAExtractor
from utils_nlp.models.bert.qa_utils import postprocess_answers, evaluate_qa
from utils_nlp.common.timer import Timer

## Configurations

In [3]:
SQUAD_VERSION = "v1.1" 
CACHE_DIR = "./temp"

LANGUAGE = Language.ENGLISHLARGEWWM
DO_LOWER_CASE = True

MAX_SEQ_LENGTH = 384
NUM_EPOCHS = 2
BATCH_SIZE = 8
LEARNING_RATE = 3e-5
WARMUP = 0.1

DOC_TEXT_COL = "doc_text"
QUESTION_TEXT_COL = "question_text"
ANSWER_START_COL = "answer_start"
ANSWER_TEXT_COL = "answer_text"
QA_ID_COL = "qa_id"
IS_IMPOSSIBLE_COL = "is_impossible"

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if  torch.cuda.device_count() > 0:
    torch.cuda.manual_seed_all(RANDOM_SEED)

## Load Data

### The SQuAD Dataset
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. [\[2, 3\]](#References)

<img src="https://nlpbp.blob.core.windows.net/images/squad.png">

There has been two versions of SQuAD datasets. SQuAD 1.1 contains 100,000+ question-answer pairs on 500+ articles. SQuAD 2.0 adds 50,000 new, unanswerable questions written adversarially by crowdworkers to look similar to answerable ones. These datasets are available at [https://rajpurkar.github.io/SQuAD-explorer/](https://rajpurkar.github.io/SQuAD-explorer/). Each dataset comes with a training dataset and a development dataset. 


The utility function `load_pandas_df` downloads the dataset specified by `squad_version` and `file_split` to `local_cache_path` if it doesn't exist already.

In [4]:
train_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="train")
dev_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="dev")

In [4]:
train_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,515,Saint Bernadette Soubirous,5733be284776f41900661182,False
1,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,188,a copper statue of Christ,5733be284776f4190066117f,False
2,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,279,the Main Building,5733be284776f41900661180,False
3,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,381,a Marian place of prayer and reflection,5733be284776f41900661181,False
4,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,92,a golden statue of the Virgin Mary,5733be284776f4190066117e,False


In [6]:
dev_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ec,False
1,Super Bowl 50 was an American football game to...,Which NFL team represented the NFC at Super Bo...,"[249, 249, 249]","[Carolina Panthers, Carolina Panthers, Carolin...",56be4db0acb8001400a502ed,False
2,Super Bowl 50 was an American football game to...,Where did Super Bowl 50 take place?,"[403, 355, 355]","[Santa Clara, California, Levi's Stadium, Levi...",56be4db0acb8001400a502ee,False
3,Super Bowl 50 was an American football game to...,Which NFL team won Super Bowl 50?,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ef,False
4,Super Bowl 50 was an American football game to...,What color was used to emphasize the 50th anni...,"[488, 488, 521]","[gold, gold, gold]",56be4db0acb8001400a502f0,False


## Tokenize and Preprocess Data

In [25]:
tokenizer = Tokenizer(language=LANGUAGE, to_lower=DO_LOWER_CASE, cache_dir=CACHE_DIR)

The `tokenizer_qa` method of `Tokenizer` tokenizes the input paragraph, question, and answer texts and converts them into the format required by pre-trained BERT model, involving the following steps:
* WordPiece tokenization.
* Convert character-based answer span indices to token-based indices.
* Truncate the question token list if it's longer than `max_query_length`.
* Split the paragraph into multiple segments if it's longer than `MAX_SEQ_LENGTH` - `max_query_length` - 3. (The "-3" is for the special [CLS] token and two [SEP] tokens.)
* Add the special tokens [CLS] and [SEP].
* Pad the concatenated token sequence to `MAX_SEQ_LENGTH` if it's shorter.
* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary.

In additional to the features required by BERT, `tokenize_qa` outputs a few additional fields needed by postprocessing. See the `QAFeatures` class in [qa_utils.py](../../utils_nlp/models/bert/qa_utils.py) for more details

In [8]:
train_features, qa_examples = tokenizer.tokenize_qa(
    doc_text=train_df[DOC_TEXT_COL], 
    question_text=train_df[QUESTION_TEXT_COL], 
    answer_start=train_df[ANSWER_START_COL], 
    answer_text=train_df[ANSWER_TEXT_COL],
    qa_id=train_df[QA_ID_COL],
    is_impossible=train_df[IS_IMPOSSIBLE_COL],
    is_training=True,
    max_len=MAX_SEQ_LENGTH,
    max_query_length=64,
    cache_results=True)

In [9]:
dev_features, dev_examples = tokenizer.tokenize_qa(
    doc_text=dev_df[DOC_TEXT_COL], 
    question_text=dev_df[QUESTION_TEXT_COL], 
    answer_start=dev_df[ANSWER_START_COL], 
    answer_text=dev_df[ANSWER_TEXT_COL],
    qa_id=dev_df[QA_ID_COL],
    is_impossible=dev_df[IS_IMPOSSIBLE_COL],
    is_training=False,
    max_len=MAX_SEQ_LENGTH,
    max_query_length=64,
    cache_results=True)

In [10]:
sample_feature = train_features[0]
for f in type(sample_feature)._fields:
    print(f)
    print(getattr(sample_feature, f))
    print()

unique_id
1000000000

qa_id
5733be284776f41900661182

tokens
['[CLS]', 'to', 'whom', 'did', 'the', 'virgin', 'mary', 'allegedly', 'appear', 'in', '1858', 'in', 'lou', '##rdes', 'france', '?', '[SEP]', 'architectural', '##ly', ',', 'the', 'school', 'has', 'a', 'catholic', 'character', '.', 'atop', 'the', 'main', 'building', "'", 's', 'gold', 'dome', 'is', 'a', 'golden', 'statue', 'of', 'the', 'virgin', 'mary', '.', 'immediately', 'in', 'front', 'of', 'the', 'main', 'building', 'and', 'facing', 'it', ',', 'is', 'a', 'copper', 'statue', 'of', 'christ', 'with', 'arms', 'up', '##rai', '##sed', 'with', 'the', 'legend', '"', 've', '##ni', '##te', 'ad', 'me', 'om', '##nes', '"', '.', 'next', 'to', 'the', 'main', 'building', 'is', 'the', 'basilica', 'of', 'the', 'sacred', 'heart', '.', 'immediately', 'behind', 'the', 'basilica', 'is', 'the', 'gr', '##otto', ',', 'a', 'marian', 'place', 'of', 'prayer', 'and', 'reflection', '.', 'it', 'is', 'a', 'replica', 'of', 'the', 'gr', '##otto', 'at', 'lou'

## Train BERTQAExtractor

In [36]:
qa_extractor = BERTQAExtractor(language=LANGUAGE, cache_dir=CACHE_DIR)

In [37]:
with Timer() as t:
    qa_extractor.fit(train_features,
                     num_epochs=NUM_EPOCHS,
                     batch_size=BATCH_SIZE,
                     lr=LEARNING_RATE,
                     cache_model=True)
print("Training time : {:.3f} hrs".format(t.interval / 3600))

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]
Iteration:   0%|          | 0/11081 [00:00<?, ?it/s][A
Iteration:   0%|          | 1/11081 [00:02<8:40:15,  2.82s/it][A
Iteration:   0%|          | 2/11081 [00:05<8:43:26,  2.83s/it][A
Iteration:   0%|          | 3/11081 [00:08<8:26:28,  2.74s/it][A
Iteration:   0%|          | 4/11081 [00:10<8:15:06,  2.68s/it][A
Iteration:   0%|          | 5/11081 [00:13<8:12:04,  2.67s/it][A
Iteration:   0%|          | 6/11081 [00:16<8:09:11,  2.65s/it][A

KeyboardInterrupt: 

## Predict
Note that the `BERTQAExtractor.predict` only outputs the probabilities of each token being the start and end of the answer span. the `postprocess_answers` method takes these probabilities and generates the final answers. 

In [13]:
qa_results = qa_extractor.predict(dev_features)

Evaluating: 100%|██████████| 339/339 [12:09<00:00,  1.96s/it]


## Postprocess and Generate the Final Answers

In [14]:
final_answers, answer_probs, nbest_answers = postprocess_answers(qa_results,
                                                                 dev_examples, 
                                                                 dev_features, 
                                                                 do_lower_case=DO_LOWER_CASE)

In [15]:
for i in [0, 10, 100]:
    print('Paragraph:')
    print(dev_df.iloc[i]['doc_text'])
    print()
    print('Question:')
    print(dev_df.iloc[i]['question_text'])
    print()
    print('Ground truth answers:')
    print(dev_df.iloc[i]['answer_text'])
    print()
    print('Predicted answer:')
    print(final_answers[dev_df.iloc[i]['qa_id']])
    print()
    print('Top N best answers')
    print(nbest_answers[dev_df.iloc[i]['qa_id']])
    print('-------------------------------------------------------------------------------------------------------------------')

Paragraph:
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.

Question:
Which NFL team represented the AFC at Super Bowl 50?

Ground truth answers:
['Denver Broncos', 'Denver Broncos', 'Denver Broncos']

Predicted answer:
Denver Broncos

Top N best answers
[OrderedDict([('t

## Evaluate

Question answering task is usually evaluated on two metrics: exact match (EM) and F1 score.   
The exact match is computed by first performing some simple normalization (e.g. remove punctuation and convert to lower case) on the ground truth and predicted answers and check if they match exactly after normalization.   
F1 score is computed from token-level precision and recall by comparing the ground truth and predicted answers. 

In [16]:
evaluation_result = evaluate_qa(qa_ids=dev_df['qa_id'], 
                                actuals=dev_df['answer_text'], 
                                preds=final_answers)

{
  "exact": 86.07379375591296,
  "f1": 92.45589503088394,
  "total": 10570,
  "HasAns_exact": 86.07379375591296,
  "HasAns_f1": 92.45589503088394,
  "HasAns_total": 10570
}


OrderedDict([('exact', 86.07379375591296),
             ('f1', 92.45589503088394),
             ('total', 10570),
             ('HasAns_exact', 86.07379375591296),
             ('HasAns_f1', 92.45589503088394),
             ('HasAns_total', 10570)])

In [17]:
print(datetime.now() - startTime)

17:27:07.310707


In [None]:
# qa_extractor_cached = BERTQAExtractor(language=LANGUAGE, cache_dir='/home/hlu/notebooks/NLP/scenarios/question_answering/temp/', load_from_cache=True)
# dev_features_cached = torch.load('/home/hlu/notebooks/NLP/scenarios/question_answering/temp/cached_features')
# dev_examples_cached = torch.load('/home/hlu/notebooks/NLP/scenarios/question_answering/temp/cached_examples')
# qa_results_new = qa_extractor_cached.predict(dev_features_cached)
# final_answers_new, _, _ = postprocess_answers(qa_results_new,
#                                               dev_examples_cached, 
#                                               dev_features_cached, 
#                                               do_lower_case=DO_LOWER_CASE)
# evaluate_qa(qa_ids=dev_df['qa_id'], 
#             actuals=dev_df['answer_text'], 
#             preds=final_answers_new)

In [25]:
# dev_features_short = []
# qa_id_short = []
# for example in dev_examples_cached[:10]:
#     qa_id = example.qa_id
#     qa_id_short.append(qa_id)
#     for f in dev_features_cached:
#         if f.qa_id == qa_id:
#             dev_features_short.append(f)
# answer_text_short = dev_df.loc[dev_df["qa_id"].isin(qa_id_short), 'answer_text']
# qa_results_new_short = qa_extractor_cached.predict(dev_features_short)
# final_answers_new_short, _, _ = postprocess_answers(qa_results_new_short,
#                                               dev_examples_cached[:10], 
#                                               dev_features_short, 
#                                               do_lower_case=DO_LOWER_CASE)
# evaluate_qa(qa_ids=qa_id_short, 
#             actuals=answer_text_short, 
#             preds=final_answers_new_short)

In [45]:
# final_answers_new, final_probs, nbest_answers = postprocess_answers(qa_results_new,
#                                                 dev_examples_cached, 
#                                                 dev_features_cached, 
#                                                 do_lower_case=DO_LOWER_CASE)
# for dev_e in dev_examples_cached:
#     qa_id = dev_e.qa_id
#     count = 0
#     dev_f_list = []
#     for dev_f in dev_features_cached:
#         if dev_f.qa_id == qa_id:
#             count += 1
#             dev_f_list.append(dev_f)
#     if count == 2:
#         break

## References

1. Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina, [*BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding*](https://arxiv.org/abs/1810.04805), ACL, 2018.
2. Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang, [*SQuAD: 100,000+ Questions for Machine Comprehension of Text*](https://arxiv.org/abs/1606.05250), EMNLP, 2016.
3. Pranav Rajpurkar, Robin Jia, Percy Liang, [*Know What You Don't Know: Unanswerable Questions for SQuAD*](https://arxiv.org/abs/1806.03822), ACL, 2018