# Fine-tune BERT on SQuAD

Source: 

https://github.com/dpoulopoulos/bert-qa-finetuning

## BERT for Question-Answering

In this Notebook, we fine-tune [BERT (Bidirectional Encoder Representations from Transformers)](https://arxiv.org/abs/1810.04805) for Question Answering (Q&A) tasks using the [SQuAD (Stanford Question Answering)](https://rajpurkar.github.io/SQuAD-explorer/explore/1.1/dev/Super_Bowl_50.html) dataset. Developed by Google in 2018, BERT revolutionized the field of NLP by setting new state-of-the-art benchmarks across various NLP (Natural Language Processing) tasks.

BERT is pre-trained on a massive corpus, allowing it to grasp language structure and context. This pre-trained model can then be fine-tuned for specific tasks such as sentiment analysis or question answering. Fine-tuning BERT for Q&A tasks involves adjusting the model to predict the start and end positions of the answer in a given passage for a provided question (extractive question answering). The following steps outline the process of fine-tuning BERT for these tasks:

1. **🌱 Dataset Preparation**:
    - Define each dataset item with a question, a passage (or context), and the start and end positions of the answer within the passage as the label.
    - Tokenize both the question and passage into subwords using BERT's tokenizer. Separate the question from the passage using the `[SEP]` token and start the input sequence with the `[CLS]`
      token.
    - Mark the question as segment `A` (or `0`) and the context as segment `B` (or `1`). Use this information to learn different embeddings for each segment, which are added to the word
      embeddings.
1. **🪡 Model Modification**:
    - Extract embeddings for each token in the sequence from the pre-trained BERT model.
    - Add a dense (fully connected) layer on top of BERT, with two output nodes: one for predicting the start position and one for predicting the end position of the answer in the passage (see
      [code](https://github.com/huggingface/transformers/blob/c385de24414e4ec6125ee14c46c128bfe70ecb66/src/transformers/models/bert/modeling_bert.py#L1803)).
1. **🎯 Training Objective**:
    - Output a score for each token in the passage, indicating how likely that token is the start of the answer, and another score for the end.
    - Apply a SoftMax function over the sequence to get a probability distribution for the start and end positions.
    - Use the sum of the negative log likelihood of the correct start and end positions as the loss function.
1. **🚀 Training**:
    - Initialize training with pre-trained BERT weights.
    - Apply a smaller learning rate (e.g., 2e-5 or 3e-5) since BERT is already pre-trained. Avoid using a larger learning rate, as it may cause the model to diverge.
    - Fine-tune the model on the Q&A dataset for several epochs, stopping when validation performance plateaus or decreases.
1. **✨ Inference**:
    - Tokenize a new question and passage, and add the special `[CLS]` and `[SEP]` tokens.
    - Feed the tokens into the fine-tuned BERT model to get scores for the start and end positions of the answer.
    - Select the tokens between the predicted start and end positions as the final answer.
    - Apply constraints such as ensuring the end position is after the start and limiting the maximum answer length.
1. **🎉 Evaluation**:
    - Measure performance using common metrics like Exact Match (EM), which calculates the percentage of predictions that exactly match the ground truth and the F1 score to account for partial
      matches by considering overlapping words between the prediction and ground truth.

Let's begin by implementing each step, one by one, using the Hugging Face 🤗 ecosystem. First, the imports:

In [2]:
import collections
from functools import partial

import evaluate
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import pipeline
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

  from .autonotebook import tqdm as notebook_tqdm


### Approach

- **Model Selection:** I am using Hugging Face's ecosystem to fine-tune the pre-trained BERT model (bert-base-uncased). The model was fine-tuned using the SQuAD dataset.

- **Evaluation:** The model was evaluated using metrics such as exact match (EM) and F1 score to assess how well it identified the correct span of the answer from the context.

- **SQuAD:** 

    - Stanford Question Answering Dataset (SQuAD) is a reading comprehension dataset, consisting of questions posed by crowdworkers on a set of Wikipedia articles, where the answer to every question is a segment of text, or span, from the corresponding reading passage, or the question might be unanswerable.

    - SQuAD 1.1, the previous version of the SQuAD dataset, contains 100,000+ question-answer pairs on 500+ articles.

### Access SQuAD Datset via Hugging Face Hub

Let's set the IDs for the model and the dataset. We will download both of them from the Hugging Face Hub, using the `datasets` and `transformers` libraries:

In [3]:
DATASET_ID = "rajpurkar/squad"
MODEL_ID = "google-bert/bert-base-uncased"

### Data Processing

In this section, we will download the dataset, cache it locally, and preprocess it into the format described in the introduction. Our goal is to produce examples that contain:
- The tokenized question and context.
- The start position of the answer within the context.
- The end position of the answer within the context.

First, let's download and examine the dataset:

In [4]:
# load the SQuAD dataset
data = load_dataset(DATASET_ID)
data

Generating train split: 100%|██████████| 87599/87599 [00:00<00:00, 576945.59 examples/s]
Generating validation split: 100%|██████████| 10570/10570 [00:00<00:00, 423282.79 examples/s]


DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

The dataset has two splits: a `train` split with `87,599` rows, and a `validation` split with `10,570` rows. Each row includes a unique identifier, a title, the context, the question, and one or more possible answers. We will handle each split slightly differently. You’ll see why later, but for now, let's focus on processing the `train` split.

Since we need to tokenize the sequences, let's begin by loading the BERT tokenizer:

In [5]:
# load the BERT tokenizer
# set `clean_up_tokenization_spaces` to False to keep the tokenization spaces
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID, clean_up_tokenization_spaces=False)

Next, let's structure the examples in the desired format. First, we tokenize the questions and the context. Then, using the answer and sequence IDs (or segment IDs, as mentioned in the introduction), we identify the start and end positions of the answer within the context. Finally, we will apply the preprocessing function to each row in the `train` set and discard the columns we don't need:

In [6]:
def preprocess_train_examples(examples, tokenizer, max_length, stride):
    """Process the training split of the SQuAD dataset.

    Process the training split of the SQuAD dataset to include tokenized questions
    and context, as well as the start and end positions of the answer within the context.

    Args:
        examples: A row from the dataset containing an example.
        tokenizer: The BERT tokenizer to be used.
        max_length: The maximum length of the input sequence. If exceeded, truncate the second
            sentence of a pair (or a batch of pairs) to fit within the limit.
        stride: The number of tokens to retain from the end of a truncated sequence, allowing
            for overlap between truncated and overflowing sequences.

    Returns:
        The processed example.
    """
    # Tokenize the questions and context sequences
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
      questions,
      examples["context"],
      truncation="only_second",
      padding="max_length",
      stride=stride,
      max_length=max_length,
      return_offsets_mapping=True,
      return_overflowing_tokens=True,
    )

    answers = examples["answers"]
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")

    start_positions = []
    end_positions = []

    # find the start and end positions of the answer within the context
    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    return inputs

In [7]:
preprocess_train_data = partial(
    preprocess_train_examples, tokenizer=tokenizer, max_length=384, stride=128)
processed_train_data = data["train"].map(preprocess_train_data, batched=True, remove_columns=data["train"].column_names)
processed_train_data

Map: 100%|██████████| 87599/87599 [00:25<00:00, 3378.06 examples/s]


Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 88524
})

The processed `train` split is now ready for fine-tuning BERT. Moving on to model evaluation, the preprocessing step for the `validation` split is almost identical. However, we also need to retain the ID of each row so that we can later evaluate the model's performance by reconstructing the actual answer text and computing the Exact Match (EM) and F1 scores:

In [8]:
def preprocess_valid_examples(examples, tokenizer, max_length, stride):
    """Process the validation split of the SQuAD dataset.

    Process the training split of the SQuAD dataset to include the unique ID of each row,
    the tokenized questions and context, as well as the start and end positions of the answer
    within the context.

    Args:
        examples: A row from the dataset containing an example.
        tokenizer: The BERT tokenizer to be used.
        max_length: The maximum length of the input sequence. If exceeded, truncate the second
            sentence of a pair (or a batch of pairs) to fit within the limit.
        stride: The number of tokens to retain from the end of a truncated sequence, allowing
            for overlap between truncated and overflowing sequences.

    Returns:
        The processed example.
    """
    # Tokenize the questions and context sequences
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
      questions,
      examples["context"],
      truncation="only_second",
      padding="max_length",
      stride=stride,
      max_length=max_length,
      return_offsets_mapping=True,
      return_overflowing_tokens=True,
    )

    example_ids = []
    answers = examples["answers"]
    offset_mapping = inputs["offset_mapping"]
    sample_map = inputs.pop("overflow_to_sample_mapping")

    start_positions = []
    end_positions = []

    # find the start and end positions of the answer within the context
    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["example_id"] = example_ids  # keep the unique ID of the example
    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions

    return inputs

In [9]:
preprocess_valid_data = partial(
    preprocess_valid_examples, tokenizer=tokenizer, max_length=384, stride=128)
processed_valid_data = data["validation"].map(preprocess_valid_data, batched=True, remove_columns=data["validation"].column_names)
processed_valid_data

Map: 100%|██████████| 10570/10570 [00:04<00:00, 2425.76 examples/s]


Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'example_id', 'start_positions', 'end_positions'],
    num_rows: 10784
})

### Model Fine-Tuning

We are now ready to fine-tune BERT for Question Answering. First, let's load the model and set the training arguments. Specifically:

- We will save a model checkpoint every 2000 steps.
- We will log the training process every 500 steps, allowing us to visualize it with TensorBoard and evaluate the experiment's performance.
- We will use mixed precision training by casting the model weights to `bf16` (Brain Float) to accelerate the process and reduce the model's memory footprint on the GPU.

In [10]:
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_ID)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
training_args = TrainingArguments(
    output_dir='./checkpoints',
    logging_dir='./logs',
    eval_strategy="steps",
    logging_steps=500,
    logging_strategy="steps",
    save_steps=2000,
    save_strategy="steps",
    learning_rate=3e-5,
    num_train_epochs=2,
    weight_decay=0.01,
    bf16=True,
    per_device_train_batch_size=48,
    per_device_eval_batch_size=96,
)

In [12]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_train_data,
    eval_dataset=processed_valid_data,
    tokenizer=tokenizer,
)

  trainer = Trainer(


However, before we start fine-tuning the model, let's assess its performance using the Exact Match (EM) and F1 scores, as well as by answering a few sample questions from the validation split of the dataset. To evaluate the model, we need to perform some post-processing to reconstruct the actual answer from the model’s predictions and compare it to the ground truth answers. The following function handles that process:

In [13]:
def compute_metrics(start_logits, end_logits, features, examples, n_best=20, max_answer_length=50):
    """Compute the Exact Match (EM) and F1 score for the model's predictions.

    Reconstruct the actual text of the answer from the model's predictions and compare
    it to the ground truth for the validation dataset.

    Args:
        start_logits: Logits predicting the start position of the answer.
        end_logits: Logits predicting the end position of the answer.
        features: The processed validation dataset.
        examples: The raw validation dataset.
        n_best: The top-k answers to consider.
        max_answer_length: The maximum length of an answer to consider.

    Returns:
        The Exact Match (EM) and F1 score for the validation dataset.
    """

    metric = evaluate.load("squad")

    # keep a dictionary that maps examples to predictions through unique IDs
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            # keep a list of the top-k best predictions for the start and end position indexes
            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    # reconstruct the answer considering each prediction for the start and end positions
                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # select the answer with the best score based on the logit scores
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            # create a list with the predictions that contains the IDs and actual text
            # see: https://huggingface.co/spaces/evaluate-metric/squad
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    # create a list with the labels that contains the IDs and actual text
    # see: https://huggingface.co/spaces/evaluate-metric/squad
    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

Next, let's get the untrained model's predictions and evaluate its performance:

In [14]:
predictions, _, _ = trainer.predict(processed_valid_data)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, processed_valid_data, data["validation"])

Downloading builder script: 100%|██████████| 4.53k/4.53k [00:00<00:00, 9.11MB/s]
Downloading extra modules: 100%|██████████| 3.32k/3.32k [00:00<00:00, 8.46MB/s]
100%|██████████| 10570/10570 [00:09<00:00, 1116.14it/s]


{'exact_match': 0.12298959318826869, 'f1': 7.388432885067443}

As expected, the computed scores show that the model is randomly extracting text from the context to formulate answers. Let’s observe this behavior in action by answering a few random questions from the dataset:

In [15]:
random_indexes = np.random.randint(0, len(data["validation"]), 3)
subdataset = data["validation"].select(random_indexes)
qa_pipe_untrained = pipeline("question-answering", model=model, tokenizer=tokenizer, device='cuda')

for row in subdataset:
    context = row["context"]
    question = row["question"]
    answer = qa_pipe_untrained(question=question, context=context)

    print(f"Context: \n\n {context} \n")
    print(f"Question: \n\n {question} \n")
    print(f"Answer: \n\n {answer['answer']} \n")
    print("--- \n")

Device set to use cuda


Context: 

 During Reconstruction and the Gilded Age, Jacksonville and nearby St. Augustine became popular winter resorts for the rich and famous. Visitors arrived by steamboat and later by railroad. President Grover Cleveland attended the Sub-Tropical Exposition in the city on February 22, 1888 during his trip to Florida. This highlighted the visibility of the state as a worthy place for tourism. The city's tourism, however, was dealt major blows in the late 19th century by yellow fever outbreaks. In addition, extension of the Florida East Coast Railway further south drew visitors to other areas. From 1893 to 1938 Jacksonville was the site of the Florida Old Confederate Soldiers and Sailors Home with a nearby cemetery. 

Question: 

 Which US President visited Jacksonville in 1888? 

Answer: 

 city's tourism, however, was dealt major blows in the 

--- 

Context: 

 Several commemorative events take place every year. Gatherings of thousands of people on the banks of the Vistula on Mi

Let's fine-tune BERT so it gives better answers:

In [16]:
trainer.train()

Step,Training Loss,Validation Loss,Model Preparation Time
500,1.9015,1.236201,0.0085
1000,1.2463,1.10129,0.0085
1500,1.1357,1.049315,0.0085
2000,1.0039,1.049722,0.0085
2500,0.8424,1.028441,0.0085
3000,0.8303,1.024863,0.0085
3500,0.8254,1.01614,0.0085


TrainOutput(global_step=3690, training_loss=1.0966898848370807, metrics={'train_runtime': 844.5176, 'train_samples_per_second': 209.644, 'train_steps_per_second': 4.369, 'total_flos': 3.4696551139946496e+16, 'train_loss': 1.0966898848370807, 'epoch': 2.0})

### Model Evaluation

Finally, we need to evaluate the model on the `validation` split of the dataset. We will use two metrics to systematically assess its performance:
- **Exact Match (EM)**: Calculate the percentage of predictions that exactly match the ground truth.
- **F1 Score**: Measure partial matches by considering overlapping words between the prediction and the ground truth.

In [17]:
predictions, _, _ = trainer.predict(processed_valid_data)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, processed_valid_data, data["validation"])

100%|██████████| 10570/10570 [00:09<00:00, 1071.24it/s]


{'exact_match': 79.80132450331126, 'f1': 87.53663335317574}

Let's also provide an answer for the same random samples:

In [18]:
qa_pipe = pipeline("question-answering", model=model, tokenizer=tokenizer, device='cuda')

for row in subdataset:
    context = row["context"]
    question = row["question"]
    answer = qa_pipe(question=question, context=context)

    print(f"Context: \n\n {context} \n")
    print(f"Question: \n\n {question} \n")
    print(f"Answer: \n\n {answer['answer']} \n")
    print("--- \n")

Device set to use cuda


Context: 

 During Reconstruction and the Gilded Age, Jacksonville and nearby St. Augustine became popular winter resorts for the rich and famous. Visitors arrived by steamboat and later by railroad. President Grover Cleveland attended the Sub-Tropical Exposition in the city on February 22, 1888 during his trip to Florida. This highlighted the visibility of the state as a worthy place for tourism. The city's tourism, however, was dealt major blows in the late 19th century by yellow fever outbreaks. In addition, extension of the Florida East Coast Railway further south drew visitors to other areas. From 1893 to 1938 Jacksonville was the site of the Florida Old Confederate Soldiers and Sailors Home with a nearby cemetery. 

Question: 

 Which US President visited Jacksonville in 1888? 

Answer: 

 Grover Cleveland 

--- 

Context: 

 Several commemorative events take place every year. Gatherings of thousands of people on the banks of the Vistula on Midsummer’s Night for a festival called