In [None]:
!pip install datasets evaluate transformers[sentencepiece]
!pip install accelerate

# Question answering

The *extractive* question answering involves posing questions about a document and identifying the answers as *spans of text* in the document itself.

We will fine-tune a BERT model on the SQuAD dataset, which consists of questions posed by crowdworkers on a set of Wikipedia articles.

Encoder-only models like BERT tend to be great at extracting answers to factoid questions, but fare poorly when given open-ended questions. In these more challenging cases, encoder-decoder models like T5 and BART are typically used to synthesize the information in a way that's quite similar to text summarization.

## Preparing the data

### The SQuAD dataset

In [None]:
from datasets import load_dataset

raw_datasets = load_dataset('squad')

In [3]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [4]:
idx = 0
print(f"Context: {raw_datasets['train'][idx]['context']}")
print(f"Question: {raw_datasets['train'][idx]['question']}")
print(f"Answer: {raw_datasets['train'][idx]['answers']}")

Context: Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question: To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


The `answer` field contains two fields that are both lists: the `text` field is the string answer, and the `answer_start` field contains the starting character index of each answer in the context.

During training, there is only one possible answer.

In [None]:
raw_datasets['train'].filter(lambda x: len(x['answers']['text']) != 1)

For evaluation, however, there are several possible answers for each sample, which may be the same or different:

In [6]:
print(raw_datasets['validation'][0]['answers'])
print(raw_datasets['validation'][2]['answers'])

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}
{'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}


Some of the questions have several possible answers, and this script will compare a predicted answer to all the acceptable answers and take the best score. For example:

In [7]:
print(raw_datasets['validation'][2]['context'])
print(raw_datasets['validation'][2]['question'])

Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
Where did Super Bowl 50 take place?


### Processing the training data

The hard part will be to generate labels for the question's answer, which will be the start and end positions of the tokens corresponding to the answer inside the context.

First, we need to convert the text in the input into IDs the model can make sense of, using a tokenizer:

In [None]:
from transformers import AutoTokenizer

model_checkpoint = 'bert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [9]:
tokenizer.is_fast

True

We can pass to our tokenizer the question and the context together, and it will properly insert the special tokens to form a sentence like this:
```
[CLS] question [SEP] context [SEP]
```

In [10]:
context = raw_datasets['train'][0]['context']
question = raw_datasets['train'][0]['question']

inputs = tokenizer(question, context)
tokenizer.decode(inputs['input_ids'])

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

When we explored the intervals of the `question-answering` pipeline, we will deal with long contexts by creating several training features from one sample of our dataset, with a sliding window between them.

To see how this works using the current example, we can limit the length to 100 and use a sliding window of 50 tokens. We will use
* `max_length` to set the maximum length (here 100)
* `truncation='only_second'` to truncate the context (which is in the second position) when the question with its context is too long
* `stride` to set the number of overlapping tokens between two successive chunks (here 50)
* `return_overflowing_tokens=True` to let the tokenizer know we want the overflowing tokens

In [11]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation='only_second',
    stride=50,
    return_overflowing_tokens=True,
)

for ids in inputs['input_ids']:
    print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the B

By dealing with long contexts in thsi way we will create some training examples where the answer is not included in the context. For those examples, the labels will be `start_position = end_position = 0` (so we predict the `[CLS]` token).

We will also set those labels in the unfortunate case where the answer has been truncated so that we only have the start (or end) of it.

For the examples where the answer is fully in the context, the labels will be the index of the token where the answer starts and the index of the token where the answer ends.

The dataset provides us with the start character of the answer in the context, and by adding the length of the answer, we can find the end character in the context. To map those to token indices, we will need to use the offset mappings. We can have our tokenizer return these by passing along `return_offsets_mapping=True`:

In [12]:
inputs = tokenizer(
    question,
    context,
    max_length=100,
    truncation='only_second',
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

inputs.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

The corresponding value will be of use to use when we tokenize several texts at the same time. Since one sample can give several features, it maps each feature to the example it originated from. Because here we only tokenized one example, we get a list of `0`s:

In [13]:
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0]

If we tokenize more examples:

In [14]:
inputs = tokenizer(
    raw_datasets['train'][:6]['question'],
    raw_datasets['train'][:6]['context'],
    max_length=100,
    truncation='only_second',
    stride=50,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
)

print(f"The 6 examples give {len(inputs['input_ids'])} features.")
print(f"Here is where each comes from: {inputs['overflow_to_sample_mapping']}")

The 6 examples give 27 features.
Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5]


We can see that the first five examples each gave four features and the last example gave seven features.

This information will be useful to map each feature we get to its corresponding label. Those labels are:
* `(0, 0)` if the answer is not in the corresponding span of the context
* `(start_position, end_position)` if the answer is in the corresponding span of the context, with `start_position` being the index of the token (in the input IDs) at the start of the answer and `end_position` being the index of the token (in the input IDs) where the answer ends

To determine the case, we will use the `sequence_ids()` method of the `BatchEncoding` our tokenizer returns.

Once we have those token indices, we look at the corresponding offsets, which are tuples of two integers representing the span of characters inside the original context. We can thus detect if the chunk of the context in this feature starts after the answer or ends before the answer begins (in which case the label is `(0, 0)`). If that's not the case, we loop to find the first and last token of the answer:

In [15]:
answers = raw_datasets['train'][:6]['answers']
start_positions = []
end_positions = []

for i, offset in enumerate(inputs['offset_mapping']):
    sample_idx = inputs['overflow_to_sample_mapping'][i]
    answer = answers[sample_idx]
    start_char = answer['answer_start'][0]
    end_char = answer['answer_start'][0] + len(answer['text'][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx

    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # otherwise, it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

print(start_positions)
print(end_positions)

[0, 0, 72, 40, 53, 17, 0, 0, 83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0]
[0, 0, 78, 46, 57, 21, 0, 0, 85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0]


Double check to verify our approach is correct.

In [16]:
# the non-zero answer positions
idx = 2
sample_idx = inputs['overflow_to_sample_mapping'][idx]
answer = answers[sample_idx]['text'][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs['input_ids'][idx][start : end+1])

print(f"Theoretical answer: {answer}")
print(f"labels give: {labeled_answer}")

Theoretical answer: Saint Bernadette Soubirous
labels give: Saint Bernadette Soubirous


In [17]:
# the zero answer positions
idx = 6
sample_idx = inputs['overflow_to_sample_mapping'][idx]
answer = answers[sample_idx]['text'][0]

decoded_example = tokenizer.decode(inputs['input_ids'][idx])
print(f"Theoretical answer: {answer}")
print(f"labels give: {decoded_example}")

Theoretical answer: a copper statue of Christ
labels give: [CLS] What is in front of the Notre Dame Main Building? [SEP] of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern [SEP]


We do not see the answer inside the context.

Now that we have seen step by step how to preprocess our training data, we can group it in a function we will apply on the whole training dataset. We will pad every feature to the maximum length we set, as most of the contexts will be long (and the corresponding samples will be split into several features), so there is no real benefit to applying dynamic padding here:

In [18]:
max_length = 384
stride = 128

def preprocess_training_examples(examples):
    questions = [q.strip() for q in examples['question']]
    inputs = tokenizer(
        questions,
        examples['context'],
        max_length=max_length,
        truncation='only_second',
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )

    offset_mapping = inputs.pop('offset_mapping')
    sample_map = inputs.pop('overflow_to_sample_mapping')
    answers = examples['answers']
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer['answer_start'][0]
        end_char = answer['answer_start'][0] + len(answer['text'][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs['start_positions'] = start_positions
    inputs['end_positions'] = end_positions
    return inputs

To apply this function to the whole trianing set, we use the `Dataset.map()` method with the `batched=True` flag:

In [19]:
train_dataset = raw_datasets['train'].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets['train'].column_names,
)

len(raw_datasets['train']), len(train_dataset)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

(87599, 88729)

As we can see, the preprocessing added more features to the training set.

### Processing the validation data

We don't need to generate labels (unless we want to compute a validation loss, but that number will not really help us understand how good the model is).

The only thing we will add is a tiny cleanup of the offset mappings. They will contain offsets for the question and the context, but once we are in the post-processing stage we do not have any way to know which part of the input IDs corresponded to the context and which part was the question (the `sequence_ids()` method we used is available for the output of the tokenizer only). So we will set the offsets corresponding to the question to `None`:

In [20]:
def preprocess_validation_examples(examples):
    question = [q.strip() for q in examples['question']]
    inputs = tokenizer(
        question,
        examples['context'],
        max_length=max_length,
        truncation='only_second',
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding='max_length',
    )

    sample_map = inputs.pop('overflow_to_sample_mapping')
    example_ids = []

    for i in range(len(inputs['input_ids'])):
        sample_idx = sample_map[i]
        example_ids.append(examples['id'][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs['offset_mapping'][i]
        inputs['offset_mapping'][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs['example_id'] = example_ids
    return inputs

In [None]:
validation_dataset = raw_datasets['validation'].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets['validation'].column_names,
)

len(raw_datasets['validation']), len(validation_dataset)

## Fine-tuning the model with the Trainer API

The most difficult thing is to write the `compute_metrics()` funciton. Since we padded all the samples to the maximum length we set, there is no data collator to define. The difficult part will be to post-process the model predictions into spans of text in the original examples.

### Post-processing

THe model will output logits for the start and end positions of the answer in the input IDs. The post-processing step will involve:
* mask the start and end logits corresponding to tokens outside of the context
* convert the start and end logits into probabilities using a softmax
* attribute a score to each `(start_token, end_token)` pair by taking the product of the corresponding two probabilities
* look for the pair with the maximum score that yielded a valid answer (e.g., a `start_token` lower than `end_token`)

We will change this process slightly because we don't need to compute actual scores (just the predicted answer). This means that we can skip the softmax step. To go faster, we also won't score all the possible `(start_token, end_token)` pairs, but only the ones corresponding to the hightest `n_best` logits (with `n_best=20`).  Since we skip the softmax, those scores will be logit scores, and will be obtained by taking the sum of the start and end logits (instead of the product, because of the rule $\log(ab) = \log(a) + \log(b)$.

To demonstrate all of this, we will need some kind of predictions. Since we have not trained our model yet, we are going to use the default model for the QA pipeline to generate some predictions on a small part of the validation set. We can use the same processing function as before; because it relies on the global constant `tokenizer`, we just have to change that object to the tokenizer of the model we want to use temporarily:

In [None]:
small_eval_set = raw_datasets['validation'].select(range(100))

trained_checkpoint = 'distilbert-base-cased-distilled-squad'
tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)

eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets['validation'].column_names,
)

Now that the preprocessing is done, we change the tokenizer back to the one we originally picked:

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
import torch
from transformers import AutoModelForQuestionAnswering

eval_set_for_model = eval_set.remove_columns(['example_id', 'offset_mapping'])
eval_set_for_model.set_format('torch')

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}

trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(device)

with torch.no_grad():
    outputs = trained_model(**batch)

Since the `Trainer` will give us predictions as NumPy arrays, we grab the start and end logits and convert them to that format:

In [25]:
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

Now we need to find the predicted answer for each example in our `small_eval_set`. One example may have been split into several features in `eval-set`, so the first step is to map each example in `small_eval_set` to the corresponding features in `eval_set`:

In [26]:
import collections

example_to_features = collections.defaultdict(list)

for idx, feature in enumerate(eval_set):
    example_to_features[feature['example_id']].append(idx)

With this in hand, we can get to work by looping through all the examples and, for each example, through all the associated features. We will look at the logit scores for the `n_best` start logits and end logits, excluding positions that give:
* an answer that would not be inside the context
* an answer with negative length
* an answer that is too long (we limit the possibilities at `max_answer_length=30`)

In [28]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example['id']
    context = example['context']
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set['offset_mapping'][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue

                # Skip answers with a length that is either <0 or > max_answer_length
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        'text': context[offsets[start_index][0] : offsets[end_index][1]],
                        'logit_score': start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x['logit_score'])
    predicted_answers.append(
        {
            'id': example_id,
            'prediction_text': best_answer['text'],
        }
    )

In [None]:
import evaluate

metric = evaluate.load('squad')

In [34]:
theoretical_answers = [
    {'id': ex['id'],
     'answers': ex['answers']}
    for ex in small_eval_set
]

In [35]:
print(predicted_answers[0])
print(theoretical_answers[0])

{'id': '56be4db0acb8001400a502ec', 'prediction_text': 'Denver Broncos'}
{'id': '56be4db0acb8001400a502ec', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}}


In [36]:
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 83.0, 'f1': 88.25000000000004}

Normally, the `compute_metrics()` function only receives a tuple `eval_preds` with logits and labels. Here as we have to look in the dataset of features for the offset and in the dataset of examples for the original context, so we will not be able to use this function to get regular evaluation results during training. We will only use it at the end of training to check the results.

In [37]:
from tqdm.auto import tqdm


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [38]:
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/100 [00:00<?, ?it/s]

{'exact_match': 83.0, 'f1': 88.25000000000004}

### Fine-tuning the model

In [39]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


As usual, we get a warning that some weights are not used (the ones from the pretraining head) and some others are initialized randomly (the ones for the question answering head).

Next is the `TrainingArguments`. When we defined our function to compute the metric, we will not be able to hae a regular evaluation loop because of the signature of the `compute_metirc()` function. Instead, we will only evaluate the model at the end of training here:

In [40]:
from transformers import TrainingArguments

args = TrainingArguments(
    'bert-finetuned-squad',
    evaluation_strategy='no',
    save_strategy='epoch',
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    push_to_hub=False,
)



In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)

In [None]:
trainer.train()

In [None]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets['validation'])

## A custom training loop

### Preparing everything for training

First we need to build the `DataLoader` from our datasets. Then we can use the `default_data_collator` as a `collate_fn`

In [41]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataset.set_format('torch')
valid_dataset = validation_dataset.remove_columns(['example_id', 'offset_mapping'])
valid_dataset.set_format('torch')

train_dataloader = DataLoader(
    train_dataset,
    shuffle=True,
    collate_fn=default_data_collator,
    batch_size=8,
)
eval_dataloader = DataLoader(
    valid_dataset,
    collate_fn=default_data_collator,
    batch_size=8,
)

In [42]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [43]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)

Set up the Accelerator

In [44]:
from accelerate import Accelerator

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader,
)

We can only use the `train_dataloader` length to compute the number of training steps after it has gone through the `accelerator.prepare()` method.

In [45]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [None]:
model_name = "bert-finetuned-squad-accelerate"
output_dir = "bert-finetuned-squad-accelerate"

### Training loop

In [None]:
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for step, batch in enumerate(train_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    start_logits = []
    end_logits = []
    accelerator.print("Evaluation!")
    for batch in tqdm(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)

        start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
        end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

    start_logits = np.concatenate(start_logits)
    end_logits = np.concatenate(end_logits)
    start_logits = start_logits[: len(validation_dataset)]
    end_logits = end_logits[: len(validation_dataset)]

    metrics = compute_metrics(
        start_logits, end_logits, validation_dataset, raw_datasets["validation"]
    )
    print(f"epoch {epoch}:", metrics)

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)

## Using the fine-tuned model

In [47]:
from transformers import pipeline

# Replace this with your own checkpoint
model_checkpoint = "huggingface-course/bert-finetuned-squad"
question_answerer = pipeline("question-answering", model=model_checkpoint)

In [48]:
context = """
🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration
between them. It's straightforward to train your models with one before loading them for inference with the other.
"""
question = "Which deep learning libraries back 🤗 Transformers?"
question_answerer(question=question, context=context)

{'score': 0.9978997707366943,
 'start': 78,
 'end': 105,
 'answer': 'Jax, PyTorch and TensorFlow'}