**Fine-Tuning of BERT Models for Question Answering**
| Fine-Tune: SQuAD Reduced | Test: SQuAD 

This notebook uses pretrained Hugging Face BERT models and finetunes the models for question answering tasks. This notebook uses example code provided by Hugging Face for finetuning a model for question answering [[Hugging Face Question Answering]](https://colab.research.google.com/github/huggingface/notebooks/blob/master/examples/question_answering.ipynb#scrollTo=HFASsisvIrIb). The code has been modified to finetune on the SQuAD 2.0 and Google NQ datasets.

In [1]:
!pip install datasets transformers

Collecting datasets
[?25l  Downloading https://files.pythonhosted.org/packages/54/90/43b396481a8298c6010afb93b3c1e71d4ba6f8c10797a7da8eb005e45081/datasets-1.5.0-py3-none-any.whl (192kB)
[K     |████████████████████████████████| 194kB 5.8MB/s 
[?25hCollecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 53.3MB/s 
[?25hCollecting xxhash
[?25l  Downloading https://files.pythonhosted.org/packages/e7/27/1c0b37c53a7852f1c190ba5039404d27b3ae96a55f48203a74259f8213c9/xxhash-2.0.0-cp37-cp37m-manylinux2010_x86_64.whl (243kB)
[K     |████████████████████████████████| 245kB 59.1MB/s 
Collecting fsspec
[?25l  Downloading https://files.pythonhosted.org/packages/91/0d/a6bfee0ddf47b254286b9bd574e6f50978c69897647ae15b14230711806e/fsspec-0.8.7-py3-none-any.whl (103kB)
[K     |████████████████████████████████| 11

In [2]:
########### INPUT ###########
# Input the base model, batch size, and whether or not the dataset contains non-answerable questions (squad_v2 = True)
model_checkpoint = "bert-base-uncased"
batch_size = 16
squad_v2 = True

Load Data

In [3]:
# import the load_dataset and load_metric for loading and evaluation of datasets
from datasets import load_dataset, load_metric

In [4]:
########### INPUT ###########
# Load the dataset to finetune model
import json
import pandas as pd

datasets = load_dataset('json', data_files={'train': '/content/drive/MyDrive/ColabNotebooks/data/squad-reduce_train.jsonl',
                                            'validation': '/content/drive/MyDrive/ColabNotebooks/data/squad-reduce_validation.jsonl'})

Using custom data configuration default-5cc7df159dae7e29


Downloading and preparing dataset json/default (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/json/default-5cc7df159dae7e29/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02...


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset json downloaded and prepared to /root/.cache/huggingface/datasets/json/default-5cc7df159dae7e29/0.0.0/83d5b3a2f62630efc6b5315f00f20209b4ad91a00ac586597caee3a4da0bef02. Subsequent calls will reuse this data.


In [5]:
# view dataset object
datasets

DatasetDict({
    train: Dataset({
        features: ['answers', 'context', 'id', 'question'],
        num_rows: 19029
    })
    validation: Dataset({
        features: ['answers', 'context', 'id', 'question'],
        num_rows: 1204
    })
})

In [6]:
# view example from dataset
datasets["train"][0]

{'answers': {'answer_start': [343], 'text': ['thirteenth']},
 'context': '"4 Minutes" was released as the album\'s lead single and peaked at number three on the Billboard Hot 100. It was Madonna\'s 37th top-ten hit on the chart—it pushed Madonna past Elvis Presley as the artist with the most top-ten hits. In the UK she retained her record for the most number-one singles for a female artist; "4 Minutes" becoming her thirteenth. At the 23rd Japan Gold Disc Awards, Madonna received her fifth Artist of the Year trophy from Recording Industry Association of Japan, the most for any artist. To further promote the album, Madonna embarked on the Sticky & Sweet Tour; her first major venture with Live Nation. With a gross of $280 million, it became the highest-grossing tour by a solo artist then, surpassing the previous record Madonna set with the Confessions Tour; it was later surpassed by Roger Waters\' The Wall Live. It was extended to the next year, adding new European dates, and after it end

In [7]:
# code to view random examples in pandas
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=3):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [8]:
show_random_elements(datasets["train"])

Unnamed: 0,answers,context,id,question
0,"{'answer_start': [1098], 'text': ['that such malpractices be denounced and subject to sanctions and penalties under national legislation,']}","In 2012, the Pan American Health Organization (the North and South American branch of the World Health Organization) released a statement cautioning against services that purport to ""cure"" people with non-heterosexual sexual orientations as they lack medical justification and represent a serious threat to the health and well-being of affected people, and noted that the global scientific and professional consensus is that homosexuality is a normal and natural variation of human sexuality and cannot be regarded as a pathological condition. The Pan American Health Organization further called on governments, academic institutions, professional associations and the media to expose these practices and to promote respect for diversity. The World Health Organization affiliate further noted that gay minors have sometimes been forced to attend these ""therapies"" involuntarily, being deprived of their liberty and sometimes kept in isolation for several months, and that these findings were reported by several United Nations bodies. Additionally, the Pan American Health Organization recommended that such malpractices be denounced and subject to sanctions and penalties under national legislation, as they constitute a violation of the ethical principles of health care and violate human rights that are protected by international and regional agreements.",57100800b654c5140001f764,What did the Pan American Health Organization reccomend for those who used treatment to stop homosexuality?
1,"{'answer_start': [441], 'text': ['Transit']}","Other major employers in the city include Ordnance Survey, the UK's national mapping agency, whose headquarters is located in a new building on the outskirts of the city, opened in February 2011. The Lloyd's Register Group has announced plans to move its London marine operations to a specially developed site at the University of Southampton. The area of Swaythling is home to Ford's Southampton Assembly Plant, where the majority of their Transit models are manufactured. Closure of the plant in 2013 was announced in 2012, with the loss of hundreds of jobs.",56f8aea09e9bad19000a0303,Which model does Ford manufacture in their Southampton plant?
2,"{'answer_start': [0], 'text': ['Bal-musette']}","Bal-musette is a style of French music and dance that first became popular in Paris in the 1870s and 1880s; by 1880 Paris had some 150 dance halls in the working-class neighbourhoods of the city. Patrons danced the bourrée to the accompaniment of the cabrette (a bellows-blown bagpipe locally called a ""musette"") and often the vielle à roue (hurdy-gurdy) in the cafés and bars of the city. Parisian and Italian musicians who played the accordion adopted the style and established themselves in Auvergnat bars especially in the 19th arrondissement, and the romantic sounds of the accordion has since become one of the musical icons of the city. Paris became a major centre for jazz and still attracts jazz musicians from all around the world to its clubs and cafés.",5728fb086aef051400154934,What style of french music became populars in the 1870sto 1880s?


Preprocess Training Data

In [9]:
# import the correct tokenizer for the model architecture
from transformers import AutoTokenizer    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [10]:
# verify that the tokenizer is a fast tokenizer
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [11]:
# run a check to verify that the tokenizer is working for an example questions and answer
tokenizer("Is this tokenizer working?", "Yes, the tokenizer is working correctly")

{'input_ids': [101, 2003, 2023, 19204, 17629, 2551, 1029, 102, 2748, 1010, 1996, 19204, 17629, 2003, 2551, 11178, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [12]:
########### INPUT ###########
# Length will be truncated to handle long contexts
# Set the max length (questions and context) and stride (context overlap)
max_length = 384
doc_stride = 128 

In [13]:
# verify that the truncation is working correctly by finding an example
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break
example = datasets["train"][i]

In [14]:
# check to see what the length of the example is without truncation (should be greater than max_length)
len(tokenizer(example["question"], example["context"])["input_ids"])

477

In [15]:
# tokenizer should return always return the question plus truncated contexts
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)

In [16]:
# verify the length for the mulitple examples are provided for tokenized example
[len(x) for x in tokenized_example["input_ids"]]

[384, 233]

In [17]:
# decode the tokenized example to verify that we have a question plus the truncated context
for x in tokenized_example["input_ids"][:]:
    print(tokenizer.decode(x))

[CLS] when was the first ottoman madrasa built? [SEP] " the first ottoman medrese was created in iznik in 1331 and most ottoman medreses followed the traditions of sunni islam. " " when an ottoman sultan established a new medrese, he would invite scholars from the islamic world — for example, murad ii brought scholars from persia, such as ʻalaʼ al - din and fakhr al - din who helped enhance the reputation of the ottoman medrese ". this reveals that the islamic world was interconnected in the early modern period as they travelled around to other islamic states exchanging knowledge. this sense that the ottoman empire was becoming modernised through globalization is also recognised by hamadeh who says : " change in the eighteenth century as the beginning of a long and unilinear march toward westernisation reflects the two centuries of reformation in sovereign identity. " inalcık also mentions that while scholars from for example persia travelled to the ottomans in order to share their kno

In [18]:
# use the tokenizer to map the offset for locating the answer
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)
print(tokenized_example["offset_mapping"][0][:100])

[(0, 0), (0, 4), (5, 8), (9, 12), (13, 18), (19, 26), (27, 33), (33, 34), (35, 40), (40, 41), (0, 0), (0, 1), (1, 4), (5, 10), (11, 18), (19, 22), (22, 25), (25, 26), (27, 30), (31, 38), (39, 41), (42, 43), (43, 44), (44, 47), (48, 50), (51, 54), (54, 55), (56, 59), (60, 64), (65, 72), (73, 76), (76, 79), (79, 81), (82, 90), (91, 94), (95, 105), (106, 108), (109, 114), (115, 120), (120, 121), (121, 122), (123, 124), (124, 128), (129, 131), (132, 139), (140, 146), (147, 158), (159, 160), (161, 164), (165, 168), (168, 171), (171, 172), (172, 173), (174, 176), (177, 182), (183, 189), (190, 198), (199, 203), (204, 207), (208, 215), (216, 221), (221, 222), (222, 225), (226, 233), (233, 234), (235, 237), (237, 240), (241, 243), (244, 251), (252, 260), (261, 265), (266, 272), (272, 273), (274, 278), (279, 281), (282, 283), (283, 286), (286, 287), (288, 290), (290, 291), (291, 294), (295, 298), (299, 301), (301, 303), (303, 304), (305, 307), (307, 308), (308, 311), (312, 315), (316, 322), (323

In [19]:
# verify that the mapping is working correctly
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

when When


In [20]:
# use the tokenizer to find sequence ids for locating the position of the question and answer
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [22]:
# show example data
example

{'answers': {'answer_start': [51], 'text': ['1331']},
 'context': '"The first Ottoman Medrese was created in İznik in 1331 and most Ottoman medreses followed the traditions of Sunni Islam." "When an Ottoman sultan established a new medrese, he would invite scholars from the Islamic world—for example, Murad II brought scholars from Persia, such as ʻAlāʼ al-Dīn and Fakhr al-Dīn who helped enhance the reputation of the Ottoman medrese". This reveals that the Islamic world was interconnected in the early modern period as they travelled around to other Islamic states exchanging knowledge. This sense that the Ottoman Empire was becoming modernised through globalization is also recognised by Hamadeh who says: "Change in the eighteenth century as the beginning of a long and unilinear march toward westernisation reflects the two centuries of reformation in sovereign identity." İnalcık also mentions that while scholars from for example Persia travelled to the Ottomans in order to share their kno

In [23]:
# identify the first and last token of the answer in the context or return no answer

# locate the start and end character of answer
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # Move the token_start_index and token_end_index to the two ends of the answer.
    # Note: we could go after the last offset if the answer is the last word (edge case).
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("The answer is not in this feature.")

25 26


In [24]:
# verify that the start and end tokens produced are the correct answer
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

1331
1331


In [25]:
# To make this notebook generalizable to any model, we account for the special case where the model expects padding on the left
pad_on_right = tokenizer.padding_side == "right"

In [26]:
# This function combines the above methods by tokenizing each example with truncation and padding

def prepare_train_features(examples):
    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [27]:
# the function can work on multiple features. Verify that the tokenization is working correctly
features = prepare_train_features(datasets['train'][:2])
features

{'input_ids': [[101, 1018, 2781, 2150, 11284, 1005, 1055, 2029, 2193, 2028, 2309, 1999, 1996, 2866, 1029, 102, 1000, 1018, 2781, 1000, 2001, 2207, 2004, 1996, 2201, 1005, 1055, 2599, 2309, 1998, 6601, 2012, 2193, 2093, 2006, 1996, 4908, 2980, 2531, 1012, 2009, 2001, 11284, 1005, 1055, 23027, 2327, 1011, 2702, 2718, 2006, 1996, 3673, 1517, 2009, 3724, 11284, 2627, 12280, 17229, 2004, 1996, 3063, 2007, 1996, 2087, 2327, 1011, 2702, 4978, 1012, 1999, 1996, 2866, 2016, 6025, 2014, 2501, 2005, 1996, 2087, 2193, 1011, 2028, 3895, 2005, 1037, 2931, 3063, 1025, 1000, 1018, 2781, 1000, 3352, 2014, 14725, 1012, 2012, 1996, 13928, 2900, 2751, 5860, 2982, 1010, 11284, 2363, 2014, 3587, 3063, 1997, 1996, 2095, 5384, 2013, 3405, 3068, 2523, 1997, 2900, 1010, 1996, 2087, 2005, 2151, 3063, 1012, 2000, 2582, 5326, 1996, 2201, 1010, 11284, 11299, 2006, 1996, 15875, 1004, 4086, 2778, 1025, 2014, 2034, 2350, 6957, 2007, 2444, 3842, 1012, 2007, 1037, 7977, 1997, 1002, 13427, 2454, 1010, 2009, 2150, 1996, 3

In [28]:
# apply the function on all elements of all the splits in the dataset including training, validation, and testing data
# remove the old columns since the preprocessing changes the number of samples
# results are cached. Pass "load_from_cache_file=False" to force the preprocessing to be applied again
tokenized_datasets = datasets.map(
    prepare_train_features, 
    batched=True, 
    remove_columns=datasets["train"].column_names
)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




Fine-Tune the Model

In [29]:
# import Pytorch pretrained model for question answering
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

# from_pretrained method downloads and caches the model
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

# warning regarding not using weights and layers is normal. we are removing the 
# masked language modeling head to pretrain the model on the QA task for which
# we do not have pretrained weights and requires fine-tuning

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased a

In [31]:
########### INPUT ###########
# training arguments is a class that contains the attributes to customize training
# set the folder name f"model-dataset", which will be used to save checkpoints
# set the learning_rate, number of epochs, and weight_decay
# batch_size has been set at the beginning of the notebook 
args = TrainingArguments(
    f"bert-squad-squad",
    evaluation_strategy = "epoch",
    learning_rate = 2e-5,
    per_device_train_batch_size = batch_size,
    per_device_eval_batch_size = batch_size,
    num_train_epochs = 3,
    weight_decay = 0.01,
)

In [32]:
# import a default data collator
from transformers import default_data_collator

# set the data_collator to the default data collator
data_collator = default_data_collator

In [33]:
# pass all of the training arguments and datasets to the trainer
trainer = Trainer(
    model,
    args,
    train_dataset = tokenized_datasets["train"],
    eval_dataset = tokenized_datasets["validation"],
    data_collator = data_collator,
    tokenizer = tokenizer,
)

In [34]:
# finetune the model by calling train method
# running this cell will take time.
trainer.train()

Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,1.7061,1.631146,8.5536,143.799
2,1.2206,1.308223,8.566,143.591
3,0.8834,1.479719,8.5678,143.561


TrainOutput(global_step=3603, training_loss=1.3697598192912157, metrics={'train_runtime': 1413.4466, 'train_samples_per_second': 2.549, 'total_flos': 1.4461795647157248e+16, 'epoch': 3.0, 'init_mem_cpu_alloc_delta': 335594, 'init_mem_gpu_alloc_delta': 436709888, 'init_mem_cpu_peaked_delta': 18306, 'init_mem_gpu_peaked_delta': 0, 'train_mem_cpu_alloc_delta': 1026859, 'train_mem_gpu_alloc_delta': 1308423168, 'train_mem_cpu_peaked_delta': 95566772, 'train_mem_gpu_peaked_delta': 8290968576})

In [35]:
########### INPUT ###########
# save the model. input the model name ("model-dataset-trained")
trainer.save_model("bert-squad-squad-trained")

Evaluation

In [36]:
# the validation features will need to be re-processed similar to the training features
# the processing will also need to check if the output span is inside the context (and not in the question)
# it will also need to retrieve the text inside

def prepare_validation_features(examples):
    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [37]:
# apply the function to validation set 
# remove the old columns since the preprocessing changes the number of samples
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))




In [38]:
# extract the predictions for all features using method trainer.predict
raw_predictions = trainer.predict(validation_features)

In [39]:
# trainer hides columns not used by the model. the columns needed for post-processing are set back 
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

In [40]:
########### INPUT ###########
# to classify answers, we use the score obtained by adding the start and end logits
# limit the number of possible answers by setting n_best_size
# limit the length of the answer by setting max_answer_length
n_best_size = 20
max_answer_length = 30

In [41]:
# get the output logits from trainer
import torch

for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

odict_keys(['loss', 'start_logits', 'end_logits'])

In [42]:
# code to verify the score and corresponding text are working correctly
import numpy as np

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]
# The first feature comes from the first example. For the more general case, we will need to be match the example_id to
# an example index
context = datasets["validation"][0]["context"]

# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
        # to part of the input_ids that are not in the context.
        if (
            start_index >= len(offset_mapping)
            or end_index >= len(offset_mapping)
            or offset_mapping[start_index] is None
            or offset_mapping[end_index] is None
        ):
            continue
        # Don't consider answers with a length that is either < 0 or > max_answer_length.
        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
            continue
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            start_char = offset_mapping[start_index][0]
            end_char = offset_mapping[end_index][1]
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": context[start_char: end_char]
                }
            )

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

[{'score': 2.1693544, 'text': 'southern California'},
 {'score': 0.43769825, 'text': 'Another definition for southern California'},
 {'score': -0.4170426,
  'text': 'southern California uses Point Conception and the Tehachapi Mountains as the northern boundary.'},
 {'score': -0.45477343,
  'text': 'San Luis Obispo, Kern, and San Bernardino counties. Another definition for southern California'},
 {'score': -2.0959477, 'text': 'definition for southern California'},
 {'score': -2.1486988,
  'text': 'Another definition for southern California uses Point Conception and the Tehachapi Mountains as the northern boundary.'},
 {'score': -2.3444123, 'text': 'California'},
 {'score': -2.6859019,
  'text': 'southern California uses Point Conception and the Tehachapi Mountains as the northern boundary'},
 {'score': -2.9122484, 'text': 'southern'},
 {'score': -2.986081, 'text': 'southern California uses Point Conception'},
 {'score': -3.0515785,
  'text': 'Point Conception and the Tehachapi Mountains

In [43]:
# view the actual answer
datasets["validation"][0]["answers"]

{'answer_start': [], 'text': []}

In [44]:
# apply the process above to all features by mapping between examples and their corresponding features
import collections

examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)

In [45]:
# to handle the non-answerable questions, we need to extract the score for the impossible answer
# the score is collected from minimum of the scores from the CLS token for each feature generated by the example
# the question is not answerable when that score is greater than the highest answerable score
from tqdm.auto import tqdm
import numpy as np

def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []
        
        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )
        
        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}
        
        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

In [46]:
# apply the postprocessing function to the raw predictions
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

Post-processing 1204 example predictions split into 1230 features.


HBox(children=(FloatProgress(value=0.0, max=1204.0), HTML(value='')))




In [47]:
########### INPUT ###########
# load the metric from the datasets library
metric = load_metric("squad_v2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2264.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=3182.0, style=ProgressStyle(description…




In [48]:
# format predictions and labels
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

{'HasAns_exact': 64.32078559738135,
 'HasAns_f1': 72.07416699105619,
 'HasAns_total': 611,
 'NoAns_exact': 51.93929173693086,
 'NoAns_f1': 51.93929173693086,
 'NoAns_total': 593,
 'best_exact': 58.222591362126245,
 'best_exact_thresh': 0.0,
 'best_f1': 62.157239228849825,
 'best_f1_thresh': 0.0,
 'exact': 58.222591362126245,
 'f1': 62.15723922884988,
 'total': 1204}