# **Open Domain Question Answering with BERT**
### Author: Juan J. Roesel

# Introduction

This notebook contains code that fine-tunes a DistilBERT model using a subset of the SQuAD dataset. The dataset comes in three different splits - `train`, `dev` and `test`, with each split containing a `tsv` file with questions, contexts, answers, and spans.

For instance, the following is an example from the first input of the SQuAD `train` dataset:

**question:** 
```
what percentage of imperial 's staff was classified as world leading in 2008 ?
```

**context:** 

```
the 2008 research assessment exercise returned 26 % of the 1225 staff submitted as being world-leading ( 4* ) and a further 47 % as being internationally excellent ( 3* ) . the 2008 research assessment exercise also showed five subjects – pure mathematics , epidemiology and public health , chemical engineering , civil engineering , and mechanical , aeronautical and manufacturing engineering – were assessed to be the best [ clarification needed ] in terms of the proportion of internationally recognised research quality .
```

**answer:** 
```
26 %
```
**span:** 
```
6 7
```
The following is the data distribution per split:
* `train`: 91,412 examples
* `dev`: 5,854 examples
* `test`: 8,000 examples

After three training epochs, the model achieves the following metrics on the `test` set:
* Exact Matches: 0.87
* F1 Score: 0.91

# Contents

* [Initial Set-up](#1)
* [Data processing](#2)
* [Model training](#3)
* [Model evaluation and inference](#4)
* [System limitations](#5)

<a name="1"></a>
# Initial Set-up

In [5]:
# Initial set up
!pip install transformers
from google.colab import drive
drive.mount('/content/gdrive')

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 8.1 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 4.9 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 61.9 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 72.8 MB/s 
Installing collected packages: pyyaml, tokenizers, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    Found existing installation: PyYAML 3.13
    Uninstallin

In [6]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import transformers
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
from transformers import get_scheduler
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, f1_score
from tqdm.auto import tqdm
import collections
import os
import random
import time
from datetime import timedelta
import string, re

transformers.logging.set_verbosity_error()  # output only ERROR level logs

In [18]:
# Set seed and working device
def set_seed(seed_value=40):
    """
    Set seed for reproducibility.
    """
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)


set_seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
n_gpu = torch.cuda.device_count()
print(torch.cuda.get_device_name())

cuda
Tesla P100-PCIE-16GB


<a name="2"></a>
# Data processing

Here we are leveraging BERT's built-in [Tokenizer](https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__) to take our text inputs from the SQuAD dataset (i.e., a list of questions and a list of corresponding contexts) and turn them into tensors. The function takes the following arguments:
- `padding`: Adds padding when the question and/or context is too short.
- `truncation`: Limits the question and/or context to 512 characters (required by BERT). The truncation strategy used is `longest_first` by default.
- `max_length`: Controls the maximum length used by `truncation` and `padding` parameters.
- `return_tensors`: Returns PyTorch tensors objects, from which `input_ids` and `attention_mask` can be derived.

In [19]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def convert_to_BERT_tensors(questions, contexts):
    """
    Takes a parallel list of question strings and context strings and converts them into BERT tensors.
    """
    tokenized = tokenizer(questions, contexts, padding=True, truncation=True, max_length=512, return_tensors="pt")
    return tokenized["input_ids"], tokenized["attention_mask"]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [48]:
test_questions = ["Why?", "How?", "When?"]
test_contexts = ["Because we can", 
                 "Working hard while having fun!", 
                 "Starting today!" + "".join(["!"] * 512) + "Let's go!"]

ids, mask = convert_to_BERT_tensors(test_questions,test_contexts)
print(f"ids shape: {ids.shape}")
print(f"mask shape: {mask.shape}")

ids shape: torch.Size([3, 512])
mask shape: torch.Size([3, 512])


In [49]:
tokenizer.tokenize("[CLS]" + test_questions[0] + "[SEP]" + test_contexts[0])

['[CLS]', 'why', '?', '[SEP]', 'because', 'we', 'can']

In [50]:
# First row: '[CLS]', 'why', '?', '[SEP]', 'because', ´we´, 'can' => (With padding)
ids[0][0:100]

tensor([ 101, 2339, 1029,  102, 2138, 2057, 2064,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0])

In [51]:
tokenizer.tokenize("[CLS]" + test_questions[2] + "[SEP]" + test_contexts[2])[:10]

['[CLS]', 'when', '?', '[SEP]', 'starting', 'today', '!', '!', '!', '!']

In [52]:
# Second row: '[CLS]', 'when', '?', '[SEP]', 'starting', 'today' => (Without padding)
ids[2][0:100]

tensor([ 101, 2043, 1029,  102, 3225, 2651,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,  999,
         999,  999,  999,  999])

With an approach in place to convert questions and contexts into tensors, we need to turn our attention now to the answers provided in the SQuAD dataset, which come in the form of a string (e.g., "26 %") and a span (e.g., [6, 7]) representing token indices in the context tensor.

Since BERT Tokenizer object merges each question and context into a single input, the token indices provided in the span need to be recalculated. This work will be conducted by the following function.

In [55]:
def get_answer_span_tensor(question, context, answer):
    """
    Recomputes the answer span by combining the question and context into an input and
    identifying the correct answer span inside of it.
    If the answer doesn't appear in the input, it will return [0, 0].
    """
    input_str = "[CLS]" + question + "[SEP]" + context
    input_tokens = tokenizer.tokenize(input_str)
    answer_tokens = tokenizer.tokenize(answer)
    span_len = len(answer_tokens)
    # setting a pointer to identify exact answer span
    for i in range(min(len(input_tokens) - span_len + 1, 512 - span_len - 1)):
        if input_tokens[i:i + span_len] == answer_tokens:
            answer_span = torch.tensor([i, i + span_len - 1])
            break
    else:
        answer_span = torch.tensor([0, 0])
        
    return answer_span

In [56]:
test_answer = "Having fun!"
test_answer_span = get_answer_span_tensor(test_questions[1], test_contexts[1], test_answer)
test_answer_span

tensor([7, 9])

In [None]:
input_str = "[CLS]" + test_questions[1] + "[SEP]" + test_contexts[1]
[(i, t) for i, t in enumerate(tokenizer.tokenize(input_str))]

[(0, '[CLS]'),
 (1, 'how'),
 (2, '?'),
 (3, '[SEP]'),
 (4, 'working'),
 (5, 'hard'),
 (6, 'while'),
 (7, 'having'),
 (8, 'fun'),
 (9, '!')]

Lastly, we will proceed to create the data structure and dataloaders to prepare our training, dev, and test datasets.

In [22]:
BATCH_SIZE=16

class QAdataset(Dataset):
    """
    A custom dataset for housing QA data, including input_data, output_data, and padding mask.
    """
    def __init__(self, input_data, output_data, mask):
        self.input_data = input_data
        self.output_data = output_data
        self.mask = mask

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, index):
        data_val = self.input_data[index]
        target = self.output_data[index]
        mask = self.mask[index]
        return data_val, target, mask


def read_files(path, split):
    """
    Reads the SQuAD files and converts them into tensors.
    """
    for entry in os.listdir(path):
      if split in entry:
          with open(path + entry, "r", encoding="utf-8") as f:
            if "question" in entry:
              questions = f.readlines()
            elif "context" in entry:
              contexts = f.readlines()
            elif "answer" in entry:
              answers = f.readlines()
            else:
              continue
    return questions, contexts, answers


def prepare_QA_dataset(split, path):
    """
    Prepares the PyTorch dataset for the train, dev, and test splits.
    """
    questions, contexts, answers = read_files(path, split)
    ids, mask = convert_to_BERT_tensors(questions, contexts)
    spans = []
    for question, context, answer in zip(questions, contexts, answers):
        spans.append(get_answer_span_tensor(question, context, answer))
    return QAdataset(ids, spans, mask)


def prepare_dataloaders(split, squad_path, output_path, dtl_fn, 
                        batch_size=BATCH_SIZE, shuffle=False):
    """
    Helper function to generate and persist Dataloaders for each split.
    """
    data = prepare_QA_dataset(split, squad_path)
    dataloader = DataLoader(data, 
                            batch_size=batch_size, 
                            shuffle=shuffle)
    torch.save(dataloader, output_path + f"/{dtl_fn}")
    print(f"Generated and saved {dtl_fn}")

In [None]:
# define relevant dir paths
squad_path = '/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/data/'
small_squad_path = "/content/gdrive/MyDrive/Colab Notebooks/data/small/"
artifacts_dirpath = "/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/artifacts/"

# prepare Dataloaders
prepare_dataloaders("train", squad_path, artifacts_dirpath, "train.dtl")
prepare_dataloaders("dev", squad_path, artifacts_dirpath, "dev.dtl")
prepare_dataloaders("test", squad_path, artifacts_dirpath, "test.dtl")

---
<a name="3"></a>
# Model training

In order to train our BERT model to be suitable for the QA task at hand, we will load the `DistilBertForQuestionAnswering` module, which comes with a regular DistilBert pre-trained language model `distilbert-base-uncased`.

DistilBert is a "lightweight" version of BERT thanks to the [knowledge distillation](https://medium.com/huggingface/distilbert-8cf3380435b5) technique, which makes it much faster to train while keeping most of its performance.

THe model will be fine-tuned for three epochs, using a linear scheduler with the `AdamW` optimizer. 

In [23]:
# load Dataloaders
artifacts_dirpath = "/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/artifacts/"

train_iter = torch.load(artifacts_dirpath + "/train.dtl")
dev_iter = torch.load(artifacts_dirpath + "/dev.dtl")
test_iter = torch.load(artifacts_dirpath + "/test.dtl")

In [24]:
# load dev and test questions, contexts, and answers
squad_path = '/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/data/'
dev_questions, dev_contexts, dev_gold_answers = read_files(squad_path, "dev")
test_questions, test_contexts, test_gold_answers = read_files(squad_path, "test")

In [25]:
# Parameters
LR = 3e-5
MAX_GRAD_NORM = 1.0
EPOCHS = 3
WARMUP_PROPORTION = 0.1
NUM_TRAINING_STEPS = len(train_iter) * EPOCHS
NUM_WARMUP_STEPS = NUM_TRAINING_STEPS * WARMUP_PROPORTION

In [26]:
ckpt_path = "/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/ckpt/"
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased").to(device)
optimizer = optim.AdamW(model.parameters(), lr=LR)
lr_scheduler = get_scheduler(name="linear", 
                             optimizer=optimizer, 
                             num_warmup_steps=NUM_WARMUP_STEPS,
                             num_training_steps=NUM_TRAINING_STEPS)
criterion = nn.CrossEntropyLoss()

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

In [27]:
# model parameter count
def count_parameters(model):
    """
    Counts number of trainable model parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The model has {count_parameters(model):,} trainable parameters")

The model has 66,364,418 trainable parameters


In [28]:
def train(model, optimizer, scheduler, criterion, iterator, t_batch):
    """
    Runs batch training on the model. Returns the accumulated loss at each epoch.
    """
    epoch_loss = 0
    model.train()
    for i, (ids, spans, mask) in enumerate(iterator):
        # initialize gradients
        model.zero_grad()
        # loads inputs into GPU
        ids = ids.to(device)
        spans = spans.to(device)
        mask = mask.to(device)
        # calls the model with inputs
        outputs = model(ids, mask)
        start_loss = criterion(outputs.start_logits, spans[:, 0])
        end_loss = criterion(outputs.end_logits, spans[:, 1])
        total_loss = start_loss + end_loss
        # performs backward pass and updates optimizer
        total_loss.backward()
        optimizer.step()
        lr_scheduler.step()
        epoch_loss += total_loss
        # delete used variables to free GPU memory
        del ids, spans, mask
        # clears cache
        torch.cuda.empty_cache()
        # print status after 20 steps
        if i % 20 == 0:
          time_elapsed = time.time() - t_batch
          print(f"Processed {i * BATCH_SIZE} QA pairs of {len(train_iter.dataset)}")
          print(f"Last loss: {total_loss.item()}")
          print(f"Time_elapsed: {str(timedelta(seconds=time_elapsed))}")
        epoch_loss += total_loss.item()
        progress_bar.update(1)
    return epoch_loss

In [29]:
def evaluate(model, tokenizer, iterator, gold_answers):
    """
    Evaluates the model using Exact Match (EM) and F1 score metrics.
    """
    pred_starts, pred_ends = [], []
    gold_starts, gold_ends = [], []
    pred_answers = []
    pred_f1_score = 0
    pred_em_score = 0
    progress_bar = tqdm(range(len(iterator)))
    model.eval()

    with torch.no_grad():
        for ids, spans, mask in iterator:
            # loads inputs into GPU
            ids, spans, mask = ids.to(device), spans.to(device), mask.to(device)
            output = model(ids, mask)
            # compute start and end span probabilities
            start_probs = output.start_logits.to("cpu").detach()
            pred_starts.extend(list(np.argmax(start_probs.numpy(), axis=1)))
            end_probs = output.end_logits.to("cpu").detach()
            pred_ends.extend(list(np.argmax(end_probs.numpy(), axis=1)))
            # prepare ground truth data
            gold_targets = spans.to("cpu").detach().numpy()
            gold_starts.extend(list(gold_targets[:, 0]))
            gold_ends.extend(list(gold_targets[:, 1]))
            # gathers pred and gold text answers for F1 score
            pred_answers.extend(get_pred_answers(ids, spans, tokenizer))
            progress_bar.update(1)

    # compute evaluation metrics
    pred_acc = accuracy_score(gold_starts, pred_starts)

    assert len(pred_answers) == len(gold_answers)
    for pred_answer, gold_answer in zip(pred_answers, gold_answers):
        pred_em_score += compute_exact_match(pred_answer, gold_answer)
        pred_f1_score += compute_f1(pred_answer, gold_answer)
    
    results = {
        "pred_acc": pred_acc,
        "pred_em": pred_em_score / len(pred_answers),
        "pred_f1": pred_f1_score / len(pred_answers),
    }

    return results

In [30]:
# functions adapted from https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html
def to_list(tensor):
    """
    Helper function to convert tensors into lists.
    """
    return tensor.detach().cpu().tolist()


def get_pred_answers(batch_ids, spans, tokenizer):
    """
    Takes token text ids and spans from a given batch and outputs text answers.
    """
    answers = []
    for i in range(len(spans)):
        answers.append(tokenizer.decode(batch_ids[i, spans[i][0]: spans[i][1] + 1]))
    return answers


def normalize_text(s):
    """
    Applies regular text processing techniques on a string.
    Pre-requisite for computing F1 scores between true and pred answers. 
    """
    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    output_text = white_space_fix(remove_articles(remove_punc(lower(s))))
    
    return output_text


def compute_exact_match(prediction, gold):
    """
    Evaluates whether pred and gold answers are an exact match. 
    Returns 1 if True, 0 otherwise.
    """
    return int(normalize_text(prediction) == normalize_text(gold))


def compute_f1(prediction, gold):
    """
    Computes F1 Score on top of the words included in pred and gold answers
      - Precision: Proportion of common tokens over total predicted tokens
      - Recall: Proportion of common tokens over total gold tokens
    """
    pred_tokens = normalize_text(prediction).split()
    gold_tokens = normalize_text(gold).split()
    
    # if either the prediction or the gold is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return int(pred_tokens == gold_tokens)
    
    common_tokens = set(pred_tokens) & set(gold_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(gold_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

In [31]:
# Executes training
progress_bar = tqdm(range(NUM_TRAINING_STEPS))
t_epoch, t_batch = time.time(), time.time()

for epoch in range(EPOCHS):
    print(f"Epoch: {epoch + 1}/{EPOCHS}")
    epoch_loss = train(model, optimizer, lr_scheduler, criterion, train_iter, t_batch)
    dev_results = evaluate(model, tokenizer, dev_iter, dev_gold_answers)
    print(f"After Epoch: {epoch}")
    print(f"Loss: {epoch_loss}")
    print(f"Pred acc: {dev_results['pred_acc']}")
    print(f"Pred EM: {dev_results['pred_em']}")
    print(f"Pred F1: {dev_results['pred_f1']}")
    
    # stores model checkpoint
    torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'lr_scheduler_state_dict': lr_scheduler.state_dict(),
              'loss': epoch_loss,
              }, ckpt_path + f"ckpt_{epoch}.pt")
    
    print(f"Model state saved - epoch: {epoch}")
    print(f"{'=' * 20}")

total_time = time.time() - t_epoch
print(f"Training finished! Total time: {str(timedelta(seconds=total_time))}")

  0%|          | 0/14544 [00:00<?, ?it/s]

Epoch: 1/3
Processed 0 QA pairs of 77558
Last loss: 12.41524887084961
Time_elapsed: 0:00:00.647514
Processed 320 QA pairs of 77558
Last loss: 12.380169868469238
Time_elapsed: 0:00:12.542874
Processed 640 QA pairs of 77558
Last loss: 12.342016220092773
Time_elapsed: 0:00:24.442527
Processed 960 QA pairs of 77558
Last loss: 12.260942459106445
Time_elapsed: 0:00:36.321777
Processed 1280 QA pairs of 77558
Last loss: 12.179329872131348
Time_elapsed: 0:00:48.221089
Processed 1600 QA pairs of 77558
Last loss: 11.81229019165039
Time_elapsed: 0:01:00.130959
Processed 1920 QA pairs of 77558
Last loss: 11.63782787322998
Time_elapsed: 0:01:12.003403
Processed 2240 QA pairs of 77558
Last loss: 11.114632606506348
Time_elapsed: 0:01:23.903088
Processed 2560 QA pairs of 77558
Last loss: 10.81821060180664
Time_elapsed: 0:01:35.785419
Processed 2880 QA pairs of 77558
Last loss: 10.155652046203613
Time_elapsed: 0:01:47.673671
Processed 3200 QA pairs of 77558
Last loss: 9.397003173828125
Time_elapsed: 0:0

  0%|          | 0/366 [00:00<?, ?it/s]

After Epoch: 0
Loss: 35097.50390625
Pred acc: 0.6219678852066963
Pred EM: 0.8795695251110351
Pred F1: 0.9231976335339813
Model state saved - epoch: 0
Epoch: 2/3
Processed 0 QA pairs of 77558
Last loss: 2.675173282623291
Time_elapsed: 0:49:06.853865
Processed 320 QA pairs of 77558
Last loss: 2.295081377029419
Time_elapsed: 0:49:18.789929
Processed 640 QA pairs of 77558
Last loss: 1.9046525955200195
Time_elapsed: 0:49:30.700718
Processed 960 QA pairs of 77558
Last loss: 2.859529972076416
Time_elapsed: 0:49:42.611381
Processed 1280 QA pairs of 77558
Last loss: 2.0963034629821777
Time_elapsed: 0:49:54.532248
Processed 1600 QA pairs of 77558
Last loss: 1.9848079681396484
Time_elapsed: 0:50:06.432890
Processed 1920 QA pairs of 77558
Last loss: 2.6130688190460205
Time_elapsed: 0:50:18.353493
Processed 2240 QA pairs of 77558
Last loss: 1.8378264904022217
Time_elapsed: 0:50:30.276696
Processed 2560 QA pairs of 77558
Last loss: 3.491802215576172
Time_elapsed: 0:50:42.184447
Processed 2880 QA pai

  0%|          | 0/366 [00:00<?, ?it/s]

After Epoch: 1
Loss: 18445.638671875
Pred acc: 0.6327297574308165
Pred EM: 0.8795695251110351
Pred F1: 0.9231976335339813
Model state saved - epoch: 1
Epoch: 3/3
Processed 0 QA pairs of 77558
Last loss: 2.221468687057495
Time_elapsed: 1:38:21.044384
Processed 320 QA pairs of 77558
Last loss: 1.7286014556884766
Time_elapsed: 1:38:32.978105
Processed 640 QA pairs of 77558
Last loss: 1.0070310831069946
Time_elapsed: 1:38:44.920441
Processed 960 QA pairs of 77558
Last loss: 1.7972159385681152
Time_elapsed: 1:38:56.830754
Processed 1280 QA pairs of 77558
Last loss: 1.274364948272705
Time_elapsed: 1:39:08.738296
Processed 1600 QA pairs of 77558
Last loss: 1.6326091289520264
Time_elapsed: 1:39:20.664204
Processed 1920 QA pairs of 77558
Last loss: 1.8386147022247314
Time_elapsed: 1:39:32.583274
Processed 2240 QA pairs of 77558
Last loss: 0.9599722623825073
Time_elapsed: 1:39:44.487503
Processed 2560 QA pairs of 77558
Last loss: 1.9109464883804321
Time_elapsed: 1:39:56.399564
Processed 2880 QA 

  0%|          | 0/366 [00:00<?, ?it/s]

After Epoch: 2
Loss: 12898.7373046875
Pred acc: 0.6477622138708575
Pred EM: 0.8795695251110351
Pred F1: 0.9231976335339813
Model state saved - epoch: 2
Training finished! Total time: 2:27:38.149691


<a name="4"></a>
# Model Evaluation and Inference

In [32]:
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased").to(device)
ckpt_path = "/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/ckpt/"
checkpoint = torch.load(ckpt_path + "ckpt_2.pt")
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [33]:
squad_path = '/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/data/'
test_questions, test_contexts, test_gold_answers = read_files(squad_path, "test")
test_results = evaluate(model, tokenizer, test_iter, test_gold_answers)

  0%|          | 0/500 [00:00<?, ?it/s]

In [58]:
print(f"Test pred acc: {test_results['pred_acc']}")
print(f"Test pred EM: {test_results['pred_em']}")
print(f"Test pred F1: {test_results['pred_f1']}")

Test pred acc: 0.673875
Test pred EM: 0.8675
Test pred F1: 0.91526472890887


In [59]:
from transformers import pipeline

model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
ckpt_path = "/content/gdrive/MyDrive/Colab Notebooks/open_domain_QA_BERT/ckpt/"
checkpoint = torch.load(ckpt_path + "ckpt_2.pt")
model.load_state_dict(checkpoint['model_state_dict'])
QA_BERT_model = pipeline("question-answering", model=model, tokenizer=tokenizer)

In [60]:
test_question = "What is the difference between Work 4.0 and the other phases of work relations?"
test_context = """
Conceptually, Work 4.0 reflects the current fourth phase of work relations, having been preceded by 
the birth of industrial society and the first workers' organizations in the late 18th century (Work 1.0), 
the beginning of mass production and of the welfare state in the late 19th century (Work 2.0), 
and the advent of globalization, digitalization and the transformation of the social market economy since 
the 1970s (Work 3.0). By contrast, Work 4.0 is characterized by a high degree of integration and cooperation, 
the use of digital technologies (e.g. the internet), and a rise in flexible work arrangements. 
Its drivers include digitalization, globalization, demographic change (ageing, migration), and cultural change.
"""

QA_BERT_model(question=test_question, context=test_context)

{'answer': 'high degree of integration and cooperation,',
 'end': 522,
 'score': 0.37090301513671875,
 'start': 478}

In [38]:
test_question2 = "How is COVID-19 transmitted?"
test_context2 = """
COVID‑19 transmits when people breathe in air contaminated by droplets and small airborne particles containing 
the virus. The risk of breathing these in is highest when people are in close proximity, but they can be i
nhaled over longer distances, particularly indoors. Transmission can also occur if splashed or sprayed with 
contaminated fluids in the eyes, nose or mouth, and, rarely, via contaminated surfaces. 
People remain contagious for up to 20 days, and can spread the virus even if they do not develop symptoms.
"""

QA_BERT_model(question=test_question2, context=test_context2)

{'answer': 'when people breathe in air contaminated by droplets and small airborne particles containing the virus.',
 'end': 123,
 'score': 0.3649998903274536,
 'start': 20}

In [39]:
test_question3 = "Why is the sky blue?"
test_context3 = """
The Earth's atmosphere scatters short-wavelength light more efficiently than that of longer wavelengths. 
Because its wavelengths are shorter, blue light is more strongly scattered than the longer-wavelength lights, 
red or green. Hence the result that when looking at the sky away from the direct incident sunlight, 
the human eye perceives the sky to be blue.
"""
QA_BERT_model(question=test_question3, context=test_context3)

{'answer': 'Because its wavelengths are shorter,',
 'end': 143,
 'score': 0.35249170660972595,
 'start': 107}

In [40]:
test_question4 = "Can the sky ever be green?"
QA_BERT_model(question=test_question4, context=test_context3)

{'answer': 'the human eye perceives the sky to be blue.',
 'end': 362,
 'score': 0.04962889477610588,
 'start': 319}

In [41]:
test_question5 = "What is our projected CAGR for the next 5 years?"
test_context5 = """
Our CFO estimates strong revenue growth forcasts for the next five years. 
According to our calculations, and assuming market conditions hold, we can expect to achieve a 25% CAGR
for this period."""

QA_BERT_model(question=test_question5, context=test_context5)

{'answer': '25%', 'end': 174, 'score': 0.7485947012901306, 'start': 171}

<a name="5"></a>
# System limitations

* Assumes that each question only has one possible answer.
* Doesn't select the most optimal start/end span.
* Doesn't provide empty answers when answer is not contained in the context.
* System interactions are stateless (i.e., model processes each query independently).
* Requires context to provide an answer (in other words, is not a generative question answering model)
* Can only model one single document (context) at a time. 