The following code is two different models for the Squad 2.0 dataset. In the dataset, there are several Wikipidea articles, and each article in further separated into context paragraphs. For each context paragrpah, the dataset contains several questions, some of which are considered answerable. For each answerable question, the answer is a text span from the context paragraph. A model trained on the Squad 2.0 dataset should predict whether a given question is answerable or not, and if it is, it should predict the span from the context paragraph that answers the question.
This repository contains two different models trained on Squad 2.0. The first is based on the DrQA model which was built for the original Squad dataset. The first version of Squad did not contain unanswerable questions, so this implementation tweaks the DrQA model so that it can handle unanswerable questions. In short, this was accomplished by adding a learned threshold parameter to indiciate when the model should predict that the question is answerable, following motivation from this paper.
The second model is a fine-tuning of a pretrained BERT model obtained from the huggingface transformers library. Given a context and question pair, the pretrained BERT model outputs contextualized embeddings for both the context and question as well as a CLS (or classification) token prepended to the pair. We then apply two different linear layers to the context embeddings and the CLS token to obtain the probabilities that each context token is the start and end of the answer. The probability given to the CLS token represents the probability that the question is unanswerable.
First, install the necessary packages using pip: pip install -r requirements.txt
. Additionally, the SpaCy en_core_web_sm tokenizer is used which must be downloaded using python -m spacy download en_core_web_sm
. Example usage for each of these models is given in the respective main files: DrQA/main.py
for the DrQA-based model and BERT/main.py
for the BERT model. Both assume that the Squad 2.0 data file can be found at datafiles/train-v2.0.json
.