# NLP Tasks (Part 2)

Question answering task

| Date | User | Change Type | Remarks |  
| ---- | ---- | ----------- | ------- |
| 30/12/2025   | Martin | Created   | Notebook for Question-answering task | 
| 08/01/2025   | Martin | Update   | Completed Question-answering task | 
| 09/01/2025   | Martin | Update   | Added note on data collators | 

# Content

* [Introduction](#introduction)
* [A Note on Data Collators](#note-on-data-collators)

# Introduction

Extracative Question Answering: Finding the answer of a question from provided text

- Dataset: Question and Answers from `SQuAD`
- Training only has a single answer per question, but evaluation has multiple answers
- Model is trained to predict the __start and end__ logit per token in the input context

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer

In [2]:
raw_dataset = load_dataset("squad")
raw_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [3]:
print("Context: ", raw_dataset["train"][0]["context"])
print("Question: ", raw_dataset["train"][0]["question"])
print("Answer: ", raw_dataset["train"][0]["answers"])

Context:  Architecturally, the school has a Catholic character. Atop the Main Building's gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.
Question:  To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
Answer:  {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}


Define tokenizer. Sentences will be in the format:

> [CLS] question [SEP] context [SEP]

Truncate the context to limit each data point length. Will create multiple samples from a single question-answer pair. Some tokens will overflow into the next example.

- Will create additional entries for those that overflow
- For those questions where the answer is not contained, predict 0 for start and end logits

In [4]:
model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.is_fast

True

In [5]:
context = raw_dataset["train"][0]["context"]
question = raw_dataset["train"][0]["question"]

inputs = tokenizer(question, context)
print(tokenizer.decode(inputs["input_ids"]))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]


In [6]:
inputs = tokenizer(
  question,
  context,
  max_length=100,                 # Maximum length of context
  truncation='only_second',       # Context is in the second position of example
  stride=50,                      # Number of tokens to overlap from previous truncation
  return_overflowing_tokens=True, # Indicate to take overflowing tokens
  return_offsets_mapping=True
)

for ids in inputs['input_ids']:
  print(tokenizer.decode(ids))

[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the

In [7]:
inputs.keys()

KeysView({'input_ids': [[101, 1706, 2292, 1225, 1103, 6567, 2090, 9273, 2845, 1107, 8109, 1107, 10111, 20500, 1699, 136, 102, 22182, 1193, 117, 1103, 1278, 1144, 170, 2336, 1959, 119, 1335, 4184, 1103, 4304, 4334, 112, 188, 2284, 10945, 1110, 170, 5404, 5921, 1104, 1103, 6567, 2090, 119, 13301, 1107, 1524, 1104, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 1103, 171, 17506, 102], [101, 1706, 2292, 1225, 1103, 6567, 2090, 9273, 2845, 1107, 8109, 1107, 10111, 20500, 1699, 136, 102, 1103, 4304, 4334, 1105, 4749, 1122, 117, 1110, 170, 7335, 5921, 1104, 4028, 1114, 1739, 1146, 14089, 5591, 1114, 1103, 7051, 107, 159, 21462, 1566, 24930, 2508, 152, 1306, 3965, 107, 119, 5893, 1106, 1103, 4304, 4334, 1110, 1103, 19349, 1104, 1103, 11373, 4641, 119, 13301, 1481, 

In [8]:
inputs = tokenizer(
  raw_dataset["train"][2:6]["question"],
  raw_dataset["train"][2:6]["context"],
  max_length=100,
  truncation="only_second",
  stride=50,
  return_overflowing_tokens=True,
  return_offsets_mapping=True,
)

In [14]:
for k in inputs.keys():
  print(k)

input_ids
token_type_ids
attention_mask
offset_mapping
overflow_to_sample_mapping


In [10]:
for ids, mapping in zip(inputs['input_ids'], inputs['overflow_to_sample_mapping']):
  print(tokenizer.decode(ids))
  print(mapping)

[CLS] The Basilica of the Sacred heart at Notre Dame is beside to which structure? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building ' s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]
0
[CLS] The Basilica of the Sacred heart at Notre Dame is beside to which structure? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend " Venite Ad Me Omnes ". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]
0
[CLS] The Basilica of the Sacred heart at Notre Dame is beside to which structure? [SEP] Next to the M

Labels will be of the format: `(start_position, end_position)`

- Which are tuples of 2 integers representing the span of characters inside the original context
- If it appears before the context, then return `(0, 0)`
- Else loop to find the first and last token in the context

In [16]:
answers

[{'text': ['the Main Building'], 'answer_start': [279]},
 {'text': ['a Marian place of prayer and reflection'], 'answer_start': [381]},
 {'text': ['a golden statue of the Virgin Mary'], 'answer_start': [92]},
 {'text': ['September 1876'], 'answer_start': [248]}]

In [None]:
# Searching for answers within the training data
answers = raw_dataset['train'][2:6]['answers']
start_positions = []
end_positions = []

for i, offset in enumerate(inputs['offset_mapping']):
  answers_offset = inputs['overflow_to_sample_mapping'][i]
  answer = answers[answers_offset]
  start_char = answer['answer_start'][0]
  end_char = answer['answer_start'][0] + len(answer['text'][0])
  sequence_ids = inputs.sequence_ids(i)

  # Find the start and end of context
  cont_idx = 0
  while sequence_ids[cont_idx] != 1:
    cont_idx += 1
  start_cont_idx = cont_idx
  while sequence_ids[cont_idx] == 1:
    cont_idx += 1
  end_cont_idx = cont_idx - 1

  # Check if the answer is inside the context or not
  if offset[start_cont_idx][0] > start_char or offset[end_cont_idx][1] < end_char:
    start_positions.append(0)
    end_positions.append(0)
  else:
    idx = start_cont_idx
    while idx <= end_cont_idx and offset[idx][0] <= start_char:
      idx += 1
    start_positions.append(idx - 1)

    idx = end_cont_idx
    while idx >= start_cont_idx and offset[idx][1] >= end_char:
      idx -= 1
    end_positions.append(idx + 1)

inputs['overflow_to_sample_mapping'], start_positions, end_positions

([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3],
 [83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0],
 [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0])

In [31]:
# Check results to verify apporach
idx = 0
sample_idx = inputs['overflow_to_sample_mapping'][idx]
answer = answers[sample_idx]['text'][0]

start = start_positions[idx]
end = end_positions[idx]
labeled_answer = tokenizer.decode(inputs['input_ids'][idx][start : end + 1])
print(f"Theoretical answer: {answer} | Labels give: {labeled_answer}")

Theoretical answer: the Main Building | Labels give: the Main Building


In [101]:
inputs['overflow_to_sample_mapping']

[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3]

In [37]:
MAX_LEN = 384
STRIDE = 128

def preprocess_training_examples(examples):
  questions = [q.strip() for q in examples['question']]
  inputs = tokenizer(
    questions,
    examples['context'],
    max_length=MAX_LEN,
    truncation="only_second",
    stride=STRIDE,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length"
  )

  offset_mapping = inputs.pop('offset_mapping')
  sample_map = inputs.pop('overflow_to_sample_mapping')
  answers = examples['answers']
  start_positions = []
  end_positions = []

  for i, offset in enumerate(offset_mapping):
    sample_idx = sample_map[i]
    answer = answers[sample_idx]
    start_char = answer['answer_start'][0]
    end_char = answer['answer_start'][0] + len(answer['text'][0])
    sequence_ids = inputs.sequence_ids(i)

    # Find the start and end of context
    cont_idx = 0
    while sequence_ids[cont_idx] != 1:
      cont_idx += 1
    start_cont_idx = cont_idx
    while sequence_ids[cont_idx] == 1:
      cont_idx += 1
    end_cont_idx = cont_idx - 1

    # Check if the answer is inside the context or not
    if offset[start_cont_idx][0] > start_char or offset[end_cont_idx][1] < end_char:
      start_positions.append(0)
      end_positions.append(0)
    else:
      idx = start_cont_idx
      while idx <= end_cont_idx and offset[idx][0] <= start_char:
        idx += 1
      start_positions.append(idx - 1)

      idx = end_cont_idx
      while idx >= start_cont_idx and offset[idx][1] >= end_char:
        idx -= 1
      end_positions.append(idx + 1)
  
  inputs['start_positions'] = start_positions
  inputs['end_positions'] = end_positions

  return inputs

In [None]:
train_dataset = raw_dataset['train'].map(
  preprocess_training_examples,
  batched=True,
  remove_columns=raw_dataset['train'].column_names
)

len(train_dataset), len(raw_dataset['train'])

(88729, 87599)

Processing validation data

Ignore "answers" from the question by setting the offset for questions to be `None`

In [44]:
def preprocess_validation_examples(examples):
  questions = [q.strip() for q in examples["question"]]
  inputs = tokenizer(
    questions,
    examples["context"],
    max_length=MAX_LEN,
    truncation="only_second",
    stride=STRIDE,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  sample_map = inputs.pop("overflow_to_sample_mapping")
  example_ids = []

  for i in range(len(inputs["input_ids"])):
    sample_idx = sample_map[i]
    example_ids.append(examples["id"][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    inputs["offset_mapping"][i] = [
      o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
    ]

  inputs["example_id"] = example_ids

  return inputs

In [46]:
validation_dataset = raw_dataset["validation"].map(
  preprocess_validation_examples,
  batched=True,
  remove_columns=raw_dataset["validation"].column_names,
)
len(raw_dataset["validation"]), len(validation_dataset)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

(10570, 10822)

## Finetuning with Trainer

- No data collator to define since the sequence length is padded to the max
- Post-process the model predictions into spans of text in the original examples
- Use `compute_metric()` to measure performance
- Model outputs __logits__ for the start and end position of the answer for each input
  * Mask all words that are not within the context of the current input
  * Only score the top $n$ start and end tokens
  * Take product of logits between both tokens (log scale, so sum instead)

### Sample evaluation

Sample of evaluation using another model - `distilbert-base-cased-distilled-squad`

In [61]:
small_eval_set = raw_dataset['validation'].select(range(100))
trained_checkpoint = "distilbert-base-cased-distilled-squad"

tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)
eval_set = small_eval_set.map(
  preprocess_validation_examples,
  batched=True,
  remove_columns=raw_dataset['validation'].column_names
)

# Reset tokenizer back to original
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [70]:
len(eval_set_for_model['input_ids'])

100

In [None]:
import torch
import collections
import numpy as np
import evaluate
from transformers import AutoModelForQuestionAnswering, TrainingArguments, Trainer
from tqdm.auto import tqdm

# Move data to GPU
eval_set_for_model = eval_set.remove_columns(['example_id', 'offset_mapping'])
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
eval_set_for_model.set_format('torch', device=device)

# Convert Columns to tensors and move model to GPU
batch = {k: eval_set_for_model[k][:] for k in eval_set_for_model.column_names}
trained_model = AutoModelForQuestionAnswering.from_pretrained(trained_checkpoint).to(
  device
)

# Make predictions with trained model
with torch.no_grad():
  outputs = trained_model(**batch)

# Get the start and end logit positions
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()

# Map each small_eval_set to the corresponding entry in the eval_set
example_to_features = collections.defaultdict(list)
for idx, feature in enumerate(eval_set):
  example_to_features[feature['example_id']].append(idx)

Loop through all examples and their associated features. Ignore positions that:

- Answers that are not in the context
- Answers with negative length
- Answer that is too long

In [None]:
N_BEST = 20      # Take the top 20 logit scores
MAX_ANS_LEN = 30 # Maximum answer length
predicted_answers = []

for example in small_eval_set:
  example_id = example['id']
  context = example['context']
  answers = [] # Multiple answers since each question is broken up into many features

  for feature_index in example_to_features[example_id]:
    start_logit = start_logits[feature_index]
    end_logit = end_logits[feature_index]
    offsets = eval_set['offset_mapping'][feature_index]

    # Get the index of the top 20 logit values
    start_indexes = np.argsort(start_logit)[-1:-N_BEST-1:-1].tolist()
    end_indexes = np.argsort(end_logit)[-1:-N_BEST-1:-1].tolist()
    for start_index in start_indexes:
      for end_index in end_indexes:
        # Check 1: Ignore if index is not in context
        if offsets[start_index] is None or offsets[end_index] is None:
          continue

        # Check 2: Skip answers that are negative length or above max length
        if (
          end_index < start_index
          or end_index - start_index + 1 > MAX_ANS_LEN
        ):
          continue
        
        answers.append(
          {
            "text": context[offsets[start_index][0]:offsets[end_index][1]],
            "logit_score": start_logit[start_index] + end_logit[end_index]
          }
        )
  
  best_answer = max(answers, key=lambda x: x['logit_score'])
  predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})

In [119]:
predicted_answers[:10]

[{'id': '56be4db0acb8001400a502ec', 'prediction_text': 'Denver Broncos'},
 {'id': '56be4db0acb8001400a502ed', 'prediction_text': 'Carolina Panthers'},
 {'id': '56be4db0acb8001400a502ee',
  'prediction_text': "Levi's Stadium in the San Francisco Bay Area at Santa Clara, California"},
 {'id': '56be4db0acb8001400a502ef', 'prediction_text': 'Carolina Panthers'},
 {'id': '56be4db0acb8001400a502f0', 'prediction_text': 'gold'},
 {'id': '56be8e613aeaaa14008c90d1', 'prediction_text': 'golden anniversary'},
 {'id': '56be8e613aeaaa14008c90d2', 'prediction_text': 'February 7, 2016'},
 {'id': '56be8e613aeaaa14008c90d3',
  'prediction_text': 'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference'},
 {'id': '56bea9923aeaaa14008c91b9', 'prediction_text': 'golden anniversary'},
 {'id': '56bea9923aeaaa14008c91ba',
  'prediction_text': 'American Football Conference'}]

In [122]:
# Evaluate the answers
metric = evaluate.load("squad")

# Predicted answers must be in the format
# id: example_id
# answers: text answer
theoretical_answers = [
  {
    "id": ex['id'],
    "answers": ex['answers']
  } for ex in small_eval_set
]

print(theoretical_answers[0])
print(predicted_answers[0])

Downloading builder script: 0.00B [00:00, ?B/s]

Downloading extra modules: 0.00B [00:00, ?B/s]

{'id': '56be4db0acb8001400a502ec', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}}
{'id': '56be4db0acb8001400a502ec', 'prediction_text': 'Denver Broncos'}


In [123]:
metric.compute(predictions=predicted_answers, references=theoretical_answers)

{'exact_match': 83.0, 'f1': 88.25000000000004}

### Finetuning

`compute_metrics()` function requires additional information for offsets and example ids. So it's only present at the end of the training loop to check for results

In [137]:
def compute_metrics(start_logits, end_logits, features, examples):
  # Map example_id to features
  example_to_features = collections.defaultdict(list)
  for idx, feature in enumerate(features):
    example_to_features[feature['example_id']].append(idx)

  predicted_answers = []
  for example in tqdm(examples):
    example_id = example['id']
    context = example['context']
    answers = [] # Multiple answers since each question is broken up into many features

    for feature_index in example_to_features[example_id]:
      start_logit = start_logits[feature_index]
      end_logit = end_logits[feature_index]
      offsets = eval_set['offset_mapping'][feature_index]

      # Get the index of the top 20 logit values
      start_indexes = np.argsort(start_logit)[-1:-N_BEST-1:-1].tolist()
      end_indexes = np.argsort(end_logit)[-1:-N_BEST-1:-1].tolist()
      for start_index in start_indexes:
        for end_index in end_indexes:
          # Check 1: Ignore if index is not in context
          if offsets[start_index] is None or offsets[end_index] is None:
            continue

          # Check 2: Skip answers that are negative length or above max length
          if (
            end_index < start_index
            or end_index - start_index + 1 > MAX_ANS_LEN
          ):
            continue
          
          answers.append(
            {
              "text": context[offsets[start_index][0]:offsets[end_index][1]],
              "logit_score": start_logit[start_index] + end_logit[end_index]
            }
          )
    
    if len(answers) > 0:
      best_answer = max(answers, key=lambda x: x['logit_score'])
      predicted_answers.append({"id": example_id, "prediction_text": best_answer["text"]})
    else:
      predicted_answers.append({"id": example_id, "prediction_text": ""})
  
  theoretical_answers = [
    {
      "id": ex['id'],
      "answers": ex['answers']
    } for ex in examples
  ]

  return metric.compute(predictions=predicted_answers, references=theoretical_answers)

compute_metrics(start_logits, end_logits, eval_set, small_eval_set)

  0%|          | 0/100 [00:00<?, ?it/s]

{'exact_match': 83.0, 'f1': 88.25000000000004}

Training loop begins here

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

args = TrainingArguments(
  "bert-finetuned-squad",
  evaluation_strategy="no",
  save_strategy="epoch",
  learning_rate=2e-5,
  num_train_epochs=3,
  weight_decay=0.01,
  fp16=True,
)

trainer = Trainer(
  model=model,
  args=args,
  train_dataset=train_dataset,
  eval_dataset=validation_dataset,
  tokenizer=tokenizer
)

trainer.train()

In [None]:
# Evaluation only happens at the end because of how the Trainer class is structured
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions
compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets['validation'])

## Pytorch implementation

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import default_data_collator, get_scheduler
from accelerate import Accelerator
from tqdm.auto import tqdm

In [None]:
train_dataset.set_format('torch')
validation_set = validation_dataset.remove_columns(["example_id", "offset_mapping"])
validation_set.set_format("torch")

train_dataloader = DataLoader(
  train_dataset,
  shuffle=True,
  collate_fn=default_data_collator,
  batch_size=8
)

eval_dataloader = DataLoader(
  validation_set,
  collate_fn=default_data_collator,
  batch_size=8
)

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

optimizer = AdamW(model.parameters(), lr=2e-5)

accelerator = Accelerator()
model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
  model, optimizer, train_dataloader, eval_dataloader
)

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
  'linear',
  optimizer=optimizer,
  num_warmup_steps=0,
  num_training_steps=num_training_steps
)

In [None]:
# Training loop
progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
  # Training
  model.train()
  for step, batch in enumerate(train_dataloader):
    outputs = model(**batch)
    loss = outputs.loss
    accelerator.backward(loss)

    optimizer.step()
    lr_scheduler.step()
    optimizer.zero_grad()
    progress_bar.update(1)
  
  # Evaluation
  model.eval()
  start_logits = []
  end_logits = []
  accelerator.print(f"Running evaluation on epoch: {epoch}")
  for batch in tqdm(eval_dataloader):
    with torch.no_grad():
      outputs = model(**batch)
    
    start_logits.append(accelerator.gather(outputs.start_logits).cpu().numpy())
    end_logits.append(accelerator.gather(outputs.end_logits).cpu().numpy())

  start_logits = np.concatenate(start_logits)
  end_logits = np.concatenate(end_logits)
  # Truncate to match the validation datasets length
  start_logits = start_logits[: len(validation_dataset)]
  end_logits = end_logits[: len(validation_dataset)]

  metrics = compute_metrics(
    start_logits, end_logits, validation_dataset, raw_datasets['validation']
  )
  print(f"Epoch: {epoch}:", metrics)

  # # Save and upload
  # accelerator.wait_for_everyone()
  # unwrapped_model = accelerator.unwrap_model(model)
  # unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
  # if accelerator.is_main_process:
  #   tokenizer.save_pretrained(output_dir)
  #   repo.push_to_hub(
  #     commit_message=f"Training in progress epoch {epoch}", blocking=False
  #   )



In [None]:
from transformers import pipeline

# Inference
question_answerer = pipeline("question-answering", model=model_checkpoint)

context = """
ðŸ¤— Transformers is backed by the three most popular deep learning libraries â€” Jax, PyTorch and TensorFlow â€” with a seamless integration
between them. It's straightforward to train your models with one before loading them for inference with the other.
"""
question = "Which deep learning libraries back ðŸ¤— Transformers?"
question_answerer(question=question, context=context)

---

# Note on Data Collators

Data collators collate lists of smaples into a single minibatch. They perform some additional preprocessing that make batches of matrices

In [None]:
# Data Collators - Pytorch
data_collator = DefaultDataCollator() # return_tensors default to 'pt'
trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=dataset,
  data_collator=data_collator
)

# Tensorflow
data_collator = DefaultDataCollator(
  return_tensors='tf'
)
train_set = dataset.to_tf_dataset(
  columns=['input_ids', 'labels'],
  shuffle=True,
  batch_size=16,
  collate_fn=data_collator
)

In [None]:
# Custom data collators
data_collator = DefaultDataCollator()

# 1. Basic padding - requires specific tokenizer for special padding token
padding = DataCollatorWithPadding(tokenizer=tokenizer)

# 2. Token classification & Seq2Seq - variable label length (both inputs need to be padded)
DataCollatorForTokenClassification(tokenizer=tokenizer)
DataCollatorForSeq2Seq(tokenizer=tokenizer)

# 3. Language modeling
# Causal language modeling
DataCollatorForLanguageModeling(
  tokenizer=tokenizer,
  mlm=False
)
# Masked language modeling
DataCollatorForLanguageModeling(
  tokenizer=tokenizer,
  mlm=True,
  mlm_probability=0.15
)

In [4]:
%load_ext watermark
%watermark

The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Last updated: 2025-12-30T18:04:06.937522+08:00

Python implementation: CPython
Python version       : 3.10.12
IPython version      : 8.37.0

Compiler    : GCC 11.4.0
OS          : Linux
Release     : 6.6.87.2-microsoft-standard-WSL2
Machine     : x86_64
Processor   : x86_64
CPU cores   : 20
Architecture: 64bit

