**Fine-Tuning of BERT Models for Question Answering**
| Fine-Tune: SQuAD | Test: SQuAD 

This notebook uses pretrained Hugging Face BERT models and finetunes the models for question answering tasks. This notebook uses example code provided by Hugging Face for finetuning a model for question answering [[Hugging Face Question Answering]](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb#scrollTo=HFASsisvIrIb). The code has been modified to finetune on the SQuAD 2.0 and Google NQ datasets.

In [1]:
!pip install datasets transformers

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/54/90/43b396481a8298c6010afb93b3c1e71d4ba6f8c10797a7da8eb005e45081/datasets-1.5.0-py3-none-any.whl (192kB)
[K     |████████████████████████████████| 194kB 6.0MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 7.6MB/s 
Collecting huggingface-hub<0.1.0
  Downloading https://files.pythonhosted.org/packages/af/07/bf95f398e6598202d878332280f36e589512174882536eb20d792532a57d/huggingface_hub-0.0.7-py3-none-any.whl
Collecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/e7/27/1c0b37c53a7852f1c190ba5039404d27b3ae96a55f48203a74259f8213c9/xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (243kB)
[K     |████████████████████████████████| 245kB 21.0MB/s 
Collecting fsspec
[?25l  Downloading http

In [2]:
########### INPUT ###########
# Input the base model, batch size, and whether or not the dataset contains non-answerable questions (squad_v2 = True)
model_checkpoint = "bert-base-uncased"
batch_size = 16
squad_v2 = True

Load Data

In [3]:
# import the load_dataset and load_metric for loading and evaluation of datasets
from datasets import load_dataset, load_metric

In [4]:
########### INPUT ###########
# Load the dataset to finetune model
import json
import pandas as pd

datasets = load_dataset('json', data_files={'train': '/content/drive/MyDrive/ColabNotebooks/data/squad-reduce-8500_train.jsonl',
                                            'validation': '/content/drive/MyDrive/ColabNotebooks/data/squad-reduce_validation.jsonl'})

Using custom data configuration default-d095c4e712fef9cc


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/json/default-d095c4e712fef9cc/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-d095c4e712fef9cc/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.


In [5]:
# view dataset object
datasets

DatasetDict({
    train: Dataset({
        features: ['answers', 'context', 'id', 'question'],
        num_rows: 8563
    })
    validation: Dataset({
        features: ['answers', 'context', 'id', 'question'],
        num_rows: 1204
    })
})

In [6]:
# view example from dataset
datasets["train"][0]

{'answers': {'answer_start': [], 'text': []},
 'context': 'Texts on architecture have been written since ancient time. These texts provided both general advice and specific formal prescriptions or canons. Some examples of canons are found in the writings of the 1st-century BCE Roman Architect Vitruvius. Some of the most important early examples of canonic architecture are religious.',
 'id': '5acfa27f77cf76001a68565f',
 'question': 'Who wrote canons in 2nd century BCE?'}

In [7]:
# code to view random examples in pandas
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=3):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [8]:
show_random_elements(datasets["train"])

Unnamed: 0,answers,context,id,question
0,"{'answer_start': [484], 'text': ['John Wesley's brother, Charles']}","Soteriologically, most Methodists are Arminian, emphasizing that Christ accomplished salvation for every human being, and that humans must exercise an act of the will to receive it (as opposed to the traditional Calvinist doctrine of monergism). Methodism is traditionally low church in liturgy, although this varies greatly between individual congregations; the Wesleys themselves greatly valued the Anglican liturgy and tradition. Methodism is known for its rich musical tradition; John Wesley's brother, Charles, was instrumental in writing much of the hymnody of the Methodist Church, and many other eminent hymn writers come from the Methodist tradition.",5731f248e17f3d1400422566,Who wrote most of the Methodist hymns?
1,"{'answer_start': [233], 'text': ['13 million km2']}","Qing China reached its largest extent during the 18th century, when it ruled China proper (eighteen provinces) as well as the areas of present-day Northeast China, Inner Mongolia, Outer Mongolia, Xinjiang and Tibet, at approximately 13 million km2 in size. There were originally 18 provinces, all of which in China proper, but later this number was increased to 22, with Manchuria and Xinjiang being divided or turned into provinces. Taiwan, originally part of Fujian province, became a province of its own in the late 19th century, but was ceded to the Empire of Japan in 1895 following the First Sino-Japanese War. In addition, many surrounding countries, such as Korea (Joseon dynasty), Vietnam frequently paid tribute to China during much of this period. Khanate of Kokand were forced to submit as protectorate and pay tribute to the Qing dynasty in China between 1774 and 1798.",5731579e497a881900248e30,How many kilometers was Qing China at its height?
2,"{'answer_start': [121], 'text': ['John of Damascus']}","Many Muslims criticized the Umayyads for having too many non-Muslim, former Roman administrators in their government. St John of Damascus was also a high administrator in the Umayyad administration. As the Muslims took over cities, they left the peoples political representatives and the Roman tax collectors and the administrators. The taxes to the central government were calculated and negotiated by the peoples political representatives. The Central government got paid for the services it provided and the local government got the money for the services it provided. Many Christian cities also used some of the taxes on maintain their churches and run their own organizations. Later the Umayyads were criticized by some Muslims for not reducing the taxes of the people who converted to Islam. These new converts continues to pay the same taxes that were previously negotiated.",571aa92010f8ca140030529d,What Christian saint was also an Umayyad administrator?


Preprocess Training Data

In [9]:
# import the correct tokenizer for the model architecture
from transformers import AutoTokenizer    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [10]:
# verify that the tokenizer is a fast tokenizer
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [11]:
# run a check to verify that the tokenizer is working for an example questions and answer
tokenizer("Is this tokenizer working?", "Yes, the tokenizer is working correctly")

{'input_ids': [101, 2003, 2023, 19204, 17629, 2551, 1029, 102, 2748, 1010, 1996, 19204, 17629, 2003, 2551, 11178, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [12]:
########### INPUT ###########
# Length will be truncated to handle long contexts
# Set the max length (questions and context) and stride (context overlap)
max_length = 384
doc_stride = 128 

In [13]:
# verify that the truncation is working correctly by finding an example
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break
example = datasets["train"][i]

Token indices sequence length is longer than the specified maximum sequence length for this model (705 > 512). Running this sequence through the model will result in indexing errors


In [14]:
# check to see what the length of the example is without truncation (should be greater than max_length)
len(tokenizer(example["question"], example["context"])["input_ids"])

705

In [15]:
# tokenizer should return always return the question plus truncated contexts
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)

In [16]:
# verify the length for the mulitple examples are provided for tokenized example
[len(x) for x in tokenized_example["input_ids"]]

[384, 384, 217]

In [17]:
# decode the tokenized example to verify that we have a question plus the truncated context
for x in tokenized_example["input_ids"][:]:
    print(tokenizer.decode(x))

[CLS] what was us's biggest selling album? [SEP] the mandolin has been used extensively in the traditional music of england and scotland for generations. simon mayor is a prominent british player who has produced six solo albums, instructional books and dvds, as well as recordings with his mandolin quartet the mandolinquents. the instrument has also found its way into british rock music. the mandolin was played by mike oldfield ( and introduced by vivian stanshall ) on oldfield's album tubular bells, as well as on a number of his subsequent albums ( particularly prominently on hergest ridge ( 1974 ) and ommadawn ( 1975 ) ). it was used extensively by the british folk - rock band lindisfarne, who featured two members on the instrument, ray jackson and simon cowe, and whose " fog on the tyne " was the biggest selling uk album of 1971 - 1972. the instrument was also used extensively in the uk folk revival of the 1960s and 1970s with bands such as fairport convention and steeleye span taki

In [18]:
# use the tokenizer to map the offset for locating the answer
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)
print(tokenized_example["offset_mapping"][0][:100])

[(0, 0), (0, 4), (5, 8), (9, 11), (11, 12), (12, 13), (14, 21), (22, 29), (30, 35), (35, 36), (0, 0), (0, 3), (4, 12), (13, 16), (17, 21), (22, 26), (27, 38), (39, 41), (42, 45), (46, 57), (58, 63), (64, 66), (67, 74), (75, 78), (79, 87), (88, 91), (92, 103), (103, 104), (105, 110), (111, 116), (117, 119), (120, 121), (122, 131), (132, 139), (140, 146), (147, 150), (151, 154), (155, 163), (164, 167), (168, 172), (173, 179), (179, 180), (181, 194), (195, 200), (201, 204), (205, 209), (209, 210), (211, 213), (214, 218), (219, 221), (222, 232), (233, 237), (238, 241), (242, 250), (251, 258), (259, 262), (263, 271), (271, 276), (276, 277), (277, 278), (279, 282), (283, 293), (294, 297), (298, 302), (303, 308), (309, 312), (313, 316), (317, 321), (322, 329), (330, 334), (335, 340), (340, 341), (342, 345), (346, 354), (355, 358), (359, 365), (366, 368), (369, 373), (374, 377), (377, 382), (383, 384), (384, 387), (388, 398), (399, 401), (402, 408), (409, 413), (413, 416), (416, 418), (418, 41

In [19]:
# verify that the mapping is working correctly
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

what What


In [20]:
# use the tokenizer to find sequence ids for locating the position of the question and answer
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [21]:
# show example data
example

{'answers': {'answer_start': [], 'text': []},
 'context': 'The mandolin has been used extensively in the traditional music of England and Scotland for generations. Simon Mayor is a prominent British player who has produced six solo albums, instructional books and DVDs, as well as recordings with his mandolin quartet the Mandolinquents. The instrument has also found its way into British rock music. The mandolin was played by Mike Oldfield (and introduced by Vivian Stanshall) on Oldfield\'s album Tubular Bells, as well as on a number of his subsequent albums (particularly prominently on Hergest Ridge (1974) and Ommadawn (1975)). It was used extensively by the British folk-rock band Lindisfarne, who featured two members on the instrument, Ray Jackson and Simon Cowe, and whose "Fog on the Tyne" was the biggest selling UK album of 1971-1972. The instrument was also used extensively in the UK folk revival of the 1960s and 1970s with bands such as Fairport Convention and Steeleye Span taking 

In [22]:
# identify the first and last token of the answer in the context or return no answer

# locate the start and end character of answer
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # Move the token_start_index and token_end_index to the two ends of the answer.
    # Note: we could go after the last offset if the answer is the last word (edge case).
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("The answer is not in this feature.")

IndexError: ignored

In [23]:
# verify that the start and end tokens produced are the correct answer
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

NameError: ignored

In [24]:
# To make this notebook generalizable to any model, we account for the special case where the model expects padding on the left
pad_on_right = tokenizer.padding_side == "right"

In [25]:
# This function combines the above methods by tokenizing each example with truncation and padding

def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [26]:
# the function can work on multiple features. Verify that the tokenization is working correctly
features = prepare_train_features(datasets['train'][:2])
features

{'input_ids': [[101, 2040, 2626, 21975, 1999, 3416, 2301, 10705, 1029, 102, 6981, 2006, 4294, 2031, 2042, 2517, 2144, 3418, 2051, 1012, 2122, 6981, 3024, 2119, 2236, 6040, 1998, 3563, 5337, 20422, 2015, 2030, 21975, 1012, 2070, 4973, 1997, 21975, 2024, 2179, 1999, 1996, 7896, 1997, 1996, 3083, 1011, 2301, 10705, 3142, 4944, 6819, 16344, 2226, 26055, 1012, 2070, 1997, 1996, 2087, 2590, 2220, 4973, 1997, 9330, 2594, 4294, 2024, 3412, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [27]:
# apply the function on all elements of all the splits in the dataset including training, validation, and testing data
# remove the old columns since the preprocessing changes the number of samples
# results are cached. Pass "load_from_cache_file=False" to force the preprocessing to be applied again
tokenized_datasets = datasets.map(
    prepare_train_features, 
    batched=True, 
    remove_columns=datasets["train"].column_names
)

HBox(children=(FloatProgress(value=0.0, max=9.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




Fine-Tune the Model

In [28]:
# import Pytorch pretrained model for question answering
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

# from_pretrained method downloads and caches the model
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

# warning regarding not using weights and layers is normal. we are removing the 
# masked language modeling head to pretrain the model on the QA task for which
# we do not have pretrained weights and requires fine-tuning

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [29]:
########### INPUT ###########
# training arguments is a class that contains the attributes to customize training
# set the folder name f"model-dataset", which will be used to save checkpoints
# set the learning_rate, number of epochs, and weight_decay
# batch_size has been set at the beginning of the notebook 
args = TrainingArguments(
    f"bert-squad-squad",
    evaluation_strategy = "epoch",
    learning_rate = 2e-5,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    num_train_epochs = 3,
    weight_decay = 0.01,
)

In [30]:
# import a default data collator
from transformers import default_data_collator

# set the data_collator to the default data collator
data_collator = default_data_collator

In [31]:
# pass all of the training arguments and datasets to the trainer
trainer = Trainer(
    model,
    args,
    train_dataset = tokenized_datasets["train"],
    eval_dataset = tokenized_datasets["validation"],
    data_collator = data_collator,
    tokenizer = tokenizer,
)

In [32]:
# finetune the model by calling train method
# running this cell will take time.
trainer.train()

Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,2.5376,1.751771,15.9756,76.992
2,1.4838,1.604503,15.9757,76.992
3,1.1421,1.623546,15.9727,77.006


TrainOutput(global_step=1623, training_loss=1.6699216775665, metrics={'train_runtime': 1132.618, 'train_samples_per_second': 1.433, 'total_flos': 6506076900907008.0, 'epoch': 3.0, 'init_mem_cpu_alloc_delta': 339949, 'init_mem_gpu_alloc_delta': 436709888, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 701144, 'train_mem_gpu_alloc_delta': 1308331008, 'train_mem_cpu_peaked_delta': 95141232, 'train_mem_gpu_peaked_delta': 8290012160})

In [33]:
########### INPUT ###########
# save the model. input the model name ("model-dataset-trained")
trainer.save_model("bert-squad-squad-trained")

Evaluation

In [34]:
# the validation features will need to be re-processed similar to the training features
# the processing will also need to check if the output span is inside the context (and not in the question)
# it will also need to retrieve the text inside

def prepare_validation_features(examples):
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [35]:
# apply the function to validation set 
# remove the old columns since the preprocessing changes the number of samples
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




In [36]:
# extract the predictions for all features using method trainer.predict
raw_predictions = trainer.predict(validation_features)

In [37]:
# trainer hides columns not used by the model. the columns needed for post-processing are set back 
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

In [38]:
########### INPUT ###########
# to classify answers, we use the score obtained by adding the start and end logits
# limit the number of possible answers by setting n_best_size
# limit the length of the answer by setting max_answer_length
n_best_size = 20
max_answer_length = 30

In [39]:
# get the output logits from trainer
import torch

for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

odict_keys(['loss', 'start_logits', 'end_logits'])

In [40]:
# code to verify the score and corresponding text are working correctly
import numpy as np

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]
# The first feature comes from the first example. For the more general case, we will need to be match the example_id to
# an example index
context = datasets["validation"][0]["context"]

# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
        # to part of the input_ids that are not in the context.
        if (
            start_index >= len(offset_mapping)
            or end_index >= len(offset_mapping)
            or offset_mapping[start_index] is None
            or offset_mapping[end_index] is None
        ):
            continue
        # Don't consider answers with a length that is either < 0 or > max_answer_length.
        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
            continue
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            start_char = offset_mapping[start_index][0]
            end_char = offset_mapping[end_index][1]
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": context[start_char: end_char]
                }
            )

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

[{'score': 5.5691347, 'text': 'southern California'},
 {'score': 3.6082978, 'text': 'Another definition for southern California'},
 {'score': 2.17913,
  'text': 'southern California uses Point Conception and the Tehachapi Mountains as the northern boundary'},
 {'score': 1.8905233, 'text': 'definition for southern California'},
 {'score': 1.4352517,
  'text': 'southern California uses Point Conception and the Tehachapi Mountains as the northern boundary.'},
 {'score': 1.1835575,
  'text': 'San Luis Obispo, Kern, and San Bernardino counties. Another definition for southern California'},
 {'score': 1.1483896, 'text': 'California'},
 {'score': 0.6638926,
  'text': 'southern California uses Point Conception and the Tehachapi Mountains'},
 {'score': 0.21829328,
  'text': 'Another definition for southern California uses Point Conception and the Tehachapi Mountains as the northern boundary'},
 {'score': 0.19717634, 'text': 'northern boundary'},
 {'score': 0.038207054, 'text': 'southern Califor

In [41]:
# view the actual answer
datasets["validation"][0]["answers"]

{'answer_start': [], 'text': []}

In [42]:
# apply the process above to all features by mapping between examples and their corresponding features
import collections

examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)

In [43]:
# to handle the non-answerable questions, we need to extract the score for the impossible answer
# the score is collected from minimum of the scores from the CLS token for each feature generated by the example
# the question is not answerable when that score is greater than the highest answerable score
from tqdm.auto import tqdm
import numpy as np

def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

In [44]:
# apply the postprocessing function to the raw predictions
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

Post-processing 1204 example predictions split into 1230 features.


HBox(children=(FloatProgress(value=0.0, max=1204.0), HTML(value='')))




In [45]:
########### INPUT ###########
# load the metric from the datasets library
metric = load_metric("squad_v2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2264.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3182.0, style=ProgressStyle(description…




In [46]:
# format predictions and labels
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

{'HasAns_exact': 57.4468085106383,
 'HasAns_f1': 65.8527993602288,
 'HasAns_total': 611,
 'NoAns_exact': 43.17032040472176,
 'NoAns_f1': 43.17032040472176,
 'NoAns_total': 593,
 'best_exact': 50.91362126245847,
 'best_exact_thresh': 0.0,
 'best_f1': 54.87491174620679,
 'best_f1_thresh': 0.0,
 'exact': 50.415282392026576,
 'f1': 54.68111329659447,
 'total': 1204}