<a href="https://colab.research.google.com/github/josephho1013/M2R_Group_25_2023/blob/main/NLP_QnA_20230629.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References: https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt

# Preparing the data

In [1]:
!pip install datasets evaluate transformers[sentencepiece]



In [2]:
from datasets import load_dataset # import from (Hugging Face) dataset library

raw_datasets = load_dataset("squad") # load SQuAD dataset https://rajpurkar.github.io/SQuAD-explorer/



  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
raw_datasets # take a look at this object

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [4]:
# first element of our training set
print("Context: ", raw_datasets["train"][0]["context"])
print("Question: ", raw_datasets["train"][0]["question"])
print("Answer: ", raw_datasets["train"][0]["answers"])

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer:  {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


<code>answers</code> field compiles of a dictionary with two fields of lists.

In [5]:
# verify that for each training data, there is only one possible answer
raw_datasets["train"].filter(lambda x: len(x["answers"]["text"]) != 1)



Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 0
})

In [6]:
# however, there are several possible answers for each sample in validation data
print(raw_datasets["validation"][0]["answers"])
print(raw_datasets["validation"][2]["answers"])

{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}
{'text': ['Santa Clara, California', "Levi's Stadium", "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California."], 'answer_start': [403, 355, 355]}


In [7]:
print(raw_datasets["validation"][2]["context"])
print(raw_datasets["validation"][2]["question"])

Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
Where did Super Bowl 50 take place?


## Processing the training data

Step 1: Convert text in the input into IDs

In [8]:
import torch
from transformers import AutoTokenizer

In [9]:
model_checkpoint = "bert-base-cased" # fine-tune a BERT model
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Tokenizer inserts special tokens to form a sentence in the following form:\
<code>[CLS] question [SEP] context [SEP]</code>

In [10]:
context = raw_datasets["train"][0]["context"]
question = raw_datasets["train"][0]["question"]

inputs = tokenizer(question, context)
tokenizer.decode(inputs["input_ids"])

'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'

Specifications:
- <code>max_length</code>: maximum length of context
- <code>truncation="only_second"</code>: truncate the context when the question with its context exceeds maximum length
- <code>stride</code>: set number of overlapping tokens between two successive chunks (if set higher value, each input will have more overlapping tokens so the continuation of context can be maintained. However, this would increase the number of inputs due to high degree of overlapping.)
- <code>return_overflowing_tokens=TRUE</code>: let tokenizer know we want the overflowing tokens
- <code>return_offsets_mapping=TRUE</code>: include <code>offset_mapping</code> as one of the inputs in <code>input</code> (<code>offset_mapping</code> is ...)

In [11]:
def inputs(question, context,*,max_length=100, stride=50):
  """
  Initiate input from tokenization of sample.

  Parameters
  ---------------
  question: Question of the data
  context: Context of the data
  max_length: maximum length in each input
  stride: number of overlapping tokens between two successive chunks
  """
  inputs = tokenizer(
                    question,
                    context,
                    max_length=max_length,
                    truncation="only_second",
                    stride=stride,
                    return_overflowing_tokens=True,
                    return_offsets_mapping=True,
                )
  return inputs

inputs1 = inputs(raw_datasets["train"][0]["question"],
                 raw_datasets["train"][0]["context"])

for ids in inputs1["input_ids"]:
    print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the B

In [12]:
inputs1.keys() # get back usual input IDs, token type IDs and attention mask

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])

In [13]:
inputs1["offset_mapping"][1:2]
# returns the start and end positions of the labels
# (start_position, end_position) if answer is in correpsonding span of context
# (0,0) otherwise

[[(0, 0),
  (0, 2),
  (3, 7),
  (8, 11),
  (12, 15),
  (16, 22),
  (23, 27),
  (28, 37),
  (38, 44),
  (45, 47),
  (48, 52),
  (53, 55),
  (56, 59),
  (59, 63),
  (64, 70),
  (70, 71),
  (0, 0),
  (152, 155),
  (156, 160),
  (161, 169),
  (170, 173),
  (174, 180),
  (181, 183),
  (183, 184),
  (185, 187),
  (188, 189),
  (190, 196),
  (197, 203),
  (204, 206),
  (207, 213),
  (214, 218),
  (219, 223),
  (224, 226),
  (226, 229),
  (229, 232),
  (233, 237),
  (238, 241),
  (242, 248),
  (249, 250),
  (250, 251),
  (251, 254),
  (254, 256),
  (257, 259),
  (260, 262),
  (263, 264),
  (264, 265),
  (265, 268),
  (268, 269),
  (269, 270),
  (271, 275),
  (276, 278),
  (279, 282),
  (283, 287),
  (288, 296),
  (297, 299),
  (300, 303),
  (304, 312),
  (313, 315),
  (316, 319),
  (320, 326),
  (327, 332),
  (332, 333),
  (334, 345),
  (346, 352),
  (353, 356),
  (357, 358),
  (358, 361),
  (361, 365),
  (366, 368),
  (369, 372),
  (373, 374),
  (374, 377),
  (377, 379),
  (379, 380),
  (381,

More inputs!

In [14]:
inputs2 = inputs(raw_datasets["train"][1:5]["question"],
                 raw_datasets["train"][1:5]["context"])

answers = raw_datasets["train"][1:5]["answers"]
start_positions = []
end_positions = []

for i, offset in enumerate(inputs2["offset_mapping"]):
    sample_idx = inputs2["overflow_to_sample_mapping"][i]
    answer = answers[sample_idx]
    start_char = answer["answer_start"][0]
    end_char = answer["answer_start"][0] + len(answer["text"][0])
    sequence_ids = inputs2.sequence_ids(i)

    # Find the start and end of the context
    idx = 0
    while sequence_ids[idx] != 1:
        idx += 1
    context_start = idx
    while sequence_ids[idx] == 1:
        idx += 1
    context_end = idx - 1

    # If the answer is not fully inside the context, label is (0, 0)
    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
        start_positions.append(0)
        end_positions.append(0)
    else:
        # Otherwise it's the start and end token positions
        idx = context_start
        while idx <= context_end and offset[idx][0] <= start_char:
            idx += 1
        start_positions.append(idx - 1)

        idx = context_end
        while idx >= context_start and offset[idx][1] >= end_char:
            idx -= 1
        end_positions.append(idx + 1)

start_positions, end_positions

([53, 17, 0, 0, 83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0],
 [57, 21, 0, 0, 85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0])

In [15]:
# verify first result (53,57)
idx = 0
sample_idx = inputs2["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs2["input_ids"][idx][start : end + 1])

print(f"Theoretical answer: {answer}, labels give: {labeled_answer}")
# exact match

Theoretical answer: a copper statue of Christ, labels give: a copper statue of Christ


In [16]:
# verify last result (0,0)
idx = -1
sample_idx = inputs2["overflow_to_sample_mapping"][idx]
answer = answers[sample_idx]["text"][0]

decoded_example = tokenizer.decode(inputs2["input_ids"][idx])
print(f"Theoretical answer: {answer}, decoded example: {decoded_example}")
# answer does not appear in the truncated text

Theoretical answer: a golden statue of the Virgin Mary, decoded example: [CLS] What sits on top of the Main Building at Notre Dame? [SEP]to at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]


In [17]:
# define function for preprocessing training data

def preprocess_training_examples(examples, *, max_length=384, stride=128):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

Apply the above function to the whole training set (using <code>Dataset.map()</code> method with <code>batched=True</code>)

In [18]:
train_dataset = raw_datasets["train"].map(
    preprocess_training_examples,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
)

len(raw_datasets["train"]), len(train_dataset)
# preprocessing added roughly 1000 features
# (since one example can give several training features)



(87599, 88729)

## Preprocessing the validation data

In [19]:
# define function for preprocessing validation data

def preprocess_validation_examples(examples, *, max_length=384, stride=128):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

Apply the above function on whole validation dataset (using <code>Dataset.map()</code> method with <code>batched=True</code>)

In [20]:
validation_dataset = raw_datasets["validation"].map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)
len(raw_datasets["validation"]), len(validation_dataset)
# contexts in validation dataset are shorter than that of training

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

(10570, 10822)

# Fine-tuning BERT model with <code>Trainer</code> API

The model will output **logits** (raw mathematical outputs before transformed into probabilities) for the start and end positions of the answer in the input IDs.

Step 1: Use default model for Question-Answering pipeline to generate some predictions on a small part of the validation set.

In [21]:
small_eval_set = raw_datasets["validation"].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
    preprocess_validation_examples,
    batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)



In [22]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

from transformers import AutoModelForQuestionAnswering

# remove columns of eval_set that are not expected by the model
eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch = {k: eval_set_for_model[k].to(device) for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
    device
)

with torch.no_grad():
    outputs = trained_model(**batch)

In [23]:
outputs.keys()

odict_keys(['start_logits', 'end_logits'])

In [24]:
# convert the start and end logits into NumPy arrays
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

Find a predicted answer for each sample in the small evaluation set (<code>small_eval_set</code>).

In [25]:
import collections

# map each example in small_eval_set to the corresponding features in eval_set
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
    example_to_features[feature["example_id"]].append(idx)

Loop through all the examples and loop through all the associated features for each sample. We’ll look at the logit scores for the n_best start logits and end logits, excluding positions that give:
- An answer that wouldn’t be inside the context
- An answer with negative length
- An answer that is too long (we limit the possibilities at max_answer_length=30)

Once we have all the scored possible answers for one example, we just pick the one with the best logit score.

In [26]:
import numpy as np

n_best = 20
max_answer_length = 30
predicted_answers = []

for example in small_eval_set:
    example_id = example["id"]
    context = example["context"]
    answers = []

    for feature_index in example_to_features[example_id]:
        start_logit = start_logits[feature_index]
        end_logit = end_logits[feature_index]
        offsets = eval_set["offset_mapping"][feature_index]

        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
        for start_index in start_indexes:
            for end_index in end_indexes:
                # Skip answers that are not fully in the context
                if offsets[start_index] is None or offsets[end_index] is None:
                    continue
                # Skip answers with a length that is either < 0 or > max_answer_length.
                if (
                    end_index < start_index
                    or end_index - start_index + 1 > max_answer_length
                ):
                    continue

                answers.append(
                    {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                )

    best_answer = max(answers, key=lambda x: x["logit_score"])
    predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

Predicted answers

In [27]:
import evaluate # from (Hugging Face) evaluate library

metric = evaluate.load("squad")

Formatting answers for the metric:
- Predicted answers: a list of dictionaries with one key for the ID of the example and one key for the predicted text
- Theoretical answers: a list of dictionaries with one key for the ID of the example and one key for the possible answers

In [28]:
metric

EvaluationModule(name: "squad", module_type: "metric", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:
    predictions: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair as given in the references (see below)
        - 'prediction_text': the text of the answer
    references: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair (see above),
        - 'answers': a Dict in the SQuAD dataset format
            {
                'text': list of possible texts for the answer, as a list of strings
                'answer_start': list of start positions for 

In [29]:
theoretical_answers = [
    {"id": ex["id"], "answers": ex["answers"]} for ex in small_eval_set
]

In [30]:
# metric scores
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 83.0, 'f1': 88.25000000000004}

In [31]:
# compute_metrics function

from tqdm.auto import tqdm


def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [32]:
# verify with our predictions
compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/100 [00:00<?, ?it/s]

{'exact_match': 83.0, 'f1': 88.25000000000004}

## Fine-tuning the model

In [33]:
# create our model
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForQuestionAnswering: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['qa_outputs

Set up hyperparameters in <code>TrainingArguments</code>:

In [34]:
!pip install transformers[torch]
!pip install accelerate -U

import accelerate
print(accelerate.__version__)

0.20.3


In [35]:
from transformers import TrainingArguments, Trainer

# set up hyperparameters
args = TrainingArguments(
    "bert-finetuned-squad",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    #push_to_hub=True,
)

# launch training
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)
trainer.train()
# use GPU for runtime type for time-efficient training (still takes > 2h in this laptop)
# results saved in bert-finetuned-squad/runs

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
500,2.5011
1000,1.5882
1500,1.4649
2000,1.3556
2500,1.2976
3000,1.2347
3500,1.2533
4000,1.2163
4500,1.2095
5000,1.1394


TrainOutput(global_step=33276, training_loss=0.8290525221569576, metrics={'train_runtime': 7987.3415, 'train_samples_per_second': 33.326, 'train_steps_per_second': 4.166, 'total_flos': 5.216534983896422e+16, 'train_loss': 0.8290525221569576, 'epoch': 3.0})

In [36]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets["validation"])

  0%|          | 0/10570 [00:00<?, ?it/s]

{'exact_match': 81.28666035950805, 'f1': 88.85165588263959}