# Fine-tuning BERT on SQUAD (question answering)

In [2]:
# !pip install datasets transformers
# !pip install accelerate -U

In [None]:
import transformers

print(transformers.__version__)

4.33.1


In [None]:
# This flag is the difference between SQUAD v1 or 2 (if you're using another dataset, it indicates if impossible
# answers are allowed or not).
squad_v2 = False
model_checkpoint = "prajjwal1/bert-tiny"
batch_size = 16

## Loading the dataset

In [None]:
from datasets import load_dataset, load_metric

In [None]:
datasets = load_dataset("squad_v2" if squad_v2 else "squad")

Downloading builder script:   0%|          | 0.00/5.27k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.36k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.67k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/8.12M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.05M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [None]:
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [None]:
datasets["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [None]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [None]:
show_random_elements(datasets["train"])

Unnamed: 0,id,title,context,question,answers
0,57278d02f1498d1400e8fbca,Carnival,"Tarragona has one of the region's most complete ritual sequences. The events start with the building of a huge barrel and ends with its burning with the effigies of the King and Queen. On Saturday, the main parade takes place with masked groups, zoomorphic figures, music and percussion bands, and groups with fireworks (the devils, the dragon, the ox, the female dragon). Carnival groups stand out for their clothes full of elegance, showing brilliant examples of fabric crafts, at the Saturday and Sunday parades. About 5,000 people are members of the parade groups.",About how many people are members of the various parade groups?,"{'text': ['5,000'], 'answer_start': [522]}"
1,57318d40a5e9cc1400cdc053,Steven_Spielberg,"His first professional TV job came when he was hired to direct one of the segments for the 1969 pilot episode of Night Gallery. The segment, ""Eyes,"" starred Joan Crawford; she and Spielberg were reportedly close friends until her death. The episode is unusual in his body of work, in that the camerawork is more highly stylized than his later, more ""mature"" films. After this, and an episode of Marcus Welby, M.D., Spielberg got his first feature-length assignment: an episode of The Name of the Game called ""L.A. 2017"". This futuristic science fiction episode impressed Universal Studios and they signed him to a short contract. He did another segment on Night Gallery and did some work for shows such as Owen Marshall: Counselor at Law and The Psychiatrist, before landing the first series episode of Columbo (previous episodes were actually TV films).",What law show did Spielberg work on?,"{'text': ['Owen Marshall: Counselor at Law'], 'answer_start': [706]}"
2,5728b2d64b864d1900164c48,Estonia,"In the face of the country being re-occupied by the Red Army, tens of thousands of Estonians (including a majority of the education, culture, science, political and social specialists) chose to either retreat with the Germans or flee to Finland or Sweden where they sought refuge in other western countries, often by refugee ships such as the SS Walnut. On 12 January 1949, the Soviet Council of Ministers issued a decree ""on the expulsion and deportation"" from Baltic states of ""all kulaks and their families, the families of bandits and nationalists"", and others.",How many Estonians chose to retreat or flee when in anticipation of another Soviet invasion?,"{'text': ['tens of thousands'], 'answer_start': [62]}"
3,5726be6b708984140094d015,Pope_Paul_VI,"Pope Paul VI knew the Roman Curia well, having worked there for a generation from 1922 to 1954. He implemented his reforms in stages, rather than in one fell swoop. On 1 March 1968, he issued a regulation, a process that had been initiated by Pius XII and continued by John XXIII. On 28 March, with Pontificalis Domus, and in several additional Apostolic Constitutions in the following years, he revamped the entire Curia, which included reduction of bureaucracy, streamlining of existing congregations and a broader representation of non-Italians in the curial positions.",Whose representation was enlarged through reforms in the Curia?,"{'text': ['non-Italians'], 'answer_start': [535]}"
4,57302968a23a5019007fcec9,"Santa_Monica,_California","As of the census of 2000, there are 84,084 people, 44,497 households, and 16,775 families in the city. The population density is 10,178.7 inhabitants per square mile (3,930.4/km²). There are 47,863 housing units at an average density of 5,794.0 per square mile (2,237.3/km²). The racial makeup of the city is 78.29% White, 7.25% Asian, 3.78% African American, 0.47% Native American, 0.10% Pacific Islander, 5.97% from other races, and 4.13% from two or more races. 13.44% of the population are Hispanic or Latino of any race. There are 44,497 households, out of which 15.8% have children under the age of 18, 27.5% are married couples living together, 7.5% have a female householder with no husband present, and 62.3% are non-families. 51.2% of all households are made up of individuals and 10.6% have someone living alone who is 65 years of age or older. The average household size is 1.83 and the average family size is 2.80.",From 2000 what was the average family size?,"{'text': ['2.80'], 'answer_start': [922]}"
5,5730b3dd069b531400832286,Super_Nintendo_Entertainment_System,"During the SNES's life, Nintendo contracted with two different companies to develop a CD-ROM-based peripheral for the console to compete with Sega's CD-ROM based addon, Mega-CD. Ultimately, deals with both Sony and Philips fell through, (although a prototype console was produced by Sony) with Philips gaining the right to release a series of titles based on Nintendo franchises for its CD-i multimedia player and Sony going on to develop its own console based on its initial dealings with Nintendo (the PlayStation).",What was Philips' multimedia system?,"{'text': ['CD-i'], 'answer_start': [387]}"
6,572fa6e204bcaa1900d76b44,Hyderabad,"Hyderabad has continued with these traditions in its annual Hyderabad Literary Festival, held since 2010, showcasing the city's literary and cultural creativity. Organisations engaged in the advancement of literature include the Sahitya Akademi, the Urdu Academy, the Telugu Academy, the National Council for Promotion of Urdu Language, the Comparative Literature Association of India, and the Andhra Saraswata Parishad. Literary development is further aided by state institutions such as the State Central Library, the largest public library in the state which was established in 1891, and other major libraries including the Sri Krishna Devaraya Andhra Bhasha Nilayam, the British Library and the Sundarayya Vignana Kendram.",What is one of the activities Telugu Academy is credited with encouraging?,"{'text': ['the advancement of literature'], 'answer_start': [187]}"
7,56e07007231d4119001ac12b,Saint_Helena,"St Helena has long been known for its high proportion of endemic birds and vascular plants. The highland areas contain most of the 400 endemic species recognised to date. Much of the island has been identified by BirdLife International as being important for bird conservation, especially the endemic Saint Helena plover or wirebird, and for seabirds breeding on the offshore islets and stacks, in the north-east and the south-west Important Bird Areas. On the basis of these endemics and an exceptional range of habitats, Saint Helena is on the United Kingdom's tentative list for future UNESCO World Heritage Sites.",Saint Helena is on the United Kingdom's list for future what?,"{'text': ['UNESCO World Heritage Sites'], 'answer_start': [589]}"
8,5728ebb7ff5b5019007da951,Apollo,"In the next century which is the beginning of the Classical period, it was considered that beauty in visible things as in everything else, consisted of symmetry and proportions. The artists tried also to represent motion in a specific moment (Myron), which may be considered as the reappearance of the dormant Minoan element. Anatomy and geometry are fused in one, and each does something to the other. The Greek sculptors tried to clarify it by looking for mathematical proportions, just as they sought some reality behind appearances. Polykleitos in his Canon wrote that beauty consists in the proportion not of the elements (materials), but of the parts, that is the interrelation of parts with one another and with the whole. It seems that he was influenced by the theories of Pythagoras. The famous Apollo of Mantua and its variants are early forms of the Apollo Citharoedus statue type, in which the god holds the cithara in his left arm. The type is represented by neo-Attic Imperial Roman copies of the late 1st or early 2nd century, modelled upon a supposed Greek bronze original made in the second quarter of the 5th century BCE, in a style similar to works of Polykleitos but more archaic. The Apollo held the cythara against his extended left arm, of which in the Louvre example, a fragment of one twisting scrolling horn upright remains against his biceps.",In what type of art does the god hold the cithara in his left arm?,"{'text': ['Apollo Citharoedus statue type'], 'answer_start': [861]}"
9,5730e8d2aca1c71400fe5b51,Great_power,"All states have a geographic scope of interests, actions, or projected power. This is a crucial factor in distinguishing a great power from a regional power; by definition the scope of a regional power is restricted to its region. It has been suggested that a great power should be possessed of actual influence throughout the scope of the prevailing international system. Arnold J. Toynbee, for example, observes that ""Great power may be defined as a political force exerting an effect co-extensive with the widest range of the society in which it operates. The Great powers of 1914 were 'world-powers' because Western society had recently become 'world-wide'.""",What type of powers should have actual influence throughout the scope of the prevailing international system?,"{'text': ['great power'], 'answer_start': [260]}"


## Preprocessing the training data

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading (…)lve/main/config.json:   0%|          | 0.00/285 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

In [None]:
import transformers
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

You can check which type of models have a fast tokenizer available and which don't on the [big table of models](https://huggingface.co/transformers/index.html#bigtable).

In [None]:
tokenizer(
    "Hey you know what I did last night?",
    "You'd better not bring my mother into this!"
)

{'input_ids': [101, 4931, 2017, 2113, 2054, 1045, 2106, 2197, 2305, 1029, 102, 2017, 1005, 1040, 2488, 2025, 3288, 2026, 2388, 2046, 2023, 999, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [None]:
max_length = 384 # The maximum length of a feature (question and context)
doc_stride = 128 # The authorized overlap between two part of the context when splitting it is needed.

Let's find one long example in our dataset:

In [None]:
for i, example in enumerate(datasets["train"]):
    if len(tokenizer(example["question"], example["context"])["input_ids"]) > 384:
        break

example = datasets["train"][i]

Without any truncation, we get the following length for the input IDs:

In [None]:
len(tokenizer(example["question"], example["context"])["input_ids"])

396

In [None]:
len(tokenizer(example["question"], example["context"], max_length=max_length, truncation="only_second")["input_ids"])

384

In [None]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    stride=doc_stride
)

In [None]:
[len(x) for x in tokenized_example["input_ids"]]

[384, 157]

In [None]:
for x in tokenized_example["input_ids"][:2]:
    print(tokenizer.decode(x))

[CLS] how many wins does the notre dame men's basketball team have? [SEP] the men's basketball team has over 1, 600 wins, one of only 12 schools who have reached that mark, and have appeared in 28 ncaa tournaments. former player austin carr holds the record for most points scored in a single game of the tournament with 61. although the team has never won the ncaa tournament, they were named by the helms athletic foundation as national champions twice. the team has orchestrated a number of upsets of number one ranked teams, the most notable of which was ending ucla's record 88 - game winning streak in 1974. the team has beaten an additional eight number - one teams, and those nine wins rank second, to ucla's 10, all - time in wins against the top team. the team plays in newly renovated purcell pavilion ( within the edmund p. joyce center ), which reopened for the beginning of the 2009 – 2010 season. the team is coached by mike brey, who, as of the 2014 – 15 season, his fifteenth at notr

In [None]:
tokenized_example = tokenizer(
    example["question"],
    example["context"],
    max_length=max_length,
    truncation="only_second",
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    stride=doc_stride
)
print(tokenized_example["offset_mapping"][0][:100])

[(0, 0), (0, 3), (4, 8), (9, 13), (14, 18), (19, 22), (23, 28), (29, 33), (34, 37), (37, 38), (38, 39), (40, 50), (51, 55), (56, 60), (60, 61), (0, 0), (0, 3), (4, 7), (7, 8), (8, 9), (10, 20), (21, 25), (26, 29), (30, 34), (35, 36), (36, 37), (37, 40), (41, 45), (45, 46), (47, 50), (51, 53), (54, 58), (59, 61), (62, 69), (70, 73), (74, 78), (79, 86), (87, 91), (92, 96), (96, 97), (98, 101), (102, 106), (107, 115), (116, 118), (119, 121), (122, 126), (127, 138), (138, 139), (140, 146), (147, 153), (154, 160), (161, 165), (166, 171), (172, 175), (176, 182), (183, 186), (187, 191), (192, 198), (199, 205), (206, 208), (209, 210), (211, 217), (218, 222), (223, 225), (226, 229), (230, 240), (241, 245), (246, 248), (248, 249), (250, 258), (259, 262), (263, 267), (268, 271), (272, 277), (278, 281), (282, 285), (286, 290), (291, 301), (301, 302), (303, 307), (308, 312), (313, 318), (319, 321), (322, 325), (326, 330), (330, 331), (332, 340), (341, 351), (352, 354), (355, 363), (364, 373), (374,

In [None]:
first_token_id = tokenized_example["input_ids"][0][1]
offsets = tokenized_example["offset_mapping"][0][1]
print(tokenizer.convert_ids_to_tokens([first_token_id])[0], example["question"][offsets[0]:offsets[1]])

how How


In [None]:
sequence_ids = tokenized_example.sequence_ids()
print(sequence_ids)

[None, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, None, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [None]:
answers = example["answers"]
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])

# Start token index of the current span in the text.
token_start_index = 0
while sequence_ids[token_start_index] != 1:
    token_start_index += 1

# End token index of the current span in the text.
token_end_index = len(tokenized_example["input_ids"][0]) - 1
while sequence_ids[token_end_index] != 1:
    token_end_index -= 1

# Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
offsets = tokenized_example["offset_mapping"][0]
if (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
    # Move the token_start_index and token_end_index to the two ends of the answer.
    # Note: we could go after the last offset if the answer is the last word (edge case).
    while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
        token_start_index += 1
    start_position = token_start_index - 1
    while offsets[token_end_index][1] >= end_char:
        token_end_index -= 1
    end_position = token_end_index + 1
    print(start_position, end_position)
else:
    print("The answer is not in this feature.")

23 26


In [None]:
print(tokenizer.decode(tokenized_example["input_ids"][0][start_position: end_position+1]))
print(answers["text"][0])

over 1, 600
over 1,600


For this notebook to work with any kind of models, we need to account for the special case where the model expects padding on the left (in which case we switch the order of the question and the context):

In [None]:
pad_on_right = tokenizer.padding_side == "right"

In [None]:
def prepare_train_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Let's label those examples!
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        cls_index = input_ids.index(tokenizer.cls_token_id)

        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples["answers"][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 1)

    return tokenized_examples

In [None]:
features = prepare_train_features(datasets['train'][:5])

In [None]:
tokenized_datasets = datasets.map(prepare_train_features, batched=True, remove_columns=datasets["train"].column_names)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

## Fine-tuning the model

In [None]:
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer

model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model_name = model_checkpoint.split("/")[-1]
args = TrainingArguments(
    f"{model_name}-finetuned-squad",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [None]:
from transformers import default_data_collator

data_collator = default_data_collator

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,2.7886,2.597881
2,2.7213,2.543613
3,2.6549,2.528444


TrainOutput(global_step=16599, training_loss=2.7392811506611374, metrics={'train_runtime': 492.394, 'train_samples_per_second': 539.349, 'train_steps_per_second': 33.711, 'total_flos': 242951010453504.0, 'train_loss': 2.7392811506611374, 'epoch': 3.0})

In [None]:
trainer.save_model("test-squad-trained")

## Evaluation

In [None]:
import torch

for batch in trainer.get_eval_dataloader():
    break
batch = {k: v.to(trainer.args.device) for k, v in batch.items()}
with torch.no_grad():
    output = trainer.model(**batch)
output.keys()

odict_keys(['loss', 'start_logits', 'end_logits'])

In [None]:
output.start_logits.shape, output.end_logits.shape

(torch.Size([16, 384]), torch.Size([16, 384]))

In [None]:
output.start_logits.argmax(dim=-1), output.end_logits.argmax(dim=-1)

(tensor([ 46,  57, 161,  43, 113, 162,  72,  41, 162,  40,  73,  60, 163, 163,
         170,  46], device='cuda:0'),
 tensor([167,  52, 164,  44,  98, 165,  75,  37, 165,  36,  76, 162, 166, 166,
         170,  14], device='cuda:0'))

In [None]:
n_best_size = 20

In [None]:
import numpy as np

start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": "" # We need to find a way to get back the original substring corresponding to the answer in the context
                }
            )

In [None]:
def prepare_validation_features(examples):
    # Some of the questions have lots of whitespace on the left, which is not useful and will make the
    # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
    # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]

    # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
    # in one example possible giving several features when a context is long, each of those features having a
    # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
        examples["question" if pad_on_right else "context"],
        examples["context" if pad_on_right else "question"],
        truncation="only_second" if pad_on_right else "only_first",
        max_length=max_length,
        stride=doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

    # We keep the example_id that gave us this feature and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 1 if pad_on_right else 0

        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

        # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
        # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
        ]

    return tokenized_examples

In [None]:
validation_features = datasets["validation"].map(
    prepare_validation_features,
    batched=True,
    remove_columns=datasets["validation"].column_names
)

In [None]:
raw_predictions = trainer.predict(validation_features)
raw_predictions

PredictionOutput(predictions=(array([[-0.84487706, -6.5158668 , -6.699659  , ..., -7.086204  ,
        -7.1218915 , -7.122274  ],
       [-0.6327973 , -6.5313115 , -6.730243  , ..., -7.085315  ,
        -7.1207337 , -7.1197276 ],
       [-0.59188217, -6.7405915 , -6.7658906 , ..., -7.2617173 ,
        -7.281343  , -7.2830105 ],
       ...,
       [ 0.04040329, -6.221375  , -6.6429706 , ..., -7.1646485 ,
        -7.266314  , -7.240399  ],
       [ 0.4619447 , -6.178924  , -6.403531  , ..., -7.161273  ,
        -7.2025027 , -7.205956  ],
       [ 0.53225005, -6.3233047 , -6.721399  , ..., -7.175015  ,
        -7.268193  , -7.251528  ]], dtype=float32), array([[-2.1922884, -7.358166 , -6.723622 , ..., -7.025218 , -7.0092382,
        -6.98412  ],
       [-2.0894186, -7.3664923, -6.735838 , ..., -7.0252223, -7.00969  ,
        -6.985519 ],
       [-2.0088491, -7.179729 , -7.1436296, ..., -7.0466237, -7.054239 ,
        -7.05232  ],
       ...,
       [-1.8530481, -7.3034873, -7.0524487, ...

In [None]:
raw_predictions.predictions[1].shape

(10784, 384)

In [None]:
validation_features.set_format(type=validation_features.format["type"], columns=list(validation_features.features.keys()))

In [None]:
max_answer_length = 30

In [None]:
start_logits = output.start_logits[0].cpu().numpy()
end_logits = output.end_logits[0].cpu().numpy()
offset_mapping = validation_features[0]["offset_mapping"]
# The first feature comes from the first example. For the more general case, we will need to be match the example_id to
# an example index
context = datasets["validation"][0]["context"]

# Gather the indices the best start/end logits:
start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
valid_answers = []
for start_index in start_indexes:
    for end_index in end_indexes:
        # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
        # to part of the input_ids that are not in the context.
        if (
            start_index >= len(offset_mapping)
            or end_index >= len(offset_mapping)
            or offset_mapping[start_index] is None
            or offset_mapping[end_index] is None
        ):
            continue
        # Don't consider answers with a length that is either < 0 or > max_answer_length.
        if end_index < start_index or end_index - start_index + 1 > max_answer_length:
            continue
        if start_index <= end_index: # We need to refine that test to check the answer is inside the context
            start_char = offset_mapping[start_index][0]
            end_char = offset_mapping[end_index][1]
            valid_answers.append(
                {
                    "score": start_logits[start_index] + end_logits[end_index],
                    "text": context[start_char: end_char]
                }
            )

valid_answers = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[:n_best_size]
valid_answers

[{'score': 7.3279552, 'text': 'Denver Broncos'},
 {'score': 7.3080225,
  'text': 'Denver Broncos defeated the National Football Conference'},
 {'score': 7.3068123,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers'},
 {'score': 6.8693733,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24'},
 {'score': 6.5827866, 'text': 'Carolina Panthers'},
 {'score': 6.3987026,
  'text': 'Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third'},
 {'score': 6.3633423, 'text': 'Arabic numerals'},
 {'score': 6.1453476, 'text': 'Carolina Panthers 24'},
 {'score': 5.927129,
  'text': 'Denver Broncos defeated the National Football Conference (NFC'},
 {'score': 5.789795, 'text': 'champion Denver Broncos'},
 {'score': 5.769862,
  'text': 'champion Denver Broncos defeated the National Football Conference'},
 {'score': 5.768652,
  'text': 'champio

In [None]:
datasets["validation"][0]["answers"]

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'],
 'answer_start': [177, 177, 177]}

In [None]:
import collections

examples = datasets["validation"]
features = validation_features

example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
features_per_example = collections.defaultdict(list)
for i, feature in enumerate(features):
    features_per_example[example_id_to_index[feature["example_id"]]].append(i)

In [None]:
from tqdm.auto import tqdm

def postprocess_qa_predictions(examples, features, raw_predictions, n_best_size = 20, max_answer_length = 30):
    all_start_logits, all_end_logits = raw_predictions
    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    predictions = collections.OrderedDict()

    # Logging.
    print(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_score = None # Only used if squad_v2 is True.
        valid_answers = []

        context = example["context"]
        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]

            # Update minimum null prediction.
            cls_index = features[feature_index]["input_ids"].index(tokenizer.cls_token_id)
            feature_null_score = start_logits[cls_index] + end_logits[cls_index]
            if min_null_score is None or min_null_score < feature_null_score:
                min_null_score = feature_null_score

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or offset_mapping[end_index] is None
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    start_char = offset_mapping[start_index][0]
                    end_char = offset_mapping[end_index][1]
                    valid_answers.append(
                        {
                            "score": start_logits[start_index] + end_logits[end_index],
                            "text": context[start_char: end_char]
                        }
                    )

        if len(valid_answers) > 0:
            best_answer = sorted(valid_answers, key=lambda x: x["score"], reverse=True)[0]
        else:
            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            best_answer = {"text": "", "score": 0.0}

        # Let's pick our final answer: the best one or the null answer (only for squad_v2)
        if not squad_v2:
            predictions[example["id"]] = best_answer["text"]
        else:
            answer = best_answer["text"] if best_answer["score"] > min_null_score else ""
            predictions[example["id"]] = answer

    return predictions

In [None]:
final_predictions = postprocess_qa_predictions(datasets["validation"], validation_features, raw_predictions.predictions)

Post-processing 10570 example predictions split into 10784 features.


  0%|          | 0/10570 [00:00<?, ?it/s]

In [None]:
metric = load_metric("squad_v2" if squad_v2 else "squad")

In [None]:
if squad_v2:
    formatted_predictions = [{"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in final_predictions.items()]
else:
    formatted_predictions = [{"id": k, "prediction_text": v} for k, v in final_predictions.items()]
references = [{"id": ex["id"], "answers": ex["answers"]} for ex in datasets["validation"]]
metric.compute(predictions=formatted_predictions, references=references)

{'exact_match': 35.28855250709555, 'f1': 47.450340124821345}

In [None]:
trainer.push_to_hub()