Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Question Answering on the SQuAD Dataset using BERT


# Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|4 **CPU**s, 14GB memory| ~ 10 minutes |
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 3 minutes |
|False|4 NVIDIA Tesla K80 GPUs, 48GB GPU memory| ~ 18 hours |
|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 7 hours|

If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

## Summary
This notebook demonstrates how to fine tune [pretrained BERT model](https://github.com/huggingface/pytorch-transformers) for extractive question answering task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation. 

BERT[\[1\]](#References) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.  
The figure below illustrates how BERT can be fine tuned for extractive question answering task. The question and paragraph tokens are concatenated as a single input token sequence with a special token [SEP] between them. For the paragraph tokens, BERT predicts the probabilities of each token being the start and end of the answer span. The tokens with the highest sum of starting probability and ending probability define the span of the predicted answer

<img src="https://nlpbp.blob.core.windows.net/images/bert_qa.PNG">

In [3]:
import os
import sys

import torch
import numpy as np

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.dataset.squad import load_pandas_df
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.models.bert.question_answering import BERTQAExtractor
from utils_nlp.models.bert.qa_utils import postprocess_answer, evaluate_qa
from utils_nlp.common.timer import Timer

## Configurations

In [4]:
TRAIN_DATA_USED_PERCENT = 1
DEV_DATA_USED_PERCENT = 1
NUM_EPOCHS = 2

if QUICK_RUN:
    TRAIN_DATA_USED_PERCENT = 0.001
    DEV_DATA_USED_PERCENT = 0.01
    NUM_EPOCHS = 1

if torch.cuda.is_available() and torch.cuda.device_count() >= 4:
    MAX_SEQ_LENGTH = 384
    DOC_STRIDE = 128
    BATCH_SIZE = 8
else:
    MAX_SEQ_LENGTH = 128
    DOC_STRIDE = 64
    BATCH_SIZE = 4

print("Max sequence length: {}".format(MAX_SEQ_LENGTH))
print("Document stride: {}".format(DOC_STRIDE))
print("Batch size: {}".format(BATCH_SIZE))
    
SQUAD_VERSION = "v1.1" 
CACHE_DIR = "./temp"

LANGUAGE = Language.ENGLISHLARGEWWM
DO_LOWER_CASE = True

MAX_QUESTION_LENGTH = 64
LEARNING_RATE = 3e-5

DOC_TEXT_COL = "doc_text"
QUESTION_TEXT_COL = "question_text"
ANSWER_START_COL = "answer_start"
ANSWER_TEXT_COL = "answer_text"
QA_ID_COL = "qa_id"
IS_IMPOSSIBLE_COL = "is_impossible"

RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if  torch.cuda.device_count() > 0:
    torch.cuda.manual_seed_all(RANDOM_SEED)

Max sequence length: 384
Document stride: 128
Batch size: 8


## Load Data

### The SQuAD Dataset
Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable. [\[2, 3\]](#References)

<img src="https://nlpbp.blob.core.windows.net/images/squad.png">

There has been two versions of SQuAD datasets. SQuAD 1.1 contains 100,000+ question-answer pairs on 500+ articles. SQuAD 2.0 adds 50,000 new, unanswerable questions written adversarially by crowdworkers to look similar to answerable ones. These datasets are available at [https://rajpurkar.github.io/SQuAD-explorer/](https://rajpurkar.github.io/SQuAD-explorer/). Each dataset comes with a training dataset and a development dataset. 


The utility function `load_pandas_df` downloads the dataset specified by `squad_version` and `file_split` to `local_cache_path` if it doesn't exist already.

In [5]:
train_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="train")
dev_df = load_pandas_df(local_cache_path=".", squad_version="v1.1", file_split="dev")

100%|██████████| 8.90k/8.90k [00:00<00:00, 19.8kKB/s]
100%|██████████| 1.16k/1.16k [00:00<00:00, 11.3kKB/s]


In [6]:
train_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,"Architecturally, the school has a Catholic cha...",To whom did the Virgin Mary allegedly appear i...,515,Saint Bernadette Soubirous,5733be284776f41900661182,False
1,"Architecturally, the school has a Catholic cha...",What is in front of the Notre Dame Main Building?,188,a copper statue of Christ,5733be284776f4190066117f,False
2,"Architecturally, the school has a Catholic cha...",The Basilica of the Sacred heart at Notre Dame...,279,the Main Building,5733be284776f41900661180,False
3,"Architecturally, the school has a Catholic cha...",What is the Grotto at Notre Dame?,381,a Marian place of prayer and reflection,5733be284776f41900661181,False
4,"Architecturally, the school has a Catholic cha...",What sits on top of the Main Building at Notre...,92,a golden statue of the Virgin Mary,5733be284776f4190066117e,False


In [7]:
dev_df.head()

Unnamed: 0,doc_text,question_text,answer_start,answer_text,qa_id,is_impossible
0,Super Bowl 50 was an American football game to...,Which NFL team represented the AFC at Super Bo...,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ec,False
1,Super Bowl 50 was an American football game to...,Which NFL team represented the NFC at Super Bo...,"[249, 249, 249]","[Carolina Panthers, Carolina Panthers, Carolin...",56be4db0acb8001400a502ed,False
2,Super Bowl 50 was an American football game to...,Where did Super Bowl 50 take place?,"[403, 355, 355]","[Santa Clara, California, Levi's Stadium, Levi...",56be4db0acb8001400a502ee,False
3,Super Bowl 50 was an American football game to...,Which NFL team won Super Bowl 50?,"[177, 177, 177]","[Denver Broncos, Denver Broncos, Denver Broncos]",56be4db0acb8001400a502ef,False
4,Super Bowl 50 was an American football game to...,What color was used to emphasize the 50th anni...,"[488, 488, 521]","[gold, gold, gold]",56be4db0acb8001400a502f0,False


In [8]:
train_df = train_df.sample(frac=TRAIN_DATA_USED_PERCENT).reset_index(drop=True)
dev_df = dev_df.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)

## Tokenize and Preprocess Data

In [9]:
tokenizer = Tokenizer(language=LANGUAGE, to_lower=DO_LOWER_CASE, cache_dir=CACHE_DIR)

100%|██████████| 231508/231508 [00:00<00:00, 2199359.30B/s]


The `tokenizer_qa` method of `Tokenizer` tokenizes the input paragraph, question, and answer texts and converts them into the format required by pre-trained BERT model, involving the following steps:
* WordPiece tokenization.
* Convert character-based answer span indices to token-based indices.
* Truncate the question token list if it's longer than `max_question_length`.
* Split the paragraph into multiple segments if it's longer than `max_len` - `max_question_length` - 3. (The "-3" is for the special [CLS] token and two [SEP] tokens.)
* Add the special tokens [CLS] and [SEP].
* Pad the concatenated token sequence to `max_len` if it's shorter.
* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary.

In additional to the features required by BERT, `tokenize_qa` outputs a few additional fields needed by postprocessing. See the `QAFeatures` class in [qa_utils.py](../../utils_nlp/models/bert/qa_utils.py) for more details

In [10]:
train_features, qa_examples = tokenizer.tokenize_qa(
    doc_text=train_df[DOC_TEXT_COL], 
    question_text=train_df[QUESTION_TEXT_COL], 
    answer_start=train_df[ANSWER_START_COL], 
    answer_text=train_df[ANSWER_TEXT_COL],
    qa_id=train_df[QA_ID_COL],
    is_impossible=train_df[IS_IMPOSSIBLE_COL],
    is_training=True,
    max_len=MAX_SEQ_LENGTH,
    max_question_length=MAX_QUESTION_LENGTH,
    doc_stride=DOC_STRIDE,
    cache_results=True)

In [11]:
dev_features, dev_examples = tokenizer.tokenize_qa(
    doc_text=dev_df[DOC_TEXT_COL], 
    question_text=dev_df[QUESTION_TEXT_COL], 
    answer_start=dev_df[ANSWER_START_COL], 
    answer_text=dev_df[ANSWER_TEXT_COL],
    qa_id=dev_df[QA_ID_COL],
    is_impossible=dev_df[IS_IMPOSSIBLE_COL],
    is_training=False,
    max_len=MAX_SEQ_LENGTH,
    max_question_length=MAX_QUESTION_LENGTH,
    doc_stride=DOC_STRIDE,
    cache_results=True)

In [12]:
sample_feature = train_features[0]
for f in type(sample_feature)._fields:
    print(f)
    print(getattr(sample_feature, f))
    print()

unique_id
1000000000

qa_id
56de4d9ecffd8e1900b4b7e2

tokens
['[CLS]', 'what', 'year', 'was', 'the', 'ban', '##ska', 'aka', '##de', '##mia', 'founded', '?', '[SEP]', 'the', 'world', "'", 's', 'first', 'institution', 'of', 'technology', 'or', 'technical', 'university', 'with', 'tertiary', 'technical', 'education', 'is', 'the', 'ban', '##ska', 'aka', '##de', '##mia', 'in', 'ban', '##ska', 'st', '##ia', '##vn', '##ica', ',', 'slovakia', ',', 'founded', 'in', '1735', ',', 'academy', 'since', 'december', '13', ',', '1762', 'established', 'by', 'queen', 'maria', 'theresa', 'in', 'order', 'to', 'train', 'specialists', 'of', 'silver', 'and', 'gold', 'mining', 'and', 'metal', '##lu', '##rgy', 'in', 'neighbourhood', '.', 'teaching', 'started', 'in', '1764', '.', 'later', 'the', 'department', 'of', 'mathematics', ',', 'mechanics', 'and', 'hydraulic', '##s', 'and', 'department', 'of', 'forestry', 'were', 'settled', '.', 'university', 'buildings', 'are', 'still', 'at', 'their', 'place', 'today', 'a

## Train BERTQAExtractor

In [13]:
qa_extractor = BERTQAExtractor(language=LANGUAGE, cache_dir=CACHE_DIR)

100%|██████████| 314/314 [00:00<00:00, 114632.38B/s]
100%|██████████| 1345000548/1345000548 [00:25<00:00, 53229616.42B/s]


In [14]:
with Timer() as t:
    qa_extractor.fit(train_features,
                     num_epochs=NUM_EPOCHS,
                     batch_size=BATCH_SIZE,
                     learning_rate=LEARNING_RATE,
                     cache_model=True)
print("Training time : {:.3f} hrs".format(t.interval / 3600))

Epoch:   0%|          | 0/2 [00:00<?, ?it/s]

Iteration:   0%|          | 47/11081 [01:01<4:01:28,  1.31s/it][A
Iteration:   0%|          | 47/11081 [01:21<4:01:28,  1.31s/it][A
Iteration:   1%|          | 103/11081 [02:02<3:47:43,  1.24s/it][A
Iteration:   1%|          | 103/11081 [02:21<3:47:43,  1.24s/it][A
Iteration:   1%|▏         | 159/11081 [03:03<3:37:48,  1.20s/it][A
Iteration:   1%|▏         | 159/11081 [03:21<3:37:48,  1.20s/it][A
Iteration:   2%|▏         | 213/11081 [04:03<3:32:37,  1.17s/it][A
Iteration:   2%|▏         | 213/11081 [04:22<3:32:37,  1.17s/it][A
Iteration:   2%|▏         | 269/11081 [05:04<3:26:47,  1.15s/it][A
Iteration:   2%|▏         | 269/11081 [05:23<3:26:47,  1.15s/it][A
Iteration:   3%|▎         | 323/11081 [06:04<3:24:09,  1.14s/it][A
Iteration:   3%|▎         | 323/11081 [06:23<3:24:09,  1.14s/it][A
Iteration:   3%|▎         | 377/11081 [07:05<3:22:03,  1.13s/it][A
Iteration:   3%|▎         | 377/11081 [07:23<3:22:03,  1.13s/it][A
Iter

Iteration:  28%|██▊       | 3151/11081 [57:45<2:23:59,  1.09s/it][A
Iteration:  29%|██▉       | 3208/11081 [58:31<2:22:03,  1.08s/it][A
Iteration:  29%|██▉       | 3208/11081 [58:45<2:22:03,  1.08s/it][A
Iteration:  29%|██▉       | 3263/11081 [59:32<2:21:59,  1.09s/it][A
Iteration:  29%|██▉       | 3263/11081 [59:45<2:21:59,  1.09s/it][A
Iteration:  30%|██▉       | 3320/11081 [1:00:33<2:20:10,  1.08s/it][A
Iteration:  30%|██▉       | 3320/11081 [1:00:45<2:20:10,  1.08s/it][A
Iteration:  30%|███       | 3374/11081 [1:01:33<2:20:23,  1.09s/it][A
Iteration:  30%|███       | 3374/11081 [1:01:45<2:20:23,  1.09s/it][A
Iteration:  31%|███       | 3430/11081 [1:02:34<2:18:48,  1.09s/it][A
Iteration:  31%|███       | 3430/11081 [1:02:45<2:18:48,  1.09s/it][A
Iteration:  31%|███▏      | 3487/11081 [1:03:35<2:17:04,  1.08s/it][A
Iteration:  31%|███▏      | 3487/11081 [1:03:45<2:17:04,  1.08s/it][A
Iteration:  32%|███▏      | 3541/11081 [1:04:35<2:17:20,  1.09s/it][A
Iteration:  32%|

Iteration:  58%|█████▊    | 6382/11081 [1:56:04<1:24:17,  1.08s/it][A
Iteration:  58%|█████▊    | 6382/11081 [1:56:23<1:24:17,  1.08s/it][A
Iteration:  58%|█████▊    | 6437/11081 [1:57:04<1:23:45,  1.08s/it][A
Iteration:  58%|█████▊    | 6437/11081 [1:57:23<1:23:45,  1.08s/it][A
Iteration:  59%|█████▊    | 6494/11081 [1:58:05<1:22:21,  1.08s/it][A
Iteration:  59%|█████▊    | 6494/11081 [1:58:23<1:22:21,  1.08s/it][A
Iteration:  59%|█████▉    | 6549/11081 [1:59:05<1:21:52,  1.08s/it][A
Iteration:  59%|█████▉    | 6549/11081 [1:59:23<1:21:52,  1.08s/it][A
Iteration:  60%|█████▉    | 6606/11081 [2:00:06<1:20:27,  1.08s/it][A
Iteration:  60%|█████▉    | 6606/11081 [2:00:24<1:20:27,  1.08s/it][A
Iteration:  60%|██████    | 6663/11081 [2:01:08<1:19:37,  1.08s/it][A
Iteration:  60%|██████    | 6663/11081 [2:01:24<1:19:37,  1.08s/it][A
Iteration:  61%|██████    | 6720/11081 [2:02:08<1:18:09,  1.08s/it][A
Iteration:  61%|██████    | 6720/11081 [2:02:24<1:18:09,  1.08s/it][A
Iterat

Iteration:  87%|████████▋ | 9650/11081 [2:54:55<25:29,  1.07s/it][A
Iteration:  88%|████████▊ | 9705/11081 [2:55:41<24:40,  1.08s/it][A
Iteration:  88%|████████▊ | 9705/11081 [2:55:55<24:40,  1.08s/it][A
Iteration:  88%|████████▊ | 9761/11081 [2:56:41<23:39,  1.08s/it][A
Iteration:  88%|████████▊ | 9761/11081 [2:56:55<23:39,  1.08s/it][A
Iteration:  89%|████████▊ | 9816/11081 [2:57:41<22:46,  1.08s/it][A
Iteration:  89%|████████▊ | 9816/11081 [2:57:55<22:46,  1.08s/it][A
Iteration:  89%|████████▉ | 9873/11081 [2:58:42<21:38,  1.08s/it][A
Iteration:  89%|████████▉ | 9873/11081 [2:58:55<21:38,  1.08s/it][A
Iteration:  90%|████████▉ | 9930/11081 [2:59:43<20:32,  1.07s/it][A
Iteration:  90%|████████▉ | 9930/11081 [2:59:55<20:32,  1.07s/it][A
Iteration:  90%|█████████ | 9985/11081 [3:00:43<19:42,  1.08s/it][A
Iteration:  90%|█████████ | 9985/11081 [3:00:55<19:42,  1.08s/it][A
Iteration:  91%|█████████ | 10042/11081 [3:01:44<18:37,  1.08s/it][A
Iteration:  91%|█████████ | 10042

Iteration:  17%|█▋        | 1852/11081 [33:25<2:45:39,  1.08s/it][A
Iteration:  17%|█▋        | 1908/11081 [34:15<2:44:24,  1.08s/it][A
Iteration:  17%|█▋        | 1908/11081 [34:25<2:44:24,  1.08s/it][A
Iteration:  18%|█▊        | 1965/11081 [35:17<2:44:14,  1.08s/it][A
Iteration:  18%|█▊        | 1965/11081 [35:36<2:44:14,  1.08s/it][A
Iteration:  18%|█▊        | 2022/11081 [36:17<2:42:24,  1.08s/it][A
Iteration:  18%|█▊        | 2022/11081 [36:28<2:42:24,  1.08s/it][A
Iteration:  19%|█▊        | 2077/11081 [37:18<2:42:17,  1.08s/it][A
Iteration:  19%|█▊        | 2077/11081 [37:28<2:42:17,  1.08s/it][A
Iteration:  19%|█▉        | 2132/11081 [38:18<2:41:48,  1.08s/it][A
Iteration:  19%|█▉        | 2132/11081 [38:29<2:41:48,  1.08s/it][A
Iteration:  20%|█▉        | 2189/11081 [39:18<2:39:42,  1.08s/it][A
Iteration:  20%|█▉        | 2189/11081 [39:31<2:39:42,  1.08s/it][A
Iteration:  20%|██        | 2245/11081 [40:19<2:39:09,  1.08s/it][A
Iteration:  20%|██        | 2245/1

Iteration:  46%|████▌     | 5109/11081 [1:31:56<1:46:56,  1.07s/it][A
Iteration:  47%|████▋     | 5166/11081 [1:32:45<1:46:21,  1.08s/it][A
Iteration:  47%|████▋     | 5166/11081 [1:32:56<1:46:21,  1.08s/it][A
Iteration:  47%|████▋     | 5223/11081 [1:33:45<1:44:47,  1.07s/it][A
Iteration:  47%|████▋     | 5223/11081 [1:33:56<1:44:47,  1.07s/it][A
Iteration:  48%|████▊     | 5278/11081 [1:34:45<1:44:27,  1.08s/it][A
Iteration:  48%|████▊     | 5278/11081 [1:34:56<1:44:27,  1.08s/it][A
Iteration:  48%|████▊     | 5335/11081 [1:35:46<1:42:58,  1.08s/it][A
Iteration:  48%|████▊     | 5335/11081 [1:36:06<1:42:58,  1.08s/it][A
Iteration:  49%|████▊     | 5392/11081 [1:36:47<1:41:37,  1.07s/it][A
Iteration:  49%|████▊     | 5392/11081 [1:37:06<1:41:37,  1.07s/it][A
Iteration:  49%|████▉     | 5447/11081 [1:37:47<1:41:16,  1.08s/it][A
Iteration:  49%|████▉     | 5447/11081 [1:38:06<1:41:16,  1.08s/it][A
Iteration:  50%|████▉     | 5504/11081 [1:38:48<1:39:55,  1.08s/it][A
Iterat

Iteration:  75%|███████▌  | 8366/11081 [2:30:15<48:52,  1.08s/it][A
Iteration:  75%|███████▌  | 8366/11081 [2:30:35<48:52,  1.08s/it][A
Iteration:  76%|███████▌  | 8423/11081 [2:31:16<47:37,  1.08s/it][A
Iteration:  76%|███████▌  | 8423/11081 [2:31:27<47:37,  1.08s/it][A
Iteration:  77%|███████▋  | 8479/11081 [2:32:17<46:48,  1.08s/it][A
Iteration:  77%|███████▋  | 8479/11081 [2:32:37<46:48,  1.08s/it][A
Iteration:  77%|███████▋  | 8536/11081 [2:33:17<45:32,  1.07s/it][A
Iteration:  77%|███████▋  | 8536/11081 [2:33:28<45:32,  1.07s/it][A
Iteration:  78%|███████▊  | 8591/11081 [2:34:17<44:46,  1.08s/it][A
Iteration:  78%|███████▊  | 8591/11081 [2:34:29<44:46,  1.08s/it][A
Iteration:  78%|███████▊  | 8647/11081 [2:35:18<43:54,  1.08s/it][A
Iteration:  78%|███████▊  | 8647/11081 [2:35:31<43:54,  1.08s/it][A
Iteration:  79%|███████▊  | 8704/11081 [2:36:19<42:41,  1.08s/it][A
Iteration:  79%|███████▊  | 8704/11081 [2:36:31<42:41,  1.08s/it][A
Iteration:  79%|███████▉  | 8761/1

Training time : 6.659 hrs


## Predict
Note that the `BERTQAExtractor.predict` only outputs the probabilities of each token being the start and end of the answer span. the `postprocess_answers` method takes these probabilities and generates the final answers. 

In [15]:
qa_results = qa_extractor.predict(dev_features)

Evaluating: 100%|██████████| 339/339 [03:49<00:00,  1.11s/it]


## Postprocess and Generate the Final Answers

In [16]:
final_answers, answer_probs, nbest_answers = postprocess_answer(qa_results,
                                                                dev_examples, 
                                                                dev_features, 
                                                                do_lower_case=DO_LOWER_CASE)

In [17]:
for i in [0, 10, 100]:
    print('Paragraph:')
    print(dev_df.iloc[i]['doc_text'])
    print()
    print('Question:')
    print(dev_df.iloc[i]['question_text'])
    print()
    print('Ground truth answers:')
    print(dev_df.iloc[i]['answer_text'])
    print()
    print('Predicted answer:')
    print(final_answers[dev_df.iloc[i]['qa_id']])
    print()
    print('Top N best answers')
    print(nbest_answers[dev_df.iloc[i]['qa_id']])
    print('-------------------------------------------------------------------------------------------------------------------')

Paragraph:
Immunology is strongly experimental in everyday practice but is also characterized by an ongoing theoretical attitude. Many theories have been suggested in immunology from the end of the nineteenth century up to the present time. The end of the 19th century and the beginning of the 20th century saw a battle between "cellular" and "humoral" theories of immunity. According to the cellular theory of immunity, represented in particular by Elie Metchnikoff, it was cells – more precisely, phagocytes – that were responsible for immune responses. In contrast, the humoral theory of immunity, held, among others, by Robert Koch and Emil von Behring, stated that the active immune agents were soluble components (molecules) found in the organism’s “humors” rather than its cells.

Question:
What two scientists were proponents of the humoral theory of immunity?

Ground truth answers:
['Robert Koch and Emil von Behring', 'Robert Koch and Emil von Behring', 'Robert Koch and Emil von Behring,'

## Evaluate

Question answering task is usually evaluated on two metrics: exact match (EM) and F1 score.   
The exact match is computed by first performing some simple normalization (e.g. remove punctuation and convert to lower case) on the ground truth and predicted answers and check if they match exactly after normalization.   
F1 score is computed from token-level precision and recall by comparing the ground truth and predicted answers. 

In [18]:
evaluation_result = evaluate_qa(qa_ids=dev_df['qa_id'], 
                                actuals=dev_df['answer_text'], 
                                preds=final_answers)

{
  "exact": 86.08325449385052,
  "f1": 92.55496396943134,
  "total": 10570,
  "HasAns_exact": 86.08325449385052,
  "HasAns_f1": 92.55496396943134,
  "HasAns_total": 10570
}


## References

1. Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina, [*BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding*](https://arxiv.org/abs/1810.04805), ACL, 2018.
2. Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang, [*SQuAD: 100,000+ Questions for Machine Comprehension of Text*](https://arxiv.org/abs/1606.05250), EMNLP, 2016.
3. Pranav Rajpurkar, Robin Jia, Percy Liang, [*Know What You Don't Know: Unanswerable Questions for SQuAD*](https://arxiv.org/abs/1806.03822), ACL, 2018