# Fine-Tuning BERT on SQuAD v1.0 and TriviaQA

The TriviaQA Dataset is also a question/answer dataset similar to the SQuAD dataset. The paper states that the F1 and EM score improves when finetuning the BERT large model first on the larger TriviaQA dataset and then on the SQuAD dataset.

## 0. Configuration

In [1]:
import os
# Store the huggingface data in a shared group folder on the provided JupyterLab instance.
os.environ['HF_HOME'] = '../../groups/192.039-2024W/bert/huggingface/cache'

In [2]:
from pathlib import Path
from transformers import set_seed

# RANDOMNESS SEED
SEED = 42
set_seed(SEED)

# Which datasets to load
DATASET_NAME_SQUAD = "squad"
DATASET_NAME_TRIVIAQA = "trivia_qa"

TRAIN_OUTPUT_DIR = (
    Path("../../groups/192.039-2024W/bert") / "training" / f"{DATASET_NAME_SQUAD}-{DATASET_NAME_TRIVIAQA}"
)

BATCH_SIZE = 32  # Original Paper claims to use 32 for the SQuAD task
NUM_EPOCHS = 3  # Original Paper claims to use 3 fine-tuning epochs for the SQuAD task

In [3]:
import torch

if torch.cuda.is_available():
  device = torch.device("cuda")
  device_count = torch.cuda.device_count()
  device_name = torch.cuda.get_device_name(0)

  print(f"There are {device_count} GPU(s) available.")
  print(f"GPU used: {device_name}")
  ! nvidia-smi -q --display=MEMORY,COMPUTE

else:
  print("No GPU available, using CPU.")
  device = torch.device("cpu")

There are 1 GPU(s) available.
GPU used: NVIDIA A40


Timestamp                                 : Tue Jan 28 13:47:49 2025
Driver Version                            : 550.90.07
CUDA Version                              : 12.4

Attached GPUs                             : 1
GPU 00000000:05:00.0
    FB Memory Usage
        Total                             : 46068 MiB
        Reserved                          : 665 MiB
        Used                              : 23791 MiB
        Free                              : 21613 MiB
    BAR1 Memory Usage
        Total                             : 65536 MiB
        Used                              : 4 MiB
        Free                              : 65532 MiB
    Conf Compute Protected Memory Usage
        Total                             : 0 MiB
        Used                              : 0 MiB
        Free                              : 0 MiB
    Compute Mode                          : Default



## 1. Datasets

In [4]:
from datasets import load_dataset
import pandas as pd

dataset_triviaqa = load_dataset("mandarjoshi/trivia_qa", "rc")
dataset_triviaqa

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/24 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['question', 'question_id', 'question_source', 'entity_pages', 'search_results', 'answer'],
        num_rows: 138384
    })
    validation: Dataset({
        features: ['question', 'question_id', 'question_source', 'entity_pages', 'search_results', 'answer'],
        num_rows: 17944
    })
    test: Dataset({
        features: ['question', 'question_id', 'question_source', 'entity_pages', 'search_results', 'answer'],
        num_rows: 17210
    })
})

Since the TriviaQA dataset is quite large and it will take a lot of hours for training, we will use the first 50000 samples of it.

In [5]:
small_train_dataset_triviaqa = dataset_triviaqa["train"].select(range(50000))
len(small_train_dataset_triviaqa)

50000

In [6]:
dataset_squad = load_dataset(DATASET_NAME_SQUAD)
dataset_squad

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

## 2. Finetuning on TriviaQA

In [7]:
PRE_TRAINED_CHECKPOINT = "google-bert/bert-large-uncased"

### 2.1 Tokenization

In [8]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT)

#### 2.1.1 Preprocessing training dataset

In [9]:
def preprocess_triviaqa_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    contexts = []
    for entity_page, search_result in zip(examples["entity_pages"], examples["search_results"]):
        wiki_context = entity_page.get("wiki_context", [])
        search_context = search_result.get("search_context", [])
        if wiki_context:
            context = wiki_context[0].strip()
        elif search_context:
            context = search_context[0].strip()
        else:
            context = ""
        contexts.append(context)

    inputs = tokenizer(
        questions,
        contexts,
        max_length=512,
        truncation="only_second",
        stride=min(128, tokenizer.model_max_length // 2),
        padding="max_length",
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
    )

    start_positions = []
    end_positions = []
    answers = examples["answer"]
    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]["value"]
        ct = contexts[sample_idx]
        answer_start = context.find(answer)
        answer_end = answer_start + len(answer)
        sequence_ids = inputs.sequence_ids(i)

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        if offset[context_start][0] > answer_start or offset[context_end][1] < answer_end:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= answer_start:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= answer_end:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
        
    return inputs

In [10]:
train_dataset_triviaqa = small_train_dataset_triviaqa.map(
    preprocess_triviaqa_training_examples,
    batched=True,
    remove_columns=small_train_dataset_triviaqa.column_names,
)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [11]:
len(small_train_dataset_triviaqa), len(train_dataset_triviaqa)

(50000, 840275)

In [12]:
import pandas as pd

with pd.option_context('display.max_colwidth', 400):
    display(pd.DataFrame(train_dataset_triviaqa[:1]).transpose())

Unnamed: 0,0
input_ids,"[101, 2029, 2137, 1011, 2141, 11881, 2180, 1996, 10501, 3396, 2005, 3906, 1999, 4479, 1029, 102, 1996, 10501, 3396, 1999, 3906, 4479, 1996, 10501, 3396, 1999, 3906, 4479, 11881, 4572, 1996, 10501, 3396, 1999, 3906, 4479, 11881, 4572, 3396, 3745, 1024, 1015, 1013, 1015, 1996, 10501, 3396, 1999, 3906, 4479, 2001, 3018, 2000, 11881, 4572, 1000, 2005, 2010, 21813, 1998, 8425, 2396, 1997, 6412, 199..."
token_type_ids,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
attention_mask,"[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
start_positions,0
end_positions,0


In [13]:
import pandas as pd

with pd.option_context('display.max_colwidth', 400):
    display(pd.DataFrame(small_train_dataset_triviaqa[:1]).transpose())

Unnamed: 0,0
question,Which American-born Sinclair won the Nobel Prize for Literature in 1930?
question_id,tc_1
question_source,http://www.triviacountry.com/
entity_pages,"{'doc_source': [], 'filename': [], 'title': [], 'wiki_context': []}"
search_results,"{'description': ['The Nobel Prize in Literature 1930 Sinclair ... The Nobel Prize in Literature 1930 was awarded to ... nobelprize.org/nobel_prizes/literature/laureates/1930/>', 'Why Don’t More Americans Win the Nobel Prize? By . ... When the Nobel Prize in Literature was awarded to Sinclair ... In 1930, Lewis told his Nobel audience that ...', '... Sauk Centre native Sinclair Lewis became the..."
answer,"{'aliases': ['(Harry) Sinclair Lewis', 'Harry Sinclair Lewis', 'Lewis, (Harry) Sinclair', 'Grace Hegger', 'Sinclair Lewis'], 'normalized_aliases': ['grace hegger', 'lewis harry sinclair', 'harry sinclair lewis', 'sinclair lewis'], 'matched_wiki_entity_name': '', 'normalized_matched_wiki_entity_name': '', 'normalized_value': 'sinclair lewis', 'type': 'WikipediaEntity', 'value': 'Sinclair Lewis'}"


#### 2.1.2 Preprocessing validation dataset

After preprocessing the training dataset, we will preprocess the validation dataset. This differs a little bit from the preprocessing of the training dataset, because we do not need to generate labels. This would only be necessary when we want to compute a validation loss, but since that number does not really tell us if the model is good or not, we will not compute it.

In [14]:
def preprocess_triviaqa_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    contexts = []
    for entity_page, search_result in zip(examples["entity_pages"], examples["search_results"]):
        wiki_context = entity_page.get("wiki_context", [])
        search_context = search_result.get("search_context", [])
        if wiki_context:
            context = wiki_context[0].strip()
        elif search_context:
            context = search_context[0].strip()
        else:
            context = ""
        contexts.append(context)

    inputs = tokenizer(
        questions,
        contexts,
        max_length=512,
        truncation="only_second",
        stride=min(128, tokenizer.model_max_length // 2),
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    # Map each overflowed tokenization back to its corresponding example
    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["question_id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    # Add example IDs for evaluation
    inputs["example_id"] = example_ids
    return inputs


In [15]:
validation_dataset_triviaqa = dataset_triviaqa["validation"].map(
    preprocess_triviaqa_validation_examples,
    batched=True,
    remove_columns=dataset_triviaqa["validation"].column_names,
)

In [16]:
len(dataset_triviaqa["validation"]), len(validation_dataset_triviaqa)

(17944, 296523)

In [17]:
test_dataset_triviaqa = dataset_triviaqa["test"].map(
    preprocess_triviaqa_validation_examples,
    batched=True,
    remove_columns=dataset_triviaqa["test"].column_names,
)

In [18]:
len(dataset_triviaqa["test"]), len(test_dataset_triviaqa)

(17210, 284990)

### 2.2 Metrics

In [19]:
import evaluate

# Since there is no specific metric for the TriviaQA dataset, we will load the metric for the SQuAD dataset.
# We can do that, because they are both question/answering datasets and we are preprocessing them the same way.
metric = evaluate.load(DATASET_NAME_SQUAD)
metric

EvaluationModule(name: "squad", module_type: "metric", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:
    predictions: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair as given in the references (see below)
        - 'prediction_text': the text of the answer
    references: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair (see above),
        - 'answers': a Dict in the SQuAD dataset format
            {
                'text': list of possible texts for the answer, as a list of strings
                'answer_start': list of start positions for 

In [20]:
import collections
import numpy as np
from tqdm.auto import tqdm

n_best = 20
max_answer_length = 30

def compute_metrics_triviaqa(eval_pred):
    predictions, _ = eval_pred
    start_logits, end_logits = predictions
    features = validation_dataset_triviaqa
    examples = dataset_triviaqa["validation"]
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["question_id"]
        contexts = []
        for entity_page, search_result in zip(examples["entity_pages"], examples["search_results"]):
            wiki_context = entity_page.get("wiki_context", [])
            search_context = search_result.get("search_context", [])
            if wiki_context:
                context = wiki_context[0].strip()
            elif search_context:
                context = search_context[0].strip()
            else:
                context = ""
            contexts.append(context)
            answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

### 2.3 Training

In [21]:
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained(PRE_TRAINED_CHECKPOINT)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at google-bert/bert-large-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [22]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT.replace("/", "_")).resolve(),
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    learning_rate=2e-5,
    weight_decay=0.01,
    save_total_limit=2,
    fp16=True,
)

In [23]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset_triviaqa,
    eval_dataset=validation_dataset_triviaqa,
    processing_class=tokenizer,
    compute_metrics=compute_metrics_triviaqa
)

In [24]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
training_summary_bert_large_triviaqa = trainer.train()

--- training_arguments.output_dir='/home/e12433721/groups/192.039-2024W/bert/training/squad-trivia_qa/google-bert_bert-large-uncased'


Epoch,Training Loss,Validation Loss
1,4.0208,No log
2,6.2395,No log
3,6.2395,No log


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [25]:
training_summary_bert_large_triviaqa

TrainOutput(global_step=157554, training_loss=5.499922447224444, metrics={'train_runtime': 68380.4298, 'train_samples_per_second': 36.865, 'train_steps_per_second': 2.304, 'total_flos': 2.3411078034310656e+18, 'train_loss': 5.499922447224444, 'epoch': 3.0})

In [26]:
# Define the save directory for the fine-tuned model
triviaqa_model_dir = (TRAIN_OUTPUT_DIR / "triviaqa_model").resolve()

# Save the fine-tuned model and tokenizer
model.save_pretrained(triviaqa_model_dir)
tokenizer.save_pretrained(triviaqa_model_dir)

print(f"Model and tokenizer saved to: {triviaqa_model_dir}")


Model and tokenizer saved to: /home/e12433721/groups/192.039-2024W/bert/training/squad-trivia_qa/triviaqa_model


### 2.4 Evaluation

In [27]:
training_history_bert_large_triviaqa = pd.DataFrame(trainer.state.log_history)
training_history_bert_large_triviaqa.epoch = training_history_bert_large_triviaqa.epoch.astype(int)
training_history_bert_large_triviaqa.groupby("epoch").first()

Unnamed: 0_level_0,loss,grad_norm,learning_rate,step,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
1,4.0208,inf,1.333613e-05,52518,2045.2202,144.983,9.062,,,,,
2,6.2395,2.873855,6.671998e-06,105036,2026.404,146.33,9.146,,,,,
3,6.2395,2.443461,8.124199e-09,157554,2018.2761,146.919,9.183,68380.4298,36.865,2.304,2.341108e+18,5.499922


## 3. Finetuning on SQuAD

In [28]:
PRE_TRAINED_CHECKPOINT = triviaqa_model_dir

### 3.1 Tokenization

In [29]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_CHECKPOINT)

#### 3.1.1 Preprocessing training dataset

In [30]:
def preprocess_squad_training_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [31]:
train_dataset_squad = dataset_squad["train"].map(
    preprocess_squad_training_examples,
    batched=True,
    remove_columns=dataset_squad["train"].column_names,
)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

#### 3.1.2 Validation dataset

In [32]:
def preprocess_squad_validation_examples(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

In [33]:
validation_dataset_squad = dataset_squad["validation"].map(
    preprocess_squad_validation_examples,
    batched=True,
    remove_columns=dataset_squad["validation"].column_names,
)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

### 3.2 Metrics

In [34]:
import collections
import numpy as np
from tqdm.auto import tqdm

n_best = 20
max_answer_length = 30

def compute_metrics_squad(eval_pred):
    predictions, _ = eval_pred
    start_logits, end_logits = predictions
    features = validation_dataset_squad
    examples = dataset_squad["validation"]
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

### 3.3 Training

In [35]:
from transformers import AutoModelForQuestionAnswering

model = AutoModelForQuestionAnswering.from_pretrained(PRE_TRAINED_CHECKPOINT)

In [40]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir=(TRAIN_OUTPUT_DIR / PRE_TRAINED_CHECKPOINT).resolve(),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    learning_rate=2e-5,  # Original paper uses 5e-5
    weight_decay=0.01,
    save_total_limit=3,
)

In [41]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset_squad,
    eval_dataset=validation_dataset_squad,
    processing_class=tokenizer,
    compute_metrics=compute_metrics_squad
)

In [42]:
torch.cuda.empty_cache()

print(f"--- {training_arguments.output_dir=}")
training_summary_bert_large_squad = trainer.train()

--- training_arguments.output_dir='/home/e12433721/groups/192.039-2024W/bert/training/squad-trivia_qa/triviaqa_model'


Epoch,Training Loss,Validation Loss
1,5.9508,No log
2,5.9509,No log
3,5.951,No log


In [43]:
training_summary_bert_large_squad

TrainOutput(global_step=8301, training_loss=5.9509137641549215, metrics={'train_runtime': 4572.3403, 'train_samples_per_second': 58.082, 'train_steps_per_second': 1.815, 'total_flos': 1.849789299850629e+17, 'train_loss': 5.9509137641549215, 'epoch': 3.0})

In [44]:
training_history_bert_large_squad = pd.DataFrame(trainer.state.log_history)
training_history_bert_large_squad.epoch = training_history_bert_large_squad.epoch.astype(int)
training_history_bert_large_squad.groupby("epoch").first()

Unnamed: 0_level_0,loss,grad_norm,learning_rate,step,eval_runtime,eval_samples_per_second,eval_steps_per_second,train_runtime,train_samples_per_second,train_steps_per_second,total_flos,train_loss
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
1,5.9508,1.369943,1.333333e-05,2767,48.6471,221.678,6.927,,,,,
2,5.9509,1.302679,6.669076e-06,5534,48.7133,221.377,6.918,,,,,
3,5.951,1.335975,7.228045e-09,8301,48.7052,221.414,6.919,4572.3403,58.082,1.815,1.849789e+17,5.950914


In [45]:
import seaborn as sns

data = training_history_bert_large_squad[["eval_f1", "eval_exact_match", "epoch"]]
data.columns = ["F1", "EM", "Training Epoch"]
data = data[:-1]
data = pd.melt(data, ['Training Epoch']).dropna()

plot = sns.lineplot(data=data, x="Training Epoch", y="value", hue="variable", style="variable", markers=True)
plot.set_ylabel("")
plot.set(xticks=list(set(training_history_bert_large_squad.epoch)))
plot.set_ylim((0, plot.get_ylim()[1]))
plot.legend(title="")


from IPython.display import Markdown, display
display(Markdown(f"### Loss and Evaluation Metrics over Training Epochs ({PRE_TRAINED_CHECKPOINT})"))

KeyError: "['eval_f1', 'eval_exact_match'] not in index"

### 3.4 Evaluation

In [46]:
import evaluate

# Since there is no specific metric for the TriviaQA dataset, we will load the metric for the SQuAD dataset.
# We can do that, because they are both question/answering datasets and we are preprocessing them the same way.
metric = evaluate.load(DATASET_NAME_SQUAD)
metric

EvaluationModule(name: "squad", module_type: "metric", features: {'predictions': {'id': Value(dtype='string', id=None), 'prediction_text': Value(dtype='string', id=None)}, 'references': {'id': Value(dtype='string', id=None), 'answers': Sequence(feature={'text': Value(dtype='string', id=None), 'answer_start': Value(dtype='int32', id=None)}, length=-1, id=None)}}, usage: """
Computes SQuAD scores (F1 and EM).
Args:
    predictions: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair as given in the references (see below)
        - 'prediction_text': the text of the answer
    references: List of question-answers dictionaries with the following key-values:
        - 'id': id of the question-answer pair (see above),
        - 'answers': a Dict in the SQuAD dataset format
            {
                'text': list of possible texts for the answer, as a list of strings
                'answer_start': list of start positions for 

In [51]:
import collections
import numpy as np
from tqdm.auto import tqdm

n_best = 20
max_answer_length = 30

def compute_metrics_triviaqa_test(predictions):
    start_logits, end_logits = predictions
    features = test_dataset_triviaqa
    examples = dataset_triviaqa["test"]
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["question_id"]
        contexts = []
        for entity_page, search_result in zip(examples["entity_pages"], examples["search_results"]):
            wiki_context = entity_page.get("wiki_context", [])
            search_context = search_result.get("search_context", [])
            if wiki_context:
                context = wiki_context[0].strip()
            elif search_context:
                context = search_context[0].strip()
            else:
                context = ""
            contexts.append(context)
            answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

In [54]:
small_test_dataset_triviaqa = dataset_triviaqa["test"].select(range(500))

In [55]:
small_tokenized_test_dataset_triviaqa = small_test_dataset_triviaqa.map(
    preprocess_triviaqa_validation_examples,
    batched=True,
    remove_columns=dataset_triviaqa["test"].column_names,
)

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [57]:
predictions, _, _ = trainer.predict(small_tokenized_test_dataset_triviaqa)
scores_bert_large_squad = compute_metrics_triviaqa_test(predictions)

  0%|          | 0/17210 [00:00<?, ?it/s]

IndexError: index 154492 is out of bounds for axis 0 with size 9011

In [None]:
from IPython.display import Markdown, display

display(Markdown(f"### Model performance:"))
final_results = pd.DataFrame(
    [scores_bert_large_squad["exact_match"]] + [scores_bert_large_squad["f1"]],
    index=["EM"] + ["F1"],
    columns=["our BERT_LARGE on TriviaQA Dev"],
)

# Achieved scores from original BERT paper:
final_results["original BERT_LARGE on TriviaQA Dev"] = [84.2,91.1]
final_results["our BERT_LARGE on TriviaQA Test"] = [training_history_bert_large_squad["eval_exact_match"] + training_history_bert_large_squad["eval_f1"]]
final_results["original BERT_LARGE on TriviaQA Test"] = [85.1, 91.8]

print(
    '"BERT_LARGE" performance on the TriviaQA and SQuAD dataset as reported in the original paper.'
)
results