An excellent Pytorch implementation of the BERT model was done by [HuggingFace](https://github.com/huggingface/pytorch-transformers). They also included supplementary code to build upon pre-trained BERT models for a wide range of NLP tasks. Below is a detailed walkthrough of their code that is needed to fine-tune a trained BERT model for a Question Answering task (to be run on  [the SQuAD dataset](https://rajpurkar.github.io/SQuAD-explorer/)).

Except for the *Demo* section of this notebook, and a few minor edits and comments added here and there, the code was pasted directly from various files in HuggingFace's [pytorch-transformers](https://github.com/huggingface/pytorch-transformers) GitHub repository. I focused on the part of the code that is task-specific (in this case, that means question answering). Below I assume that you have already performed

```
! pip install pytorch-transformers
``` 

and copied the SQuAD 1.0 [train](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json) and [dev](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json) datasets, as well as the [evaluation script](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py), to a directory called `/root/BERT/SQuAD1`. I also copied the BERT vocab file from [here](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt) to the same directory as the Notebook.

This is an FP16 (half-precision) version of the model which was trained in **mixed precision** using [NVIDIA's apex](https://github.com/NVIDIA/apex). [More on mixed precision training here.](https://www.linkedin.com/pulse/joys-mixed-precision-training-ml-olga-petrova/)


# 1. Data preprocessing

The SQuAD dataset contains contexts (passages of text containing multiple sentences), questions (sentences), and answers (uninterrupted sections of a single sentence from a context answering a particular question). These are organized via multiple layers of ID numbers (grouped by multiple-contexts-with-a-common-theme / context-with-multiple-corresponding-QAs / question with acceptable answers: single possible answer in `train-v1.1.json` and generally multiple correct answers in `dev-v1.1.json`).

### 1A. Load the SQuAD data:

First, we need to be able to load the training/test examples from the `json` files:

In [1]:
""" Load SQuAD dataset. """

from __future__ import absolute_import, division, print_function

import json
import logging
import math
import collections
from io import open

from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize

class SquadExample(object):
    """
    A single training/test example for the Squad dataset.
    For examples without an answer, the start and end position are -1.
    """

    def __init__(self,
                 qas_id,
                 question_text,
                 doc_tokens,
                 orig_answer_text=None,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.qas_id = qas_id
        self.question_text = question_text
        self.doc_tokens = doc_tokens
        self.orig_answer_text = orig_answer_text
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += "qas_id: %s" % (self.qas_id)
        s += ", question_text: %s" % (
            self.question_text)
        s += ", doc_tokens: [%s]" % (" ".join(self.doc_tokens))
        if self.orig_answer_text:
            s += ", orig_answer_text: %s" % (self.orig_answer_text)
        if self.start_position:
            s += ", start_position: %d" % (self.start_position)
        if self.end_position:
            s += ", end_position: %d" % (self.end_position)
        if self.is_impossible:
            s += ", is_impossible: %r" % (self.is_impossible)
        return s
    
def read_squad_examples(input_file, is_training, version_2_with_negative):
    """Read a SQuAD json file into a list of SquadExample."""
    with open(input_file, "r", encoding='utf-8') as reader:
        input_data = json.load(reader)["data"]

    def is_whitespace(c):
        if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
            return True
        return False

    examples = []
    for entry in input_data:
        for paragraph in entry["paragraphs"]:
            paragraph_text = paragraph["context"]
            doc_tokens = []
            char_to_word_offset = []
            prev_is_whitespace = True
            for c in paragraph_text:
                if is_whitespace(c):
                    prev_is_whitespace = True
                else:
                    if prev_is_whitespace:
                        doc_tokens.append(c)
                    else:
                        doc_tokens[-1] += c
                    prev_is_whitespace = False
                char_to_word_offset.append(len(doc_tokens) - 1)

            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                question_text = qa["question"]
                start_position = None
                end_position = None
                orig_answer_text = None
                is_impossible = False
                if is_training:
                    if version_2_with_negative:
                        is_impossible = qa["is_impossible"]
                    if (len(qa["answers"]) != 1) and (not is_impossible):
                        raise ValueError(
                            "For training, each question should have exactly 1 answer.")
                    if not is_impossible:
                        answer = qa["answers"][0]
                        orig_answer_text = answer["text"]
                        answer_offset = answer["answer_start"]
                        answer_length = len(orig_answer_text)
                        start_position = char_to_word_offset[answer_offset]
                        end_position = char_to_word_offset[answer_offset + answer_length - 1]
                        # Only add answers where the text can be exactly recovered from the
                        # document. If this CAN'T happen it's likely due to weird Unicode
                        # stuff so we will just skip the example.
                        #
                        # Note that this means for training mode, every example is NOT
                        # guaranteed to be preserved.
                        actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
                        cleaned_answer_text = " ".join(
                            whitespace_tokenize(orig_answer_text))
                        if actual_text.find(cleaned_answer_text) == -1:
                            logger.warning("Could not find answer: '%s' vs. '%s'",
                                           actual_text, cleaned_answer_text)
                            continue
                    else:
                        start_position = -1
                        end_position = -1
                        orig_answer_text = ""

                example = SquadExample(
                    qas_id=qas_id,
                    question_text=question_text,
                    doc_tokens=doc_tokens,
                    orig_answer_text=orig_answer_text,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=is_impossible)
                examples.append(example)
    return examples

Test:

In [2]:
train_file = "/root/BERT/SQuAD1/train-v1.1.json"
train_examples = read_squad_examples(train_file, True, False)
# Note that the dev_file can only be read via read_squad_examples(train_file, False, False)
# due to there being more than one possible answer
train_examples[2]

qas_id: 5733be284776f41900661180, question_text: The Basilica of the Sacred heart at Notre Dame is beside to which structure?, doc_tokens: [Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.], orig_answer_text: the Main Building, start_position: 49, end_position: 51

### 1B. Tokenize the text:

The contexts, questions, and answers need to be **tokenized** (projected onto the basis used by BERT to represent words). BERT makes use of [wordpiece tokenization](https://arxiv.org/pdf/1609.08144.pdf), which minimizes the number of *unknown* words by breaking them further into subparts. They also use a longest-match-first algorithm to perform tokenization, such that words contained in the vocabularly will not be broken down, whereas e.g. *unaffable*, which is not part of the vocabularly, will return the following three word pieces that are in it: \[*un*, *##aff*, *##able*\].

In [3]:
from pytorch_transformers.tokenization_bert import BertTokenizer

# You can get the vocab_file from https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt

tokenizer = BertTokenizer(vocab_file="bert-base-uncased-vocab.txt")

tokens_test = tokenizer.tokenize("Tokenization is the act of breaking up a sequence of strings into pieces such as words, keywords, phrases, symbols and other elements called tokens.")
print(tokens_test)

# The tokens can be converted to IDs: indices assigned to the tokens in the vocab_file

ids_test = tokenizer.convert_tokens_to_ids(tokens_test)
print(ids_test)

['token', '##ization', 'is', 'the', 'act', 'of', 'breaking', 'up', 'a', 'sequence', 'of', 'strings', 'into', 'pieces', 'such', 'as', 'words', ',', 'key', '##words', ',', 'phrases', ',', 'symbols', 'and', 'other', 'elements', 'called', 'token', '##s', '.']
[19204, 3989, 2003, 1996, 2552, 1997, 4911, 2039, 1037, 5537, 1997, 7817, 2046, 4109, 2107, 2004, 2616, 1010, 3145, 22104, 1010, 15672, 1010, 9255, 1998, 2060, 3787, 2170, 19204, 2015, 1012]


We can now use the tokenizer to convert the textual examples to features that can be read as input into the BERT model. This part of the code is quite lengthy for the following reasons, among others:
* to keep the correspondence between the wordpiece-tokenized words and the whitespace-tokenized words (to be able to put the split words together later on)
* to split the context paragraphs into multiple parts in they exceed the `max_query_length`
* if the answer ends up being split, figure out which part contains most of the answer
* and other book-keeping tasks.

In [4]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self,
                 unique_id,
                 example_index,
                 doc_span_index,
                 tokens,
                 token_to_orig_map,
                 token_is_max_context,
                 input_ids,
                 input_mask,
                 segment_ids,
                 start_position=None,
                 end_position=None,
                 is_impossible=None):
        self.unique_id = unique_id
        self.example_index = example_index
        self.doc_span_index = doc_span_index
        self.tokens = tokens
        self.token_to_orig_map = token_to_orig_map
        self.token_is_max_context = token_is_max_context
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.start_position = start_position
        self.end_position = end_position
        self.is_impossible = is_impossible
        
        
def _improve_answer_span(doc_tokens, input_start, input_end, tokenizer,
                         orig_answer_text):
    """Returns tokenized answer spans that better match the annotated answer."""

    # The SQuAD annotations are character based. We first project them to
    # whitespace-tokenized words. But then after WordPiece tokenization, we can
    # often find a "better match". For example:
    #
    #   Question: What year was John Smith born?
    #   Context: The leader was John Smith (1895-1943).
    #   Answer: 1895
    #
    # The original whitespace-tokenized answer will be "(1895-1943).". However
    # after tokenization, our tokens will be "( 1895 - 1943 ) .". So we can match
    # the exact answer, 1895.
    #
    # However, this is not always possible. Consider the following:
    #
    #   Question: What country is the top exporter of electornics?
    #   Context: The Japanese electronics industry is the lagest in the world.
    #   Answer: Japan
    #
    # In this case, the annotator chose "Japan" as a character sub-span of
    # the word "Japanese". Since our WordPiece tokenizer does not split
    # "Japanese", we just use "Japanese" as the annotation. This is fairly rare
    # in SQuAD, but does happen.
    tok_answer_text = " ".join(tokenizer.tokenize(orig_answer_text))

    for new_start in range(input_start, input_end + 1):
        for new_end in range(input_end, new_start - 1, -1):
            text_span = " ".join(doc_tokens[new_start:(new_end + 1)])
            if text_span == tok_answer_text:
                return (new_start, new_end)

    return (input_start, input_end)

def _check_is_max_context(doc_spans, cur_span_index, position):
    """Check if this is the 'max context' doc span for the token."""

    # Because of the sliding window approach taken to scoring documents, a single
    # token can appear in multiple documents. E.g.
    #  Doc: the man went to the store and bought a gallon of milk
    #  Span A: the man went to the
    #  Span B: to the store and bought
    #  Span C: and bought a gallon of
    #  ...
    #
    # Now the word 'bought' will have two scores from spans B and C. We only
    # want to consider the score with "maximum context", which we define as
    # the *minimum* of its left and right context (the *sum* of left and
    # right context will always be the same, of course).
    #
    # In the example the maximum context for 'bought' would be span C since
    # it has 1 left context and 3 right context, while span B has 4 left context
    # and 0 right context.
    best_score = None
    best_span_index = None
    for (span_index, doc_span) in enumerate(doc_spans):
        end = doc_span.start + doc_span.length - 1
        if position < doc_span.start:
            continue
        if position > end:
            continue
        num_left_context = position - doc_span.start
        num_right_context = end - position
        score = min(num_left_context, num_right_context) + 0.01 * doc_span.length
        if best_score is None or score > best_score:
            best_score = score
            best_span_index = span_index

    return cur_span_index == best_span_index


def convert_examples_to_features(examples, tokenizer, max_seq_length,
                                 doc_stride, max_query_length, is_training):
    """Loads a data file into a list of `InputBatch`s."""

    unique_id = 1000000000

    features = []
    for (example_index, example) in enumerate(examples):
        query_tokens = tokenizer.tokenize(example.question_text)

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = tokenizer.tokenize(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens, tok_start_position, tok_end_position, tokenizer,
                example.orig_answer_text)

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append("[CLS]")
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append("[SEP]")
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

            input_ids = tokenizer.convert_tokens_to_ids(tokens)

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [1] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(0)
                segment_ids.append(0)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length

            start_position = None
            end_position = None
            if is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and
                        tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset
            if is_training and example.is_impossible:
                start_position = 0
                end_position = 0


            features.append(
                InputFeatures(
                    unique_id=unique_id,
                    example_index=example_index,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    input_mask=input_mask,
                    segment_ids=segment_ids,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=example.is_impossible))
            unique_id += 1

    return features

Test:

In [5]:
train_features = convert_examples_to_features(train_examples, tokenizer, max_seq_length=384,
                                 doc_stride=128, max_query_length=64, is_training=True)

ind = 20

# A consequtively assigned ID number that starts with 1000000000
print("unique_id: ", train_features[ind].unique_id)

# Looks to be just = unique_id - 1000000000
print("example_index: ",train_features[ind].example_index)

# Documents longer than max_seq_length are split into chunks indexed by doc_span_index with a stride of doc_stride
print("doc_span_index: ",train_features[ind].doc_span_index)

# Tokens for the Question starting with '[CLS]', followed by '[SEP]' and then by the tokens for the Context
print("tokens: ",train_features[ind].tokens)

# { <token # from features[ind].tokens> : <# of a word in the original Context> } 
# NOTE: When the original word has been split into multiple tokens, those multiple 
# <token # from features[ind].tokens> correspond to the same <# of a word in the original Context>
print("token_to_orig_map: ",train_features[ind].token_to_orig_map)

# Does this document span contain "maximum content" of the answer
print("token_is_max_context: ",train_features[ind].token_is_max_context)

# IDs that the tokens in features[ind].tokens have in the BERT vocabulary 
# (Additionally, 0s for padding entries that go up to len = max_seq_length)
print("input_ids: ",train_features[ind].input_ids)

# 1 for each non-trivial token, 0 for padding entries that go up to len = max_seq_length
print("input_mask: ",train_features[ind].input_mask)

# 1 for each token that is part of the Context, 0 otherwise
print("segment_ids: ",train_features[ind].segment_ids)

# Start and end positions of the answer
print("start_position: ",train_features[ind].start_position)
print("end_position: ",train_features[ind].end_position)

# SQuAD 2.0 includes questions that cannot be answered from the context provided
# This is the flag for those "impossible" questions
print("is_impossible: ",train_features[ind].is_impossible)

unique_id:  1000000020
example_index:  20
doc_span_index:  0
tokens:  ['[CLS]', 'what', 'entity', 'provides', 'help', 'with', 'the', 'management', 'of', 'time', 'for', 'new', 'students', 'at', 'notre', 'dame', '?', '[SEP]', 'all', 'of', 'notre', 'dame', "'", 's', 'undergraduate', 'students', 'are', 'a', 'part', 'of', 'one', 'of', 'the', 'five', 'undergraduate', 'colleges', 'at', 'the', 'school', 'or', 'are', 'in', 'the', 'first', 'year', 'of', 'studies', 'program', '.', 'the', 'first', 'year', 'of', 'studies', 'program', 'was', 'established', 'in', '1962', 'to', 'guide', 'incoming', 'freshmen', 'in', 'their', 'first', 'year', 'at', 'the', 'school', 'before', 'they', 'have', 'declared', 'a', 'major', '.', 'each', 'student', 'is', 'given', 'an', 'academic', 'advisor', 'from', 'the', 'program', 'who', 'helps', 'them', 'to', 'choose', 'classes', 'that', 'give', 'them', 'exposure', 'to', 'any', 'major', 'in', 'which', 'they', 'are', 'interested', '.', 'the', 'program', 'also', 'includes', '

# 2. The model

The way BERT's creators have set it up for the SQuAD task is the following: the tokenized Question and Context are put together into a single input sequence (the two separated by a special token) and a linear layer is added on top of the Encoder (i.e. BERT) which is then used to predict the start and end of the Answer inside the Context.

![](squadbert.jpeg)
(Fig. from the original [BERT paper](https://arxiv.org/pdf/1810.04805.pdf))

### 2A. The pre-trained BERT model for Question Answering:

In [6]:
import copy
import random
import json
import math
import os
import sys
from io import open

from IPython.utils import io

import numpy as np

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

from pytorch_transformers.modeling_bert import BertPreTrainedModel, BertModel

class BertForQuestionAnswering(BertPreTrainedModel):
    r"""
        **start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
        **end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`).
            Position outside of the sequence are not taken into account for computing the loss.
    Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
        **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
            Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
        **start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-start scores (before SoftMax).
        **end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
            Span-end scores (before SoftMax).
        **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
            list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
            of shape ``(batch_size, sequence_length, hidden_size)``:
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        **attentions**: (`optional`, returned when ``config.output_attentions=True``)
            list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
    Examples::
        >>> config = BertConfig.from_pretrained('bert-base-uncased')
        >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        >>> 
        >>> model = BertForQuestionAnswering(config)
        >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0)  # Batch size 1
        >>> start_positions = torch.tensor([1])
        >>> end_positions = torch.tensor([3])
        >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
        >>> loss, start_scores, end_scores = outputs[:2]
    """
    def __init__(self, config):
        super(BertForQuestionAnswering, self).__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

        self.apply(self.init_weights)

    def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None,
                end_positions=None, position_ids=None, head_mask=None):
        outputs = self.bert(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
                            attention_mask=attention_mask, head_mask=head_mask)
        sequence_output = outputs[0]

        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        outputs = (start_logits, end_logits,) + outputs[2:]
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions.clamp_(0, ignored_index)
            end_positions.clamp_(0, ignored_index)

            loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, start_positions)
            end_loss = loss_fct(end_logits, end_positions)
            total_loss = (start_loss + end_loss) / 2
            outputs = (total_loss,) + outputs

        return outputs  # (loss), start_logits, end_logits, (hidden_states), (attentions)
    
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

model = BertForQuestionAnswering.from_pretrained('bert-base-uncased')
model.half()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with io.capture_output() as captured:    
    model.to(device)

### 2B. Train the Question Answering model:

In [7]:
import time

from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from apex import amp
from apex.optimizers import FP16_Optimizer, FusedAdam

from pytorch_transformers import AdamW, WarmupLinearSchedule

num_train_epochs = 1
train_batch_size = 32

all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions)

train_sampler = RandomSampler(train_data)

train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=train_batch_size)

# Prepare optimizer
param_optimizer = list(model.named_parameters())

# hack to remove pooler, which is not used
# thus it produce None grad that break apex
param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
                            ]
num_train_optimization_steps = len(train_dataloader) *  num_train_epochs

optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=3e-5,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)

global_step = 0

print("  Num orig examples = ", len(train_examples))
print("  Num split examples = ", len(train_features))
print("  Batch size = ", train_batch_size)
print("  Num steps = ", num_train_optimization_steps)

model.train()

start_time = time.time()

for epoch in range(num_train_epochs):
    for step, batch in enumerate(train_dataloader):
        batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
        input_ids, input_mask, segment_ids, start_positions, end_positions = batch
        outputs = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
        loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        global_step += 1
     
    if epoch==0:
        print("Time it took to complete the epoch: ", (time.time()-start_time))
    print("Loss after epoch ", epoch, ": ", loss.item())

  Num orig examples =  87599
  Num split examples =  88641
  Batch size =  32
  Num steps =  2771
Time it took to complete the epoch:  3215.1631803512573
Loss after epoch  0 :  0.07489013671875


### 2C. Save the trained model:

In [8]:
output_dir = "output/"

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

('output/vocab.txt',
 'output/special_tokens_map.json',
 'output/added_tokens.json')

### 2D. Test the model on the SQuAD dev set:

In [9]:
import os

from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)

from run_squad_dataset_utils import read_squad_examples, convert_examples_to_features, RawResult, write_predictions

dev_file = "/root/BERT/SQuAD1/dev-v1.1.json"
predict_batch_size = 8

eval_examples = read_squad_examples(input_file=dev_file, is_training=False, version_2_with_negative=False)
eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=384,
            doc_stride=128,
            max_query_length=64,
            is_training=False)

print("***** Running predictions *****")
print("  Num orig examples = ", len(eval_examples))
print("  Num split examples = ", len(eval_features))
print("  Batch size = ", predict_batch_size)

all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)

# Run prediction for full data
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=predict_batch_size)

model.eval()
all_results = []
print("Start evaluating")
for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
    if len(all_results) % 1000 == 0:
        print("Processing example: %d" % (len(all_results)))
    input_ids = input_ids.to(device)
    input_mask = input_mask.to(device)
    segment_ids = segment_ids.to(device)
    with torch.no_grad():
        batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
    for i, example_index in enumerate(example_indices):
        start_logits = batch_start_logits[i].detach().cpu().tolist()
        end_logits = batch_end_logits[i].detach().cpu().tolist()
        eval_feature = eval_features[example_index.item()]
        unique_id = int(eval_feature.unique_id)
        all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        
output_prediction_file = os.path.join(output_dir, "predictions.json")
output_nbest_file = os.path.join(output_dir, "nbest_predictions.json")
output_null_log_odds_file = os.path.join(output_dir, "null_odds.json")

with io.capture_output() as captured:
    write_predictions(eval_examples, eval_features, all_results, 20,
                      30, True, output_prediction_file,
                      output_nbest_file, output_null_log_odds_file, True,
                      False, 0.0)
print("Done.")

***** Running predictions *****
  Num orig examples =  10570
  Num split examples =  10833
  Batch size =  8
Start evaluating
Processing example: 0
Processing example: 1000
Processing example: 2000
Processing example: 3000
Processing example: 4000
Processing example: 5000
Processing example: 6000
Processing example: 7000
Processing example: 8000
Processing example: 9000
Processing example: 10000
Done.


In [10]:
! python SQuAD1/evaluate-v1.1.py SQuAD1/dev-v1.1.json /root/BERT/output/predictions.json

{"exact_match": 77.48344370860927, "f1": 85.67522993087587}


Fine-tuning BERT for the Question-Answering task for a **single training epoch** already puts us in the top 40 entries on the SQuAD 1.1 Leaderboard.

# 3. The Demo

### 3A. Load a saved trained model:

In [11]:
import torch

from IPython.utils import io

from pytorch_transformers.modeling_bert import BertForQuestionAnswering
from pytorch_transformers.tokenization_bert import BertTokenizer

output_dir = "old_output/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load a trained model and vocabulary that you have fine-tuned
model = BertForQuestionAnswering.from_pretrained(output_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir)

with io.capture_output() as captured:    
    model.to(device)

### 3B. Run inference on arbitrary Question + Context:

In [12]:
import torch
import numpy as np

from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.tokenization import BertTokenizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_seq_length = 384

context = "Tennis is a racket sport that can be played individually against a single opponent (singles) or between two teams of two players each (doubles). Each player uses a tennis racket that is strung with cord to strike a hollow rubber ball covered with felt over or around a net and into the opponent's court. The object of the game is to maneuver the ball in such a way that the opponent is not able to play a valid return. The player who is unable to return the ball will not gain a point, while the opposite player will."
#input("Insert a paragraph: ")

In [13]:
question = "What is the goal of a tennis match?"
#input("Type in a question that can be answered by a span from a paragraph: ")

In [14]:
# TOKENIZE QUESTION AND CONTEXT

tokenizer = BertTokenizer(vocab_file="bert-base-uncased-vocab.txt")
query_tokens = tokenizer.tokenize(question)
segment_tokens = tokenizer.tokenize(context)

tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in query_tokens:
    tokens.append(token)
    segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)

for token in segment_tokens:
    tokens.append(token)
    segment_ids.append(1)

tokens.append("[SEP]")
segment_ids.append(1)

input_ids = tokenizer.convert_tokens_to_ids(tokens)

# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)

# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
    input_ids.append(0)
    input_mask.append(0)
    segment_ids.append(0)

input_ids = torch.tensor(input_ids).unsqueeze(0)
input_mask = torch.tensor(input_mask).unsqueeze(0)
segment_ids = torch.tensor(segment_ids).unsqueeze(0)

input_ids = input_ids.to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)

def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

with torch.no_grad():
    start_logits, end_logits = model(input_ids, segment_ids, input_mask)

    start_logits = start_logits.detach().cpu().numpy()
    end_logits = end_logits.detach().cpu().numpy()
    
start = np.argmax(start_logits)
end = np.argmax(end_logits)

answer = ""

for i in range(start, end+1):
    answer = answer + tokens[i] + " "
print(answer)

to maneuver the ball in such a way that the opponent is not able to play a valid return 
