# BERT for Question-Answering

In this Notebook, we fine-tune [BERT (Bidirectional Encoder Representations from Transformers)](https://arxiv.org/abs/1810.04805) for Question Answering (Q&A) tasks using the [SQuAD (Stanford Question Answering)](https://rajpurkar.github.io/SQuAD-explorer/explore/1.1/dev/Super_Bowl_50.html) dataset. Developed by Google in 2018, BERT revolutionized the field of NLP by setting new state-of-the-art benchmarks across various NLP (Natural Language Processing) tasks.

BERT is pre-trained on a massive corpus, allowing it to grasp language structure and context. This pre-trained model can then be fine-tuned for specific tasks such as sentiment analysis or question answering. Fine-tuning BERT for Q&A tasks involves adjusting the model to predict the start and end positions of the answer in a given passage for a provided question (extractive question answering). The following steps outline the process of fine-tuning BERT for these tasks:

1. **🌱 Dataset Preparation**:
    - Define each dataset item with a question, a passage (or context), and the start and end positions of the answer within the passage as the label.
    - Tokenize both the question and passage into subwords using BERT's tokenizer. Separate the question from the passage using the `[SEP]` token and start the input sequence with the `[CLS]`
      token.
    - Mark the question as segment `A` (or `0`) and the context as segment `B` (or `1`). Use this information to learn different embeddings for each segment, which are added to the word
      embeddings.
1. **🪡 Model Modification**:
    - Extract embeddings for each token in the sequence from the pre-trained BERT model.
    - Add a dense (fully connected) layer on top of BERT, with two output nodes: one for predicting the start position and one for predicting the end position of the answer in the passage (see
      [code](https://github.com/huggingface/transformers/blob/c385de24414e4ec6125ee14c46c128bfe70ecb66/src/transformers/models/bert/modeling_bert.py#L1803)).
1. **🎯 Training Objective**:
    - Output a score for each token in the passage, indicating how likely that token is the start of the answer, and another score for the end.
    - Apply a SoftMax function over the sequence to get a probability distribution for the start and end positions.
    - Use the sum of the negative log likelihood of the correct start and end positions as the loss function.
1. **🚀 Training**:
    - Initialize training with pre-trained BERT weights.
    - Apply a smaller learning rate (e.g., 2e-5 or 3e-5) since BERT is already pre-trained. Avoid using a larger learning rate, as it may cause the model to diverge.
    - Fine-tune the model on the Q&A dataset for several epochs, stopping when validation performance plateaus or decreases.
1. **✨ Inference**:
    - Tokenize a new question and passage, and add the special `[CLS]` and `[SEP]` tokens.
    - Feed the tokens into the fine-tuned BERT model to get scores for the start and end positions of the answer.
    - Select the tokens between the predicted start and end positions as the final answer.
    - Apply constraints such as ensuring the end position is after the start and limiting the maximum answer length.
1. **🎉 Evaluation**:
    - Measure performance using common metrics like Exact Match (EM), which calculates the percentage of predictions that exactly match the ground truth and the F1 score to account for partial
      matches by considering overlapping words between the prediction and ground truth.

Let's begin by implementing each step, one by one, using the Hugging Face 🤗 ecosystem. First, the imports:

In [None]:
import collections
from functools import partial

import evaluate
import numpy as np
from tqdm import tqdm
from datasets import load_dataset , load_metric
from transformers import pipeline
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

Let's set the IDs for the model and the dataset. We will download both of them from the Hugging Face Hub, using the `datasets` and `transformers` libraries:

In [None]:
DATASET_ID = "rajpurkar/squad"
MODEL_ID = "google-bert/bert-base-uncased"

# Data Processing

In this section, we will download the dataset, cache it locally, and preprocess it into the format described in the introduction. Our goal is to produce examples that contain:
- The tokenized question and context.
- The start position of the answer within the context.
- The end position of the answer within the context.

First, let's download and examine the dataset:

In [None]:
# load the SQuAD dataset
data = load_dataset(DATASET_ID)
data

The dataset has two splits: a `train` split with `87,599` rows, and a `validation` split with `10,570` rows. Each row includes a unique identifier, a title, the context, the question, and one or more possible answers. We will handle each split slightly differently. You’ll see why later, but for now, let's focus on processing the `train` split.

Since we need to tokenize the sequences, let's begin by loading the BERT tokenizer:

In [None]:
# load the BERT tokenizer
# set `clean_up_tokenization_spaces` to False to keep the tokenization spaces
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, clean_up_tokenization_spaces=False)

Next, let's structure the examples in the desired format. First, we tokenize the questions and the context. Then, using the answer and sequence IDs (or segment IDs, as mentioned in the introduction), we identify the start and end positions of the answer within the context. Finally, we will apply the preprocessing function to each row in the `train` set and discard the columns we don't need:

In [None]:
def preprocess_train_examples(examples, tokenizer, max_length, stride):
    """Process the training split of the SQuAD dataset.
    
    Process the training split of the SQuAD dataset to include tokenized questions
    and context, as well as the start and end positions of the answer within the context.
    
    Args:
        examples: A row from the dataset containing an example.
        tokenizer: The BERT tokenizer to be used.
        max_length: The maximum length of the input sequence. If exceeded, truncate the second
            sentence of a pair (or a batch of pairs) to fit within the limit.
        stride: The number of tokens to retain from the end of a truncated sequence, allowing
            for overlap between truncated and overflowing sequences.
    
    Returns:
        The processed example.
    """
    # Tokenize the questions and context sequences
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
      questions,
      examples["context"],
      truncation="only_second",
      padding="max_length",
      stride=stride,
      max_length=max_length,
      return_offsets_mapping=True,
      return_overflowing_tokens=True,
    )

    answers = examples["answers"]
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")

    start_positions = []
    end_positions = []

    # find the start and end positions of the answer within the context
    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    return inputs

In [None]:
preprocess_train_data = partial(
    preprocess_train_examples, tokenizer=tokenizer, max_length=384, stride=128)
processed_train_data = data["train"].map(preprocess_train_data, batched=True, remove_columns=data["train"].column_names)
processed_train_data

The processed `train` split is now ready for fine-tuning BERT. Moving on to model evaluation, the preprocessing step for the `validation` split is almost identical. However, we also need to retain the ID of each row so that we can later evaluate the model's performance by reconstructing the actual answer text and computing the Exact Match (EM) and F1 scores:

In [None]:
def preprocess_valid_examples(examples, tokenizer, max_length, stride):
    """Process the validation split of the SQuAD dataset.
    
    Process the training split of the SQuAD dataset to include the unique ID of each row,
    the tokenized questions and context, as well as the start and end positions of the answer
    within the context.
    
    Args:
        examples: A row from the dataset containing an example.
        tokenizer: The BERT tokenizer to be used.
        max_length: The maximum length of the input sequence. If exceeded, truncate the second
            sentence of a pair (or a batch of pairs) to fit within the limit.
        stride: The number of tokens to retain from the end of a truncated sequence, allowing
            for overlap between truncated and overflowing sequences.
    
    Returns:
        The processed example.
    """
    # Tokenize the questions and context sequences
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
      questions,
      examples["context"],
      truncation="only_second",
      padding="max_length",
      stride=stride,
      max_length=max_length,
      return_offsets_mapping=True,
      return_overflowing_tokens=True,
    )
    
    example_ids = []
    answers = examples["answers"]
    offset_mapping = inputs["offset_mapping"]
    sample_map = inputs.pop("overflow_to_sample_mapping")

    start_positions = []
    end_positions = []

    # find the start and end positions of the answer within the context
    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)
        
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["example_id"] = example_ids  # keep the unique ID of the example
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    return inputs

In [None]:
preprocess_valid_data = partial(
    preprocess_valid_examples, tokenizer=tokenizer, max_length=384, stride=128)
processed_valid_data = data["validation"].map(preprocess_valid_data, batched=True, remove_columns=data["validation"].column_names)
processed_valid_data