# TyDiQA

In this notebook, we will see how to fine-tune and evaluate a model on the TyDiQA dataset.

# Dependencies

If not already done, make sure to install PrimeQA with `notebooks` extras before getting started.

In [None]:
# If you want CUDA 11 uncomment and run this (for CUDA 10 or CPU you can ignore this line).
#! pip install 'torch~=1.11.0' --extra-index-url https://download.pytorch.org/whl/cu113

# Uncomment to install PrimeQA from source (pypi package pending).
# The path should be the project root (e.g. '.' below).
#! pip install .[notebooks]

# Configuration

We start by setting some parameters to configure the process.  Note that depending on the GPU being used you may need to tune the batch size.

In [1]:
# This needs to be filled in.
output_dir = '/Users/maltak/code/PQA_models/model_1'        # Save the results here.  Will overwrite if directory already exists.

# Optional parameters (feel free to leave as default).
model_name = 'xlm-roberta-base'  # Set this to select the LM.  Since this is a multi-lingual dataset, we use the XLM-Roberta model.
cache_dir = None                 # Set this if you have a cache directory for transformers.  Alternatively set the HF_HOME env var.
train_batch_size = 8             # Set this to change the number of features per batch during training.
eval_batch_size = 8              # Set this to change the number of features per batch during evaluation.
gradient_accumulation_steps = 8  # Set this to effectively increase training batch size.
max_train_samples = 100          # Set this to use a subset of the training data (or None for all).
max_eval_samples = 20            # Set this to use a subset of the evaluation data (or None for all).
num_train_epochs = 1             # Set this to change the number of training epochs.
fp16 = False                     # Set this to true to enable fp16 (hardware support required).
filter_language = ['english', 'arabic']    # Set this to only use examples of these languages.
num_examples_to_show = 10        # Set this to change the number of random train examples (and their features) to show.

In [2]:
from transformers import TrainingArguments
from transformers.trainer_utils import set_seed

seed = 42
set_seed(seed)

training_args = TrainingArguments(
    output_dir=output_dir,
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=eval_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    num_train_epochs=num_train_epochs,
    evaluation_strategy='no',
    learning_rate=4e-05,
    warmup_ratio=0.1,
    weight_decay=0.1,
    save_steps=50000,
    fp16=fp16,
    seed=seed,
)

  from .autonotebook import tqdm as notebook_tqdm


ImportError: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (/Users/maltak/anaconda3/envs/primeqa/lib/python3.9/site-packages/charset_normalizer/constant.py)

# Loading the Model

Here we load the model and tokenizer based on the model_name parameter set above.  We use a model with an extractive QA task head which we will later fine-tune.

In [4]:
from transformers import AutoConfig, AutoTokenizer
from primeqa.mrc.models.heads.extractive import EXTRACTIVE_HEAD
from primeqa.mrc.models.task_model import ModelForDownstreamTasks

from primeqa.mrc.trainers.mrc import MRCTrainer

task_heads = EXTRACTIVE_HEAD
config = AutoConfig.from_pretrained(
    model_name,
    cache_dir=cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    cache_dir=cache_dir,
    use_fast=True,
    config=config,
)
model = ModelForDownstreamTasks.from_config(
    config,
    model_name,
    task_heads=task_heads,
    cache_dir=cache_dir,
)
model.set_task_head(next(iter(task_heads)))

print(model)  # Examine the model structure

Downloading config.json: 100%|██████████| 615/615 [00:00<00:00, 500kB/s]
Downloading tokenizer_config.json: 100%|██████████| 25.0/25.0 [00:00<00:00, 17.9kB/s]
Downloading (…)tencepiece.bpe.model: 100%|██████████| 5.07M/5.07M [00:00<00:00, 10.8MB/s]
Downloading tokenizer.json: 100%|██████████| 9.10M/9.10M [00:01<00:00, 5.92MB/s]
Downloading pytorch_model.bin: 100%|██████████| 1.12G/1.12G [00:42<00:00, 26.3MB/s]


{"time":"2024-05-08 16:14:55,342", "name": "ExtractiveQAHead", "level": "INFO", "message": "Loading dropout value 0.1 from config attribute 'hidden_dropout_prob'"}


Some weights of XLMRobertaModelForDownstreamTasks were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['task_heads.qa_head.classifier.out_proj.bias', 'task_heads.qa_head.classifier.dense.bias', 'task_heads.qa_head.classifier.dense.weight', 'task_heads.qa_head.qa_outputs.bias', 'task_heads.qa_head.qa_outputs.weight', 'task_heads.qa_head.classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{"time":"2024-05-08 16:14:55,610", "name": "XLMRobertaModelForDownstreamTasks", "level": "INFO", "message": "Setting task head for first time to 'None'"}
XLMRobertaModelForDownstreamTasks(
  (roberta): XLMRobertaModel(
    (embeddings): XLMRobertaEmbeddings(
      (word_embeddings): Embedding(250002, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): XLMRobertaEncoder(
      (layer): ModuleList(
        (0): XLMRobertaLayer(
          (attention): XLMRobertaAttention(
            (self): XLMRobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(

# Loading Data

Here we load the TyDiQA dataset using Huggingface's datasets library.

In [5]:
import datasets

raw_datasets = datasets.load_dataset(
    'tydiqa',
    'primary_task',
    cache_dir=cache_dir,
)

if filter_language:
    raw_datasets = raw_datasets.filter(lambda example: example['language'] in filter_language)

train_examples = raw_datasets["train"]
max_train_samples = max_train_samples
if max_train_samples is not None:
    # We will select sample from whole data if argument is specified
    train_examples = train_examples.select(range(max_train_samples))

print(f"Using {train_examples.num_rows} train examples.")

eval_examples = raw_datasets["validation"]
max_eval_samples = max_eval_samples
if max_eval_samples is not None:
    # We will select sample from whole data if argument is specified
    eval_examples = eval_examples.select(range(max_eval_samples))

print(f"Using {eval_examples.num_rows} eval examples.")

Downloading builder script: 13.3kB [00:00, 21.2MB/s]                   
Downloading metadata: 6.73kB [00:00, 12.1MB/s]                   


Downloading and preparing dataset tydiqa/primary_task (download: 1.82 GiB, generated: 5.62 GiB, post-processed: Unknown size, total: 7.44 GiB) to /Users/maltak/.cache/huggingface/datasets/tydiqa/primary_task/1.0.0/b8a6c4c0db10bf5703d7b36645e5dbae821b8c0e902dac9daeecd459a8337148...


Downloading data: 100%|██████████| 1.73G/1.73G [01:32<00:00, 18.8MB/s]
Downloading data: 100%|██████████| 161M/161M [00:09<00:00, 17.2MB/s]]
Downloading data files: 100%|██████████| 2/2 [01:43<00:00, 51.81s/it]
Extracting data files: 100%|██████████| 2/2 [00:20<00:00, 10.23s/it]
Downloading data: 100%|██████████| 58.0M/58.0M [00:03<00:00, 15.9MB/s]
Downloading data: 100%|██████████| 5.62M/5.62M [00:01<00:00, 3.40MB/s]
Downloading data files: 100%|██████████| 2/2 [00:07<00:00,  3.63s/it]
Extracting data files: 100%|██████████| 2/2 [00:00<00:00, 1346.05it/s]
                                                                                           

Dataset tydiqa downloaded and prepared to /Users/maltak/.cache/huggingface/datasets/tydiqa/primary_task/1.0.0/b8a6c4c0db10bf5703d7b36645e5dbae821b8c0e902dac9daeecd459a8337148. Subsequent calls will reuse this data.


100%|██████████| 2/2 [00:00<00:00, 21.80it/s]
100%|██████████| 167/167 [00:25<00:00,  6.67ba/s]
100%|██████████| 19/19 [00:02<00:00,  8.00ba/s]

Using 100 train examples.
Using 20 eval examples.





# Preprocessing

Here we preprocess the data to create features which can be given to the model.

In [6]:
from primeqa.mrc.processors.preprocessors.tydiqa import TyDiQAPreprocessor

preprocessor = TyDiQAPreprocessor(
    stride=128,
    tokenizer=tokenizer,
)

# Train Feature Creation
with training_args.main_process_first(desc="train dataset map pre-processing"):
    train_examples, train_dataset = preprocessor.process_train(train_examples)

print(f"Preprocessing produced {train_dataset.num_rows} train features from {train_examples.num_rows} examples.")

# Validation Feature Creation
with training_args.main_process_first(desc="validation dataset map pre-processing"):
    eval_examples, eval_dataset = preprocessor.process_eval(eval_examples)

print(f"Preprocessing produced {eval_dataset.num_rows} eval features from {eval_examples.num_rows} examples.")

{"time":"2024-05-08 16:26:42,210", "name": "TyDiQAPreprocessor", "level": "INFO", "message": "TyDiQAPreprocessor only supports single context multiple passages -- enabling"}


100%|██████████| 100/100 [00:00<00:00, 7760.91ex/s]
100%|██████████| 100/100 [00:00<00:00, 1082.59ex/s]
100%|██████████| 1/1 [00:00<00:00, 51.64ba/s]
Running tokenizer on train dataset: 100%|██████████| 1/1 [00:05<00:00,  5.46s/ba]


Preprocessing produced 92 train features from 100 examples.


100%|██████████| 20/20 [00:00<00:00, 3670.36ex/s]
100%|██████████| 20/20 [00:00<00:00, 856.39ex/s]
100%|██████████| 1/1 [00:00<00:00, 126.88ba/s]
Running tokenizer on eval dataset: 100%|██████████| 1/1 [00:01<00:00,  1.27s/ba]

Preprocessing produced 415 eval features from 20 examples.





In [7]:
from datasets import ClassLabel, Sequence
import random
import pandas as pd
from IPython.display import display, HTML

# Based on https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb
def show_elements(dataset):
    df = pd.DataFrame(dataset)
    for column, typ in dataset.features.items():
        if isinstance(typ, ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
        elif isinstance(typ, Sequence) and isinstance(typ.feature, ClassLabel):
            df[column] = df[column].transform(lambda x: [typ.feature.names[i] for i in x])
    display(HTML(df.to_html()))

In [8]:
import random

def trim_document(example, max_len=500):
    example['context'] = example['context'][0]
    doc_len = len(example['context'])
    if doc_len > max_len:
        example['context'] = f"{example['context'][:max_len - 3]}..."        
    return example

random_idxs = random.sample(range(len(train_examples)), num_examples_to_show)
random_train_examples = train_examples.select(random_idxs).remove_columns(['document_plaintext', 'passage_candidates'])
random_train_examples = random_train_examples.map(trim_document)

show_elements(random_train_examples)  # Show random train examples

100%|██████████| 10/10 [00:00<00:00, 5895.84ex/s]


Unnamed: 0,question,document_title,language,target,document_url,context,example_id
0,How many people died during WW1?,World War I casualties,english,"{'end_positions': [93], 'passage_indices': [0], 'start_positions': [83], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/World%20War%20I%20casualties,"\n\n\n\nThe total number of military and civilian casualties in World War I were about 40 million: estimates range from 15 to 19million deaths and about 23million wounded military personnel, ranking it among the deadliest conflicts in human history.\nThe total number of deaths includes from 9 to 11 million military personnel. The civilian death toll was about 8 million, including about 6 million due to war-related famine and disease. The Triple Entente (also known as the Allies) lost about 6 milli...",f799851c-e506-4366-9650-78aaa0ce4e8a
1,Can the DC character Nightwing fly?,Nightwing,english,"{'end_positions': [-1], 'passage_indices': [5], 'start_positions': [-1], 'yes_no_answer': ['YES']}",https://en.wikipedia.org/wiki/Nightwing,"\n\n\nNightwing is a fictional superhero appearing in American comic books published by DC Comics. The character has appeared in various incarnations, with the Nightwing identity most prominently being adopted by Dick Grayson when he moved on from his role as Batman's vigilante partner Robin.\nAlthough Nightwing is commonly associated with Batman, the title and concept have origins in classic Superman stories. The original Nightwing in DC Comics was an identity assumed by alien superhero Superman...",d235a8a4-7374-4e2d-8685-bfa1e51939d9
2,When did Juan Rivera start professional wrestling?,Ron Rivera,english,"{'end_positions': [-1], 'passage_indices': [-1], 'start_positions': [-1], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Ron%20Rivera,"\n\n\nRonald Eugene ""Ron"" Rivera (born January 7, 1962)[1] also known as ""Riverboat Ron"" is an American football coach and former player who is the head coach of the Carolina Panthers of the National Football League (NFL). He has also been the defensive coordinator for the Chicago Bears and San Diego Chargers.\nRivera played college football at the University of California in Berkeley, and was recognized as an All-American linebacker. He was selected in the second round of the 1984 NFL draft by t...",352c9d7b-21b9-4d1c-82a4-2cfc47d1c90a
3,When were bluebonnets named the state flower of Texas?,Bluebonnet (plant),english,"{'end_positions': [629], 'passage_indices': [2], 'start_positions': [616], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Bluebonnet%20%28plant%29,"\nBluebonnet is a name given to any number of blue-flowered species of the genus Lupinus predominantly found in southwestern United States and is collectively the state flower of Texas. The shape of the petals on the flower resembles the bonnet worn by pioneer women to shield them from the sun.[1]\nSpecies often called bluebonnets include:\nLupinus argenteus, silvery lupine\nLupinus concinnus, Bajada lupine\nLupinus havardii, Big Bend bluebonnet or Chisos bluebonnet\nLupinus plattensis, Nebraska lu...",aa18b029-01bc-4725-8d0a-5b800034534e
4,"Who wrote the song ""Happy Days""?",Happy Days (TV theme),english,"{'end_positions': [65], 'passage_indices': [0], 'start_positions': [36], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Happy%20Days%20%28TV%20theme%29,"\n\n""Happy Days"" is a song written by Norman Gimbel and Charles Fox. It is the theme song of the 1970s television series Happy Days.[3] It can be heard during the TV show's opening and closing credits as it runs in perpetual rerun syndication.\nThe song was first recorded in 1974 by Jim Haas with a group of other session singers for the first two seasons.[4] These versions of the song were used only during the closing credits of Seasons 1 and 2, with an updated version of ""Rock Around the Clock...",6fe91382-54e3-434c-ae64-ddd2b7b04072
5,Where were the first dinosaur bones discovered?,History of paleontology,english,"{'end_positions': [-1], 'passage_indices': [-1], 'start_positions': [-1], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/History%20of%20paleontology,\n\n\n\n\n\n\n\n\nPart of a series onPaleontology\nFossils\nFossilization\nTrace fossil\nIndex fossil\nList of fossils\nList of fossil sites\nLagerstätte fossil beds\nList of transitional fossils\nList of human evolution fossils\nNatural history\nBiogeography\nExtinction event\nGeochronology\nGeologic time scale\nGeologic record\nHistory of life\nOrigin of life\nTimeline of evolution\nTransitional fossil\nOrgans and processes\nAvian flight\nCells\nMulticells\nEyes\nFlagella\nHair\nMammalian auditory ossicles\nMosaic evolution\nNe...,a4740c9f-5e44-4d83-8299-fa95e6dbbcd8
6,Where was Doris Hursley born?,Hursley (disambiguation),english,"{'end_positions': [-1], 'passage_indices': [-1], 'start_positions': [-1], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Hursley%20%28disambiguation%29,"Hursley is a village in Hampshire, England.\nHursley may also refer to:\nHMS Hursley (L84), a Second World War escort destroyer\nHursley House, a mansion in Hursley, Hampshire, England\nPeople\nFrank and Doris Hursley (1902–1989 and 1898–1984), husband-and-wife team who wrote American serials\nFrank Hursley (1902–1989), American soap opera writer\nJoe Hursley (born 1979), actor and musician living in California, US\n",4be67a13-c8ba-42ee-8e09-60de61b2f87c
7,When is the dialectical method used?,Dialectic,english,"{'end_positions': [277], 'passage_indices': [0], 'start_positions': [130], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Dialectic,"\nDialectic or dialectics (Greek: διαλεκτική, dialektikḗ; related to dialogue), also known as the dialectical method, is at base a discourse between two or more people holding different points of view about a subject but wishing to establish the truth through reasoned arguments. Dialectic resembles debate, but the concept excludes subjective elements such as emotional appeal and the modern pejorative sense of rhetoric.[1][2] Dialectic may be contrasted with the didactic method, wherein one sid...",35c7165e-38fd-4b7d-8f3f-399adff7cb9e
8,Do The Rough Riders have a special patch?,Rough Riders,english,"{'end_positions': [-1], 'passage_indices': [-1], 'start_positions': [-1], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/Rough%20Riders,"\n\n\nThe Rough Riders was a nickname given to the 1st United States Volunteer Cavalry, one of three such regiments raised in 1898 for the Spanish–American War and the only one to see action. The United States Army was small and understaffed in comparison to its status during the American Civil War roughly thirty years prior. As a measure towards rectifying this situation President William McKinley called upon 125,000 volunteers to assist in the war efforts.[1] The regiment was also called ""Wood...",bafc6dd1-75c9-415c-af79-827ca4de4e30
9,What are the ratings for Sábado Gigante?,Súper Sábado Sensacional,english,"{'end_positions': [-1], 'passage_indices': [-1], 'start_positions': [-1], 'yes_no_answer': ['NONE']}",https://en.wikipedia.org/wiki/S%C3%BAper%20S%C3%A1bado%20Sensacional,"Main Page\nSúper Sábado Sensacional (originally named Sábado Espectacular in 1968, renamed Sábado Sensacional in 1971) is a Spanish-language variety show created in Venezuela, and established on Radio Caracas Television in 1968. The show later moved to Venevisión network in 1971. Shown on a weekly basis, every Saturday from 3:00 pm to 8:00 pm (sometimes longer during special occasions) it is viewed internationally throughout Latin America, the Caribbean and the United States and it is consider...",a46ebf40-957e-47d8-a5cc-f65f6252a3a4


In [9]:
from primeqa.mrc.data_models.target_type import TargetType

def target_type_as_str(feature):
    feature['target_type'] = TargetType(feature['target_type']).name
    return feature

random_train_dataset = train_dataset.filter(lambda feature: feature['example_idx'] in random_idxs).remove_columns(['attention_mask', 'offset_mapping'])
show_elements(random_train_dataset.map(target_type_as_str))  # Show random train features

100%|██████████| 1/1 [00:00<00:00, 12.44ba/s]
100%|██████████| 7/7 [00:00<00:00, 5094.59ex/s]


Unnamed: 0,example_id,input_ids,example_idx,start_positions,end_positions,target_type
0,35c7165e-38fd-4b7d-8f3f-399adff7cb9e,"[0, 14847, 83, 70, 220734, 21533, 55300, 11814, 32, 2, 2, 4512, 133, 49086, 707, 220734, 28021, 15, 91127, 343, 12, 4437, 12596, 140157, 4, 139581, 783, 3, 74, 62548, 47, 144483, 247, 2843, 51529, 237, 70, 220734, 21533, 55300, 4, 83, 99, 3647, 10, 189413, 13, 17721, 6626, 707, 1286, 3395, 104064, 12921, 26847, 111, 21455, 1672, 10, 28368, 1284, 32599, 214, 47, 137633, 70, 85027, 8305, 31635, 297, 10750, 7, 5, 4512, 133, 49086, 3332, 195, 13566, 29865, 4, 1284, 70, 23755, 39041, 988, 28368, 5844, 80854, 6044, 237, 88965, 149528, 136, 70, 5744, 280, 15503, 45023, 10422, ...]",21,45,71,SPAN_ANSWER
1,35c7165e-38fd-4b7d-8f3f-399adff7cb9e,"[0, 14847, 83, 70, 220734, 21533, 55300, 11814, 32, 2, 2, 7612, 67, 4, 42459, 7, 10, 1238, 19729, 4, 707, 95134, 142, 82940, 46485, 7432, 4, 23, 2499, 61475, 159688, 2451, 217, 83, 5792, 164789, 136, 21, 22824, 70, 40907, 111, 110324, 4, 18499, 4, 136, 16981, 5, 1326, 1529, 2679, 4, 70, 122776, 4, 70, 142518, 90, 164, 4, 70, 89931, 4, 8110, 11343, 27875, 8305, 70, 93402, 111, 70, 40907, 4, 23, 70, 120696, 47, 21721, 1830, 4, 450, 83, 4, 2450, 1363, 5, 3293, 83, 70, 3533, 3956, 111, 2367, 83, 5700, 538, 35839, 1529, 83331, ...]",21,0,0,NO_ANSWER
2,aa18b029-01bc-4725-8d0a-5b800034534e,"[0, 14847, 3542, 57571, 145743, 933, 24, 4806, 70, 11341, 6, 132641, 111, 31464, 32, 2, 2, 22928, 145743, 18, 83, 10, 9351, 34475, 47, 2499, 14012, 111, 57571, 9, 132641, 297, 114149, 111, 70, 107396, 104702, 44297, 156531, 660, 538, 14037, 23, 127067, 1177, 48850, 14098, 46684, 136, 83, 143849, 538, 70, 11341, 6, 132641, 111, 31464, 5, 581, 115700, 111, 70, 280, 60380, 98, 70, 6, 132641, 3332, 195, 13566, 70, 18414, 18, 6, 23432, 19, 390, 53918, 11226, 24793, 47, 6, 221292, 2856, 1295, 70, 4262, 25432, 24990, 3387, 27983, 35839, 57571, 145743, 933, 26698, 12, 104702, ...]",29,181,184,SPAN_ANSWER
3,d235a8a4-7374-4e2d-8685-bfa1e51939d9,"[0, 4171, 70, 31455, 62816, 36151, 14775, 12403, 32, 2, 2, 36151, 14775, 83, 10, 127663, 289, 1601, 90865, 108975, 214, 23, 15672, 131259, 42840, 91376, 390, 31455, 111321, 7, 5, 581, 62816, 1556, 118775, 23, 67842, 23, 107032, 5256, 4, 678, 70, 36151, 14775, 182324, 2684, 197097, 538, 8035, 30666, 297, 390, 67468, 155438, 1681, 3229, 764, 109133, 98, 1295, 1919, 31486, 237, 82630, 25, 7, 86433, 1479, 4755, 65810, 5, 106073, 36151, 14775, 83, 39210, 538, 137272, 678, 82630, 4, 70, 44759, 136, 23755, 765, 59665, 7, 23, 54704, 183497, 43515, 5, 581, 7311, 36151, 14775, 23, 31455, ...]",74,0,0,YES
4,d235a8a4-7374-4e2d-8685-bfa1e51939d9,"[0, 4171, 70, 31455, 62816, 36151, 14775, 12403, 32, 2, 2, 289, 62816, 333, 87168, 1914, 9, 441, 3679, 164, 14825, 183497, 36151, 14775, 83, 5117, 8, 18695, 3674, 23, 70, 13765, 44, 73903, 669, 23, 2734, 1846, 58, 23, 183497, 468, 137197, 15, 67884, 2240, 53, 36102, 194, 1650, 83, 142, 55109, 11814, 390, 183497, 23, 479, 9, 441, 3679, 164, 43515, 5, 581, 13765, 83, 5423, 23, 2734, 1846, 4, 10, 35758, 40934, 72173, 26349, 450, 509, 90978, 3678, 33, 136, 9498, 56, 4126, 23, 10, 144521, 390, 6163, 943, 2263, 5, 360, 2734, 1846, 4, 183497, 1556, ...]",74,0,0,YES
5,6fe91382-54e3-434c-ae64-ddd2b7b04072,"[0, 40469, 54397, 70, 11531, 44, 184870, 97292, 38843, 2, 2, 44, 184870, 97292, 58, 83, 10, 11531, 59121, 390, 111413, 92150, 4063, 136, 28166, 49049, 5, 1650, 83, 70, 73986, 11531, 111, 70, 19340, 7, 113976, 36549, 32506, 97292, 71540, 1650, 831, 186, 49782, 20271, 70, 1910, 7639, 25, 7, 73432, 136, 20450, 6953, 22299, 7, 237, 442, 127877, 23, 155241, 141, 456, 16428, 226406, 1830, 5, 581, 11531, 509, 5117, 17164, 297, 23, 27898, 390, 41994, 1391, 162, 678, 10, 21115, 111, 3789, 56002, 5367, 1314, 100, 70, 5117, 6626, 34003, 7, 105977, 32255, 11389, 7, 111, 70, ...]",92,20,25,SPAN_ANSWER
6,f799851c-e506-4366-9650-78aaa0ce4e8a,"[0, 11249, 5941, 3395, 68, 71, 20271, 6, 95162, 418, 32, 2, 2, 581, 3622, 14012, 111, 116338, 136, 117907, 66, 63044, 2449, 23, 6661, 5550, 87, 3542, 1672, 1112, 19879, 12, 25902, 1636, 37457, 1295, 423, 47, 953, 39, 96222, 47219, 7, 136, 1672, 1105, 39, 96222, 148, 167457, 116338, 35768, 4, 77918, 442, 54940, 70, 103494, 150, 525, 79612, 7, 23, 14135, 32692, 5, 581, 3622, 14012, 111, 47219, 7, 96853, 1295, 483, 47, 534, 19879, 116338, 35768, 5, 581, 117907, 66, 47219, 43584, 509, 1672, 382, 19879, 4, 26719, 1672, 305, 19879, 4743, 47, 1631, 9, 174822, ...]",99,29,30,SPAN_ANSWER


# Fine-tuning

Here we fine-tune the model on the training set.

In [10]:
from operator import attrgetter
from transformers import DataCollatorWithPadding
from primeqa.mrc.data_models.eval_prediction_with_processing import EvalPredictionWithProcessing
from primeqa.mrc.metrics.tydi_f1.tydi_f1 import TyDiF1
from primeqa.mrc.processors.postprocessors.extractive import ExtractivePostProcessor
from primeqa.mrc.processors.postprocessors.scorers import SupportedSpanScorers

# If using mixed precision we pad for efficient hardware acceleration
using_mixed_precision = any(attrgetter('fp16', 'bf16')(training_args))
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=64 if using_mixed_precision else None)

# noinspection PyProtectedMember
postprocessor = ExtractivePostProcessor(
    k=3,
    n_best_size=20,
    max_answer_length=30,
    scorer_type=SupportedSpanScorers.WEIGHTED_SUM_TARGET_TYPE_AND_SCORE_DIFF,
    single_context_multiple_passages=preprocessor._single_context_multiple_passages,
)

def compute_metrics(p: EvalPredictionWithProcessing):
    return TyDiF1().compute(predictions=p.processed_predictions, references=p.label_ids)

trainer = MRCTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset if training_args.do_train else None,
    eval_dataset=eval_dataset if training_args.do_eval else None,
    eval_examples=eval_examples if training_args.do_eval else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
    post_process_function=postprocessor.process_references_and_predictions,  # see QATrainer in Huggingface
    compute_metrics=compute_metrics,
)

train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics
max_train_samples = max_train_samples or len(train_dataset)
metrics["train_samples"] = min(max_train_samples, len(train_dataset))

trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

{"time":"2024-05-08 16:27:14,716", "name": "primeqa.mrc.trainers.mrc", "level": "INFO", "message": "The following columns in the training set  don't have a corresponding argument in `XLMRobertaModelForDownstreamTasks.forward` and have been ignored: offset_mapping, example_id, example_idx."}


***** Running training *****
  Num examples = 92
  Num Epochs = 1
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 8
  Total optimization steps = 1
  Number of trainable parameters = 279481753
100%|██████████| 1/1 [00:46<00:00, 46.97s/it]

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 1/1 [00:46<00:00, 46.97s/it]
Saving model checkpoint to /Users/maltak/code/PQA_models


{'train_runtime': 47.0294, 'train_samples_per_second': 1.956, 'train_steps_per_second': 0.021, 'train_loss': 4.743960857391357, 'epoch': 0.67}


Model weights saved in /Users/maltak/code/PQA_models/pytorch_model.bin


***** train metrics *****
  epoch                    =       0.67
  train_loss               =      4.744
  train_runtime            = 0:00:47.02
  train_samples            =         92
  train_samples_per_second =      1.956
  train_steps_per_second   =      0.021


# Evaluation

Here we evaluate the model on the validation set.

In [11]:
metrics = trainer.evaluate()

max_eval_samples = max_eval_samples or len(eval_dataset)
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

{"time":"2024-05-08 16:28:41,529", "name": "primeqa.mrc.trainers.mrc", "level": "INFO", "message": "The following columns in the evaluation set  don't have a corresponding argument in `XLMRobertaModelForDownstreamTasks.forward` and have been ignored: offset_mapping, example_id, example_idx."}


***** Running Evaluation *****
  Num examples = 415
  Batch size = 8
100%|██████████| 20/20 [00:00<00:00, 33.08it/s]
100%|██████████| 52/52 [01:11<00:00,  1.37s/it]

Passage & english & \fpr{8.3}{7.1}{10.0}
Minimal Answer & english & \fpr{0.0}{0.0}{0.0}
********************
english
Language: english (20)
********************
PASSAGE ANSWER R@P TABLE:
Optimal threshold: 0.273
 F1     /  P      /  R
  8.33% /   7.14% /  10.00%
R@P=0.5: 0.00% (actual p=0.00%, score threshold=0.0)
R@P=0.75: 0.00% (actual p=0.00%, score threshold=0.0)
R@P=0.9: 0.00% (actual p=0.00%, score threshold=0.0)
********************
MINIMAL ANSWER R@P TABLE:
Optimal threshold: 0.0
 F1     /  P      /  R
  0.00% /   0.00% /   0.00%
R@P=0.5: 0.00% (actual p=0.00%, score threshold=0.0)
R@P=0.75: 0.00% (actual p=0.00%, score threshold=0.0)
R@P=0.9: 0.00% (actual p=0.00%, score threshold=0.0)
Total # examples in gold: 20, # ex. in pred: 20 (including english)
*** Macro Over 0 Languages, excluding English **
Passage F1:0.000 P:0.000 R:0.000000
\fpr{0.0}{0.0}{0.0}
Minimal F1:0.000 P:0.000 R:0.000000
\fpr{0.0}{0.0}{0.0}
*** / Aggregate Scores ****
{"avg_passage_f1": 0, "avg_passage_reca




# Predictions

Here we examine the model predictions.

In [12]:
import json
import os
from pprint import pprint

with open(os.path.join(output_dir, 'eval_predictions.json'), 'r') as f:
    predictions = json.load(f)

pprint(predictions)

{'0740df34-c6c2-45fa-a871-3b12d6eead8d': [{'cls_score': -0.4955431669950485,
                                           'confidence_score': 0.3371063425565338,
                                           'end_index': 289,
                                           'end_logit': 0.060142915695905685,
                                           'end_stdev': 0.0,
                                           'example_id': '0740df34-c6c2-45fa-a871-3b12d6eead8d',
                                           'normalized_span_answer_score': 0.3371063425565338,
                                           'passage_index': -1,
                                           'query_passage_similarity': 0.0,
                                           'span_answer': {'end_position': 9638,
                                                           'start_position': 9550},
                                           'span_answer_score': 0.35771970078349113,
                                           'span_answer_te