# Environment setup

The code below installs the necessary packages and configures the random seeds for reproducibility.

In [None]:
!pip install datasets
!pip install transformers
!pip install tqdm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
from transformers import AutoConfig, DistilBertTokenizerFast, DistilBertForQuestionAnswering, get_scheduler
from datasets import load_dataset, load_metric

import torch
from torch.utils.data.dataloader import DataLoader
import numpy as np
import random
from tqdm.auto import tqdm
import os

# we set up some seeds so that we can reproduce results
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


In [None]:
from google.colab import drive
drive.mount('/content/drive')

data_folder = '/content/drive/MyDrive/Colab Notebooks/qa-final-data/'
model_save_path = '/content/drive/MyDrive/Colab Notebooks/qa-model.model'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Data loading & preprocessing


In [None]:
def load_data(dataset_folder):
  '''Loads the data and splits it into a training dataset and a validation dataset.
  * returns: training data
  * returns: validation data'''
  
  data = load_dataset(dataset_folder)

  return data['train'], data['validation']

In [None]:
class PreprocDataset(torch.utils.data.Dataset):
  def __init__(self, raw_data, tokenizer):
    self.raw_data = raw_data
    self.tokenizer = tokenizer
    self.max_length = 384

  def __len__(self):
    return len(self.raw_data)
  
  def __getitem__(self, index):
    question_text = self.raw_data[index]['questions'][0]['input_text']
    document_text = self.raw_data[index]['contexts']

    encoded_input = self.tokenizer(
      text = question_text.lower(),
      text_pair = document_text.lower(),
      add_special_tokens = True,
      return_attention_mask = True,
      return_token_type_ids = False,
      padding = 'max_length',
      max_length = self.max_length,
      truncation = True,
      return_tensors = 'pt',
    )

    # Convert the character span to a token span
    start_char_idx = self.raw_data[index]['answers'][0]['span_start']
    end_char_idx = self.raw_data[index]['answers'][0]['span_end']
    start_tok_idx = encoded_input.char_to_token(start_char_idx, sequence_index=1) or self.max_length - 1
    end_tok_idx = encoded_input.char_to_token(end_char_idx + 1, sequence_index=1) or self.max_length - 1

    return {
      'input_ids': torch.flatten(encoded_input.input_ids),
      'attention_mask': torch.flatten(encoded_input.attention_mask),
      'span_start': torch.tensor(start_tok_idx, dtype=torch.long),
      'span_end': torch.tensor(end_tok_idx, dtype=torch.long),
      'encoded': encoded_input,         # this and below fields used for evaluation only
      'question_text': question_text,
      'document_text': document_text,
    }

def preprocess_and_tokenize(data, tokenizer, batch_size):
  '''Takes in a data stream and applies preprocessing & tokenization to it
  * input: data loader to preprocess & tokenize
  * input: the tokenizer to use
  * returns: preprocessed data loader'''

  preproc_data = PreprocDataset(data, tokenizer)
  return DataLoader(
    dataset=preproc_data,
    batch_size=batch_size,
  )

# Training and evaluation

In [None]:
def load_model(pretrained_model="distilbert-base-uncased"):
  '''Returns the base BERT model we will be finetuning.
  * input: the pretrained model name to load
  * input: the device to put the model on
  * returns: base BERT model (distilbert-base-uncased)
  * returns: tokenizer'''

  tokenizer = DistilBertTokenizerFast.from_pretrained(pretrained_model)
  model = DistilBertForQuestionAnswering.from_pretrained(pretrained_model)

  return model, tokenizer

In [None]:
def get_validation_loss(model, validation_dataloader, device='cpu'):
  '''Computes the validation loss for the model
  * input: the model
  * input: the validation data set
  * returns: the validation loss'''

  with torch.no_grad():
    model.eval()
    batch_validation_losses = []

    for batch in validation_dataloader:
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      start_positions = batch['span_start'].to(device)
      end_positions = batch['span_end'].to(device)

      outputs = model(
        input_ids = input_ids,
        attention_mask = attention_mask,
        start_positions = start_positions,
        end_positions = end_positions
      )
      # Store the batch loss
      batch_validation_losses.append(outputs.loss)

    validation_loss = sum(batch_validation_losses) / len(batch_validation_losses)
    return validation_loss

def tensors_in_order(a, b, c, d):
  '''Returns a column of booleans, with each row True iff a <= b <= c <= d
  in that row.
  * inputs: a, b, c, d are all 2-dimensional tensors
  * returns: a column of boolean values'''

  return torch.logical_and(
    a <= b,
    torch.logical_and(
      b <= c,
      c <= d,
    )
  )

def eval_loop(model, validation_dataloader, device='cpu'):
  '''Computes the precision, recall, and f1 score for the model on the validation data
  * input: the model to evaluate
  * validation_data: the validation to evaluate on
  * returns: validation loss
  * returns: dict of precision, recall, and f1_score on the validation data set'''

  print("Evaluating metrics:")
  progress_bar = tqdm(range(len(validation_dataloader)))

  with torch.no_grad():
    model.eval()

    precisions = []
    recalls = []
    f1_scores = []

    for batch in validation_dataloader:
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      actual_start = batch['span_start'].to(device)
      actual_end = batch['span_end'].to(device)

      outputs = model(
        input_ids = input_ids,
        attention_mask = attention_mask,
        start_positions = actual_start,
        end_positions = actual_end,
      )
      predicted_start = outputs.start_logits.argmax(axis=1)
      predicted_end = outputs.end_logits.argmax(axis=1)
      
      # TODO: how are we ensuring that predicted_start <= predicted_end?
      predicted_start, predicted_end = torch.where(predicted_start <= predicted_end, predicted_start, predicted_end), torch.where(predicted_end >= predicted_start, predicted_end, predicted_start)

      precision = torch.where(
        tensors_in_order(actual_start, predicted_start, predicted_end, actual_end),
        1,
        torch.where(
          tensors_in_order(predicted_start, actual_start, actual_end, predicted_end),
          (actual_end - actual_start) / (predicted_end - predicted_start),
          torch.where(
            tensors_in_order(predicted_start, actual_start, predicted_end, actual_end),
            (predicted_end - actual_start) / (predicted_end - predicted_start),
            torch.where(
              tensors_in_order(actual_start, predicted_start, actual_end, predicted_end),
              (actual_end - predicted_start) / (predicted_end - predicted_start),
              0))))
     
      recall = torch.where(
        tensors_in_order(actual_start, predicted_start, predicted_end, actual_end),
        (predicted_end - predicted_start) / (actual_end - actual_start),
        torch.where(
          tensors_in_order(predicted_start, actual_start, actual_end, predicted_end),
          1,
          torch.where(
            tensors_in_order(predicted_start, actual_start, predicted_end, actual_end),
            (predicted_end - actual_start) / (actual_end - actual_start),
            torch.where(
              tensors_in_order(actual_start, predicted_start, actual_end, predicted_end),
              (actual_end - predicted_start) / (actual_end - actual_start),
              0))))

      f1_score = torch.nan_to_num((2 * precision * recall) / (precision + recall))

      precisions.append(torch.mean(precision))
      recalls.append(torch.mean(recall))
      f1_scores.append(torch.mean(f1_score))
      progress_bar.update(1)

    return {
      'precision': sum(precisions) / len(precisions),
      'recall': sum(recalls) / len(recalls),
      'f1_score': sum(f1_scores) / len(f1_scores),
    }

In [None]:
def train_loop(model, train_dataloader, validation_dataloader, num_epochs, device='cpu', save_path=None):
  '''Trains the model using train_data for num_epochs.
  * input: the model to train
  * input: the training data loader to train on
  * input: the validation data loader to report losses for
  * returns: the trained model
  * returns: a list of training losses, one per epoch
  * returns: a list of validation losses, one per epoch'''

  optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
  lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=50,
    num_training_steps=len(train_dataloader) * num_epochs,
  )

  train_losses = []
  validation_losses = []
  
  for e in range(num_epochs):
    model.train()

    print(f"Epoch {e + 1} training:")
    progress_bar = tqdm(range(len(train_dataloader)))

    batch_train_losses = []

    # Run the epoch training
    for i, batch in enumerate(train_dataloader):
      input_ids = batch['input_ids'].to(device)
      attention_mask = batch['attention_mask'].to(device)
      start_positions = batch['span_start'].to(device)
      end_positions = batch['span_end'].to(device)

      optimizer.zero_grad()

      outputs = model(
        input_ids = input_ids,
        attention_mask = attention_mask,
        start_positions = start_positions,
        end_positions = end_positions,
      )
      
      outputs.loss.backward()
      optimizer.step()
      lr_scheduler.step()
      
      batch_train_losses.append(outputs.loss)
      progress_bar.update(1)
    
    # Save the model checkpoint at this point if necessary
    if save_path is not None:
      torch.save(model.state_dict(), save_path)
    
    # Report the epoch training loss
    epoch_train_loss = sum(batch_train_losses) / len(batch_train_losses)
    train_losses.append(epoch_train_loss)
    print(f"Epoch {e + 1} training loss: {epoch_train_loss:.3f}")

    # Evaluate on the validation set and report the loss
    validation_loss = get_validation_loss(model, validation_dataloader, device=device)
    validation_losses.append(validation_loss)
    print(f"Epoch {e + 1} validation loss: {validation_loss:.3f}")

    # Also report current precision/recall/f1
    metrics = eval_loop(model, validation_dataloader, device=device)
    print(f"Precision: {metrics['precision']:.2f}; Recall: {metrics['recall']:.2f}; F1 score: {metrics['f1_score']:.2f}")

  return model, train_losses, validation_losses

# Running

The code below loads the base BERT model and the dataset, preprocesses the data, and then proceeds to train and evaluate the model.

We run on a batch size of 48. Each epoch takes takes 6-8 minutes to train. Total training time is 30-40 minutes.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 48
num_epochs = 4

model, tokenizer = load_model()
model.to(device)
train, validation = load_data(data_folder)


train_data_loader = preprocess_and_tokenize(train, tokenizer, batch_size)
validation_data_loader = preprocess_and_tokenize(validation, tokenizer, batch_size)

if os.path.exists(model_save_path):
  print("Found saved model, restoring weights")
  model.load_state_dict(torch.load(model_save_path))
else:
  print("No saved model found, training anew.")
  model, train_losses, val_losses = train_loop(model, train_data_loader, validation_data_loader, num_epochs, device=device, save_path=model_save_path)

metrics  = eval_loop(model, validation_data_loader, device=device)

torch.save(model.state_dict(), model_save_path)

print(f"precision: {metrics['precision']:.2f}")
print(f"recall: {metrics['recall']:.2f}")
print(f"f1-score: {metrics['f1_score']:.2f}")

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

  0%|          | 0/2 [00:00<?, ?it/s]

No saved model found, training anew.
Epoch 1 training:


  0%|          | 0/291 [00:00<?, ?it/s]

Epoch 1 training loss: 2.407
Epoch 1 validation loss: 1.628
Evaluating metrics:


  0%|          | 0/19 [00:00<?, ?it/s]

Precision: 0.61; Recall: 0.66; F1 score: 0.58
Epoch 2 training:


  0%|          | 0/291 [00:00<?, ?it/s]

Epoch 2 training loss: 1.164
Epoch 2 validation loss: 1.768
Evaluating metrics:


  0%|          | 0/19 [00:00<?, ?it/s]

Precision: 0.64; Recall: 0.66; F1 score: 0.60
Epoch 3 training:


  0%|          | 0/291 [00:00<?, ?it/s]

Epoch 3 training loss: 0.591
Epoch 3 validation loss: 1.993
Evaluating metrics:


  0%|          | 0/19 [00:00<?, ?it/s]

Precision: 0.65; Recall: 0.68; F1 score: 0.61
Epoch 4 training:


  0%|          | 0/291 [00:00<?, ?it/s]

Epoch 4 training loss: 0.288
Epoch 4 validation loss: 2.288
Evaluating metrics:


  0%|          | 0/19 [00:00<?, ?it/s]

Precision: 0.64; Recall: 0.69; F1 score: 0.61
Evaluating metrics:


  0%|          | 0/19 [00:00<?, ?it/s]

precision: 0.64
recall: 0.69
f1-score: 0.61


# Testing
Below, we implement a few helpers to actually use the model to return a character span given a question and a document to read from.

In [None]:
def get_answer(question, document):
  '''Gets the span that answers the question from the given document
  * input: the question
  * input: the document containing the answer
  * returns: a snippet from the document with the answer'''
  
  # Note: we have an "answers" field here since it's required by the preprocesser, but the values are just dummy values
  encoded_input = tokenizer(
      text = question.lower(),
      text_pair = document.lower(),
      add_special_tokens = True,
      return_attention_mask = True,
      return_token_type_ids = False,
      padding = 'max_length',
      max_length = 384,
      truncation = True,
      return_tensors = 'pt',
  )

  with torch.no_grad():
    model.eval()

    outputs = model(
      input_ids = encoded_input.input_ids.to(device),
      attention_mask = encoded_input.attention_mask.to(device),
    )
    start_tok_idx = torch.argmax(outputs.start_logits, axis=1)[0]
    end_tok_idx = torch.argmax(outputs.end_logits, axis=1)[0]
    
    start_char_idx = encoded_input.token_to_chars(start_tok_idx).start
    end_char_idx = encoded_input.token_to_chars(end_tok_idx - 1).end

    return document[start_char_idx:end_char_idx]

# first paragraph of Brown's wikipedia page
brown_document = "Brown University is a private Ivy League research university in Providence, Rhode Island. Brown is the seventh-oldest institution of higher education in the United States, founded in 1764 as the College in the English Colony of Rhode Island and Providence Plantations. Brown is one of the nine colonial colleges chartered before the American Revolution. Admission at Brown is among the most selective in the United States. In 2022, the university reported a first year acceptance rate of 5%."
print(get_answer("When was Brown founded?", brown_document))
print(get_answer("Where is Brown located?", brown_document))
print(get_answer("What is Brown's acceptance rate?", brown_document))
print(get_answer("What kind of university is Brown?", brown_document))

ri_document = "After the American Revolution, during which it was heavily occupied and contested, Rhode Island became the fourth state to ratify the Articles of Confederation on February 9, 1778. Favoring a weaker central government, it boycotted the 1787 convention that drafted the United States Constitution, which it initially refused to ratify it was the last of the original 13 states to do so, on May 29, 1790."
print(get_answer("When did Rhode Island ratify the Articles?", ri_document))
print(get_answer("Who was the last state to ratify the Constitution?", ri_document))

1764
Providence, Rhode Island
5%
Ivy League research university
February 9, 1778
Rhode Island became the fourth state to ratify the Articles of Confederation on February 9, 1778. Favoring a weaker central government, it boycotted the 1787 convention that drafted the United States Constitution, which it initially refused to ratify it was the last of the original 13 states to do so, on May 29, 1790
