This notebook largely follows the [Hugging Face Question Answering tutorial](https://huggingface.co/learn/nlp-course/chapter7/7?fw=pt), with some adjustments (e.g. using squad-v2 instead of squad)

In [1]:
! pip install torch transformers transformers[torch] datasets evaluate

[0m

In [2]:
import torch
from datasets import load_dataset
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

squad2 = load_dataset("squad_v2")

## Pre-process Training Data

In [3]:
MAX_TOKEN_LENGTH = 512
STRIDE = 128


def preprocess_training_examples(
    examples,
    tokenizer,
    max_token_length=MAX_TOKEN_LENGTH,
    stride=STRIDE,
):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_token_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]

        # If the answer doesn't exist, the label is (0, 0)
        if not answer["answer_start"]:
          start_positions.append(0)
          end_positions.append(0)
          continue

        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

## Pre-process Test Data

In [4]:
def preprocess_validation_examples(
    examples,
    tokenizer,
    max_token_length=MAX_TOKEN_LENGTH,
    stride=STRIDE,
):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_token_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

## Inference

In [5]:
from tqdm.auto import trange


BATCH_SIZE = 700


def do_inference(model, inputs, batch_size=BATCH_SIZE):
    inputs.set_format("torch")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        print("No GPU available, using CPU for inference")
        device = torch.device("cpu")
    
    model.to(device)

    num_examples = len(inputs)
    outputs = None
    for batch_start in trange(0, num_examples, batch_size, desc="Inference"):
      batch_end = min(batch_start+batch_size, num_examples)
      batch_inputs = inputs.select(range(batch_start, batch_end))
      batch_inputs_for_model = {
          k: batch_inputs[k].to(device)
          for k in batch_inputs.column_names
          if k not in ["offset_mapping", "example_id"]
      }

      with torch.no_grad():
        batch_outputs = model(**batch_inputs_for_model)

      # Free memory for inputs
      for v in batch_inputs_for_model.values():
        del v

      if outputs is None:
        outputs = batch_outputs
        for v in outputs.values():
          v.cpu()
      else:
        for k in batch_outputs.keys():
          outputs[k] = torch.cat((outputs[k], batch_outputs[k]), dim=0)

      # Free memory for outputs
      for v in batch_outputs.values():
        del v

    return outputs

## Post-processing

In [6]:
import collections
import numpy as np
from tqdm.auto import tqdm


def post_process(
    start_logits,
    end_logits,
    validation_dataset,
    validation_features,
    n_best=20,
    max_answer_length=30,
):
    validation_features.set_format()
    offset_mapping = validation_features["offset_mapping"]

    example_to_features = collections.defaultdict(list)
    example_ids = validation_features["example_id"]
    for idx, example_id in enumerate(example_ids):
        example_to_features[example_id].append(idx)

    predicted_answers = []
    for example in tqdm(validation_dataset, total=len(validation_dataset), desc="Post process"):
        example_id = example["id"]
        context = example["context"]
        answers = []

        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = offset_mapping[feature_index]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Prediction is that there's no answer
                    if start_index == 0 and end_index == 0:
                      answers.append(
                          {
                              "text": None,
                              "logit_score": start_logit[start_index] + end_logit[end_index],
                          }
                      )
                    # Skip answers that are not fully in the context
                    elif offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length.
                    elif (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue
                    else:
                      answers.append(
                          {
                              "text": context[offsets[start_index][0] : offsets[end_index][1]],
                              "logit_score": start_logit[start_index] + end_logit[end_index],
                          }
                      )

        best_answer = max(answers, key=lambda x: x["logit_score"])
        predicted_answers.append(
            {
                "id": example_id,
                "prediction_text": best_answer["text"] or "",
                # TODO: can this be improved?
                "no_answer_probability": 1.0 if best_answer["text"] else 0.0
            }
        )
    return predicted_answers

In [7]:
def find_examples_with_no_answer(dataset):
  example_idxs = []
  for idx, ex in enumerate(dataset):
    if not ex["answers"].get("text"):
      example_idxs.append(idx)
  return example_idxs

In [8]:
import evaluate

squad_v2_metric = evaluate.load("squad_v2")

In [9]:
from functools import partial
from pprint import pprint


def compute_metrics(
    start_logits,
    end_logits,
    validation_dataset,
    validation_features,
    metric,
):
    predicted_answers = post_process(
        start_logits,
        end_logits,
        validation_dataset=validation_dataset,
        validation_features=validation_features,
    )
    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in validation_dataset
    ]

    return metric.compute(predictions=predicted_answers, references=theoretical_answers)


def evaluate_model(
        model_name=None,
        model=None,
        tokenizer=None,
        validation_dataset=None,
        inference_batch_size=500,
):
    if model_name:
        if model is not None or tokenizer is not None:
            raise ValueError("Cannot specify both model_name and model/tokenizer")
        model = AutoModelForQuestionAnswering.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16,
        )
        tokenizer = AutoTokenizer.from_pretrained(model_name)
    elif model is None or tokenizer is None:
        raise ValueError("Must specify either model_name or model and tokenizer")

    if validation_dataset is None:
        raise ValueError("Must specify validation_dataset")

    validation_features = validation_dataset.map(
        partial(preprocess_validation_examples, tokenizer=tokenizer),
        batched=True,
        remove_columns=validation_dataset.column_names,
    )
    outputs = do_inference(model, validation_features, batch_size=inference_batch_size)
    
    start_logits = outputs.start_logits.cpu().to(torch.float32).numpy()
    end_logits = outputs.end_logits.cpu().to(torch.float32).numpy()
    
    metrics = compute_metrics(
        start_logits,
        end_logits,
        validation_dataset=validation_dataset,
        validation_features=validation_features,
        metric=squad_v2_metric,
    )
    pprint(metrics)

In [10]:
# Evaluate distilbert
evaluate_model(
    model_name="distilbert/distilbert-base-cased-distilled-squad",
    validation_dataset=squad2["validation"],
    inference_batch_size=1000,
)

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

Inference:   0%|          | 0/12 [00:00<?, ?it/s]

Post process:   0%|          | 0/11873 [00:00<?, ?it/s]

{'HasAns_exact': 79.53778677462888,
 'HasAns_f1': 87.11129018902147,
 'HasAns_total': 5928,
 'NoAns_exact': 0.2691337258200168,
 'NoAns_f1': 0.2691337258200168,
 'NoAns_total': 5945,
 'best_exact': 50.11370336056599,
 'best_exact_thresh': 1.0,
 'best_f1': 50.11370336056599,
 'best_f1_thresh': 1.0,
 'exact': 39.84671102501474,
 'f1': 43.62804078501804,
 'total': 11873}


In [11]:
# Evaluate tiny RoBERTa
evaluate_model(
    model_name="deepset/tinyroberta-squad2",
    validation_dataset=squad2["validation"],
    inference_batch_size=1000,
)

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

Inference:   0%|          | 0/12 [00:00<?, ?it/s]

Post process:   0%|          | 0/11873 [00:00<?, ?it/s]

{'HasAns_exact': 76.48448043184885,
 'HasAns_f1': 82.57139438733313,
 'HasAns_total': 5928,
 'NoAns_exact': 81.37931034482759,
 'NoAns_f1': 81.37931034482759,
 'NoAns_total': 5945,
 'best_exact': 78.93539964625622,
 'best_exact_thresh': 1.0,
 'best_f1': 81.97449894113613,
 'best_f1_thresh': 1.0,
 'exact': 78.93539964625622,
 'f1': 81.97449894113623,
 'total': 11873}


## Fine-tuning

In [12]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [13]:
model_name = "distilbert/distilbert-base-cased-distilled-squad"
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_test_dataset = squad2["train"].train_test_split(test_size=0.1)
shuffle_seed = 999999

raw_train_dataset = train_test_dataset["train"]
train_dataset = raw_train_dataset.map(
    partial(preprocess_training_examples, tokenizer=tokenizer),
    batched=True,
    remove_columns=raw_train_dataset.column_names,
)
train_dataset = train_dataset.shuffle(seed=shuffle_seed)

raw_test_dataset = train_test_dataset["test"]
test_dataset = raw_test_dataset.map(
    partial(preprocess_training_examples, tokenizer=tokenizer),
    batched=True,
    remove_columns=raw_test_dataset.column_names,
)
test_dataset = test_dataset.shuffle(seed=shuffle_seed)

raw_validation_dataset = squad2["validation"]
validation_dataset = raw_validation_dataset.map(
    partial(preprocess_validation_examples, tokenizer=tokenizer),
    batched=True,
    remove_columns=raw_validation_dataset.column_names,
)
validation_dataset = validation_dataset.shuffle(seed=shuffle_seed)

print(f"Train examples: {len(raw_train_dataset)}, features: {len(train_dataset)}")
print(f"Test examples: {len(raw_test_dataset)}, features: {len(test_dataset)}")
print(f"Validation examples: {len(raw_validation_dataset)}, features: {len(validation_dataset)}")

Map:   0%|          | 0/104255 [00:00<?, ? examples/s]

Map:   0%|          | 0/26064 [00:00<?, ? examples/s]

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

Train examples: 104255, features: 104439
Test examples: 26064, features: 26105
Validation examples: 11873, features: 11974


In [14]:
from transformers import Trainer, TrainingArguments

args = TrainingArguments(
    "models/distilbert-base-cased-distilled-squad-v2",
    evaluation_strategy="steps",
    eval_steps=0.1,
    save_strategy="epoch",
    save_total_limit=3,
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    push_to_hub=True,
)


trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer
)

In [15]:
trainer.train()

Step,Training Loss,Validation Loss
3917,0.9861,0.893741
7834,0.9142,0.835824
11751,0.8906,0.792013
15668,0.6405,0.840566
19585,0.6474,0.802553
23502,0.6154,0.786321
27419,0.4052,0.905506
31336,0.4157,0.927355
35253,0.4134,0.927741


Checkpoint destination directory models/distilbert-base-cased-distilled-squad-v2/checkpoint-13055 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/distilbert-base-cased-distilled-squad-v2/checkpoint-26110 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/distilbert-base-cased-distilled-squad-v2/checkpoint-39165 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=39165, training_loss=0.6611109690359995, metrics={'train_runtime': 5707.85, 'train_samples_per_second': 54.892, 'train_steps_per_second': 6.862, 'total_flos': 4.093583734272614e+16, 'train_loss': 0.6611109690359995, 'epoch': 3.0})

In [16]:
trainer.push_to_hub(commit_message="Training complete")

CommitInfo(commit_url='https://huggingface.co/jackfriedson/distilbert-base-cased-distilled-squad-v2/commit/f8ebe183c921dcc57516df779de92bf2aa9c7c8e', commit_message='Training complete', commit_description='', oid='f8ebe183c921dcc57516df779de92bf2aa9c7c8e', pr_url=None, pr_revision=None, pr_num=None)

In [17]:
predictions, _, _ = trainer.predict(validation_dataset)
start_logits, end_logits = predictions

In [18]:
metrics = compute_metrics(
    start_logits,
    end_logits,
    validation_dataset=raw_validation_dataset,
    validation_features=validation_dataset,
    metric=squad_v2_metric,
)
metrics

Post process:   0%|          | 0/11873 [00:00<?, ?it/s]

{'exact': 66.88284342626126,
 'f1': 70.26828959712651,
 'total': 11873,
 'HasAns_exact': 66.21120107962213,
 'HasAns_f1': 72.99180202204461,
 'HasAns_total': 5928,
 'NoAns_exact': 67.55256518082422,
 'NoAns_f1': 67.55256518082422,
 'NoAns_total': 5945,
 'best_exact': 66.88284342626126,
 'best_exact_thresh': 1.0,
 'best_f1': 70.26828959712655,
 'best_f1_thresh': 1.0}