# BERT for Low-Resource Question-Answering

Pre-trained language representations have been shown to improve many downstream NLP tasks such as question answering and natural language inference (NLI). Devlin, Jacob, et al proposed `BERT` [1] (Bidirectional Encoder Representations from Transformers), which fine-tunes deep bi-directional representations on a wide range of tasks with minimal task-specific parameters, and obtained state-of-the-art results.

In this tutorial, you will focus on adapting the `BERT` model for the question answering task on the `SQuAD` dataset. Specifically, you will:

- Understand how to pre-process the `SQuAD` dataset to leverage the learnt representations in `BERT`.
- Adapt the `BERT` model to the question-answering task.
- Load a trained model to perform inference on the `SQuAD` dataset.


In [None]:
import collections
import numpy as np
import mxnet as mx
from mxnet import gluon, metric, autograd
from mxnet.gluon.contrib import estimator
import gluonnlp as nlp
ctx = mx.gpu(0)
!ulimit -n 16384

The `bert` module can be downloaded from the [GluonNLP BERT model zoo](https://gluon-nlp.mxnet.io/v0.9.x/model_zoo/bert/index.html).

In [None]:
import bert
from bert import bert_qa_evaluate

## Prepare the `SQuAD` dataset

In [None]:
squad_train_original = nlp.data.SQuAD(segment='train', version='1.1')
squad_dev = nlp.data.SQuAD(segment='dev', version='1.1')

The format of each record of the dataset is following:

- record_index:  An index of the record, generated on the fly (0 ... to # of last question)
- question_id:   Question Id. It is a string and taken from the original json file as-is
- question:      Question text, taken from the original json file as-is
- context:       Context text.  Will be the same for questions from the same context
- answer_list:   All answers for this question. Stored as python list
- start_indices: All answers' starting indices. Stored as python list.
  The position in this list is the same as the position of an answer in answer_list

In [None]:
squad_train_original[0]

### Downsample for Low-resource Dataset

In [None]:
class RandomDownSampler(gluon.data.Sampler):
    def __init__(self, length, ratio):
        self._length = length
        self._ratio = ratio
        self._count = int(np.round(length * ratio))

    def __iter__(self):
        indices = np.arange(self._length)
        np.random.shuffle(indices)
        indices = indices[:self._count]
        return iter(indices)

    def __len__(self):
        return self._count

In [None]:
squad_train = squad_train_original.sample(RandomDownSampler(len(squad_train_original), 0.1))

In [None]:
print('Original # samples: {}, downsampled to # samples: {}'.format(len(squad_train_original),
                                                                    len(squad_train)))

### Data pre-processing for QA with `BERT`

![qa](img/qa.png)

In [None]:
bert_encoder, vocab = nlp.model.get_model('bert_12_768_12',
                                          dataset_name='openwebtext_book_corpus_wiki_en_uncased',
                                          use_classifier=False,
                                          use_decoder=False,
                                          use_pooler=False,
                                          pretrained=True,
                                          ctx=ctx)

In [None]:
print(vocab)

### Subword Tokenizing

In [None]:
tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)

tokenizer("as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals")

### QA For BERT Transformation

The transformation is processed in the following steps:
- Tokenize the question_text in the example.
- For examples where the document is too long,
  use a sliding window to split into multiple features and
  record whether each token is a maximum context.
- Tokenize the split document chunks.
- Combine the token of question_text with the token
  of the document and insert [CLS] and [SEP].
- Generate the start position and end position of the answer.
- Generate valid length.

The functionality is available via the `SQuADTransform` API from BERT model zoo. 

In [None]:
bert_qa_transform = bert.data.qa.SQuADTransform(tokenizer)

In [None]:
def flatten_dataset(dataset):
    return gluon.data.SimpleDataset([x for xs in dataset for x in xs])

In [None]:
processed_train = flatten_dataset(squad_train.transform(bert_qa_transform))
processed_dev = flatten_dataset(squad_dev.transform(bert_qa_transform))

In [None]:
batchify_fn = nlp.data.batchify.Tuple(
    nlp.data.batchify.Stack(), # example ID
    nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], dtype='float32'), # tokens
    nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token], dtype='float32'), # token types
    nlp.data.batchify.Stack('float32'), # actual sample lengths without padding
    nlp.data.batchify.Stack('float32'), # start positions
    nlp.data.batchify.Stack('float32'), # end positions
    nlp.data.batchify.Stack('float32')) # batch length

In [None]:
train_dataloader = mx.gluon.data.DataLoader(
    processed_train, batchify_fn=batchify_fn,
    batch_size=8, num_workers=4, shuffle=True)
dev_dataloader = mx.gluon.data.DataLoader(
    processed_dev, batchify_fn=batchify_fn,
    batch_size=8, num_workers=0, shuffle=False)

## Defining the model

After the data is processed, you can define the model that uses the representation produced by BERT for predicting the starting and ending positions of the answer span.

In [None]:
class BertForQA(mx.gluon.HybridBlock):
    def __init__(self, bert_encoder, prefix=None, params=None):
        super(BertForQA, self).__init__(prefix=prefix, params=params)
        self.bert = bert_encoder
        with self.name_scope():
            self.span_classifier = mx.gluon.nn.Dense(units=2, flatten=False)

    def hybrid_forward(self, F, inputs, token_types, valid_length=None):
        # Use self.bert to get the representation for each token.
        bert_output = self.bert(inputs, token_types, valid_length)
        
        # Use self.span_classifier to predict the start and end spans
        return self.span_classifier(bert_output)

Now download a BERT model trained on the SQuAD dataset, and prepare the `DataLoader`.

In [None]:
net = BertForQA(bert_encoder)
net.span_classifier.initialize(ctx=ctx)

In [None]:
for p in net.collect_params('.*beta|.*gamma|.*bias').values():
    p.wd_mult = 0.0

In [None]:
learnable_params = [p for p in net.collect_params().values() if p.grad_req != 'null']

## Training

In [None]:
epochs = 2
warmup_ratio = 0.1
num_train_steps = epochs * len(train_dataloader)
num_warmup_steps = int(num_train_steps * warmup_ratio)
lr = 3e-5

In [None]:
loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [None]:
trainer = gluon.Trainer(net.collect_params(), 'adam',
                        {'learning_rate': lr})

In [None]:
import utils

lr_handler = utils.MyLearningRateHandler(trainer, num_warmup_steps, num_train_steps, lr)

In [None]:
metrics = [metric.Loss()]

In [None]:
class QAEstimator(estimator.Estimator):
    def fit_batch(self, train_batch, batch_axis=0):
        _, data, token_types, valid_length, start_label, end_label, _ = train_batch
        label = mx.nd.stack(start_label, end_label, axis=1)

        with autograd.record():
            pred = self.net(data.as_in_context(ctx),
                            token_types.as_in_context(ctx),
                            valid_length.as_in_context(ctx))
            pred = pred.transpose((0, 2, 1))
            loss = self.loss(pred, label.as_in_context(ctx))

        loss.backward()
        nlp.utils.clip_grad_global_norm(learnable_params, 1)

        self.trainer.step(1)

        return data, label, pred, loss
    
    def evaluate_batch(self,
                       val_batch,
                       val_metrics,
                       batch_axis=0):
        _, data, token_types, valid_length, start_label, end_label, _ = val_batch
        label = mx.nd.stack(start_label, end_label, axis=1)
        pred = self.net(data.as_in_context(ctx),
                        token_types.as_in_context(ctx),
                        valid_length.as_in_context(ctx))
        pred = pred.transpose((0, 2, 1))
        loss = self.loss(pred, label.as_in_context(ctx))
        # update metrics
        for m in val_metrics:
            if isinstance(m, metric.Loss):
                m.update(0, loss)
            else:
                m.update(label, pred)

In [None]:
est = QAEstimator(net=net, loss=loss,
                  metrics=metrics,
                  trainer=trainer,
                  context=ctx)

In [None]:
est.fit(train_data=train_dataloader,
        epochs=epochs)

In [None]:
val_metrics = [metric.Loss()]
est.evaluate(val_data=dev_dataloader,
             val_metrics=val_metrics)

In [None]:
val_metrics[0].get()

And lastly, take a look at the predictions your model can make.

In [None]:
def predict(net, dataset, dev_dataloader, vocab):
    tokenizer = nlp.data.BERTTokenizer(vocab=vocab, lower=True)
    transform = bert.data.qa.SQuADTransform(tokenizer, is_pad=False,
                                            is_training=False, do_lookup=False,
                                            return_fields=False)
    dev_dataset = dataset.transform(transform)
    
    all_results = []

    for data in dev_dataloader:
        example_ids, inputs, token_types, valid_length, _, _, _ = data
        output = net(inputs.as_in_context(ctx),
                     token_types.as_in_context(ctx),
                     valid_length.as_in_context(ctx))
        pred_start, pred_end = mx.nd.split(output, axis=2, num_outputs=2)

        batch_size = example_ids.shape[0]
        all_results.append((example_ids.asnumpy().tolist(),
                            pred_start.reshape(batch_size, -1).asnumpy(),
                            pred_end.reshape(batch_size, -1).asnumpy()))

    all_results_np = collections.defaultdict(list)
    for example_ids, pred_start, pred_end in all_results:
        for example_id, start, end in zip(example_ids, pred_start, pred_end):
            all_results_np[example_id].append(
                bert_qa_evaluate.PredResult(start=start, end=end))

    all_predictions = collections.OrderedDict()
    top_results = []
    for features in dev_dataset:
        results = all_results_np[features[0].example_id]

        prediction, nbest = bert_qa_evaluate.predict(
            features=features,
            results=results,
            tokenizer=nlp.data.BERTBasicTokenizer(lower=True))
        qas_id = features[0].qas_id
        all_predictions[qas_id] = prediction
        curr_result = {}
        question = features[0].input_ids.index('[SEP]')
        curr_result['context'] = features[0].doc_tokens
        curr_result['question'] = features[0].input_ids[1:question]
        curr_result['prediction'] = nbest[0]
        top_results.append(curr_result)
    return top_results, all_predictions

In [None]:
top_results, all_predictions = predict(net, squad_dev, dev_dataloader, vocab)
first_sample_result = top_results[0]
print('Question: %s\n'%(' '.join((first_sample_result['question']))))
print('Top prediction: %.2f%% \t %s'%(first_sample_result['prediction'][1] * 100, first_sample_result['prediction'][0]))
print('\nContext: %s\n'%(' '.join(first_sample_result['context'])))

In [None]:
bert_qa_evaluate.get_F1_EM(squad_dev, all_predictions)

## Exercise 1: even lower resources

It is impressive that with just 1/10 of the SQuAD training dataset, the finetuned BERT model can already perform reasonably well. Here's the challenge: can you devise a way to use even less data, but still achieve `f1 > 80`?

Implement your idea in the form of a dataset sampler below.

In [None]:
class MySampler(gluon.data.Sampler):
    def __init__(self):
        pass

    def __iter__(self):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

In [None]:
my_sampler = MySampler()

In [None]:
new_squad_train = squad_train_original.sample(my_sampler)

In [None]:
new_processed_train = flatten_dataset(new_squad_train.transform(bert_qa_transform))

In [None]:
new_train_dataloader = mx.gluon.data.DataLoader(
    new_processed_train, batchify_fn=batchify_fn,
    batch_size=8, num_workers=4, shuffle=True)

In [None]:
est.fit(train_data=new_train_dataloader,
        epochs=epochs)

In [None]:
est.evaluate(val_data=dev_dataloader,
             val_metrics=val_metrics)
val_metrics[0].get()

In [None]:
_, all_predictions = predict(net, squad_dev, dev_dataloader, vocab)
bert_qa_evaluate.get_F1_EM(squad_dev, all_predictions)

## Exercise 2: out of domain QA

The possibility of getting a reasonable QA model from low resource suggests the potential power of generalization. One might ask: how well can it perform on a dataset that's out of domain?

In this exercise, you will implement the logic to load a dataset from the MRQA 2019 Shared Task. These datasets follow the same format as SQuAD. We will evaluate our 1/10 model on one of such dataset to answer the above question.

The dataset should return the following fields, the same as the SQuAD 1.1 dataset:

- record_index:  An index of the record, generated on the fly (0 ... to # of last question)
- question_id:   Question Id. It is a string and taken from the original json file as-is
- question:      Question text, taken from the original json file as-is
- context:       Context text.  Will be the same for questions from the same context
- answer_list:   All answers for this question. Stored as python list
- start_indices: All answers' starting indices. Stored as python list.
  The position in this list is the same as the position of an answer in answer_list

In [None]:
class MyDataset(gluon.data.Dataset):
    def __getitem__(self, idx):
        raise NotImplementedError

    def __len__(self):
        raise NotImplementedError

In [None]:
new_dev = MyDataset()

In [None]:
new_processed_dev = flatten_dataset(new_dev.transform(bert_qa_transform))

In [None]:
new_dev_dataloader = mx.gluon.data.DataLoader(
    new_processed_dev, batchify_fn=batchify_fn,
    batch_size=8, num_workers=4, shuffle=True)

In [None]:
est.evaluate(val_data=new_dev_dataloader,
             val_metrics=val_metrics)
val_metrics[0].get()

In [None]:
_, all_predictions = predict(net, new_processed_dev, new_dev_dataloader, vocab)
bert_qa_evaluate.get_F1_EM(new_processed_dev, all_predictions)