In [1]:
! pip install torch transformers datasets evaluate



In [2]:
import torch
from datasets import load_dataset
from transformers import AutoModelForQuestionAnswering, AutoTokenizer

squad2 = load_dataset("squad_v2")

## Pre-process Training Data

In [3]:
MAX_TOKEN_LENGTH = 512
STRIDE = 128


def preprocess_training_examples(
    examples,
    tokenizer,
    max_token_length=MAX_TOKEN_LENGTH,
    stride=STRIDE,
):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_token_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
        return_tensors="pt",
    )

    offset_mapping = inputs.pop("offset_mapping")
    sample_map = inputs.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i]
        answer = answers[sample_idx]

        # If the answer doesn't exist, the label is (0, 0)
        if not answer["answer_start"]:
          start_positions.append(0)
          end_positions.append(0)
          continue

        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        # Find the start and end of the context
        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        # If the answer is not fully inside the context, label is (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            # Otherwise it's the start and end token positions
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [4]:
# from functools import partial

# train_dataset = squad2["train"].select(range(5000))
# preprocessed_train_dataset = train_dataset.map(
#     partial(preprocess_training_examples, tokenizer=roberta_squad2_tokenizer),
#     batched=True,
#     remove_columns=train_dataset.column_names,
# )
# len(train_dataset), len(preprocessed_train_dataset)

## Pre-process Test Data

In [5]:
def preprocess_test_examples(
    examples,
    tokenizer,
    max_token_length=MAX_TOKEN_LENGTH,
    stride=STRIDE,
):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=max_token_length,
        truncation="only_second",
        stride=stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    sample_map = inputs.pop("overflow_to_sample_mapping")
    example_ids = []

    for i in range(len(inputs["input_ids"])):
        sample_idx = sample_map[i]
        example_ids.append(examples["id"][sample_idx])

        sequence_ids = inputs.sequence_ids(i)
        offset = inputs["offset_mapping"][i]
        inputs["offset_mapping"][i] = [
            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
        ]

    inputs["example_id"] = example_ids
    return inputs

## Inference

In [6]:
from tqdm import trange


BATCH_SIZE = 700


def infer_outputs(model, preprocessed_inputs, batch_size=BATCH_SIZE):
    preprocessed_inputs.set_format("torch")

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        print("No GPU available, using CPU for inference")
        device = torch.device("cpu")
    
    model.to(device)

    num_examples = len(preprocessed_inputs)
    outputs = None
    for batch_start in trange(0, num_examples, batch_size):
      batch_end = min(batch_start+batch_size, num_examples)
      batch_inputs = preprocessed_inputs.select(range(batch_start, batch_end))
      batch_inputs_for_model = {
          k: batch_inputs[k].to(device)
          for k in batch_inputs.column_names
          if k not in ["offset_mapping", "example_id"]
      }

      with torch.no_grad():
        batch_outputs = model(**batch_inputs_for_model)

      # Free memory for inputs
      for v in batch_inputs_for_model.values():
        del v

      if outputs is None:
        outputs = batch_outputs
        for v in outputs.values():
          v.cpu()
      else:
        for k in batch_outputs.keys():
          outputs[k] = torch.cat((outputs[k], batch_outputs[k]), dim=0)

      # Free memory for outputs
      for v in batch_outputs.values():
        del v

    return outputs

## Post-processing

In [7]:
import collections
import numpy as np
import tqdm


def post_process(
    outputs,
    test_dataset,
    preprocessed_test_dataset,
    n_best=20,
    max_answer_length=30,
):
    preprocessed_test_dataset.set_format()

    start_logits = outputs.start_logits.cpu().to(torch.float32).numpy()
    end_logits = outputs.end_logits.cpu().to(torch.float32).numpy()
    offset_mapping = preprocessed_test_dataset["offset_mapping"]

    example_to_features = collections.defaultdict(list)
    example_ids = preprocessed_test_dataset["example_id"]
    for idx, example_id in tqdm.tqdm(enumerate(example_ids), total=len(example_ids)):
        example_to_features[example_id].append(idx)

    predicted_answers = []
    for example in tqdm.tqdm(test_dataset, total=len(test_dataset)):
        example_id = example["id"]
        context = example["context"]
        answers = []

        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = offset_mapping[feature_index]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Prediction is that there's no answer
                    if start_index == 0 and end_index == 0:
                      answers.append(
                          {
                              "text": None,
                              "logit_score": start_logit[start_index] + end_logit[end_index],
                          }
                      )
                    # Skip answers that are not fully in the context
                    elif offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length.
                    elif (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue
                    else:
                      answers.append(
                          {
                              "text": context[offsets[start_index][0] : offsets[end_index][1]],
                              "logit_score": start_logit[start_index] + end_logit[end_index],
                          }
                      )

        best_answer = max(answers, key=lambda x: x["logit_score"])
        predicted_answers.append(
            {
                "id": example_id,
                "prediction_text": best_answer["text"] or "",
                # TODO: can this be improved?
                "no_answer_probability": 1.0 if best_answer["text"] else 0.0
            }
        )
    return predicted_answers

In [8]:
def find_examples_with_no_answer(dataset):
  example_idxs = []
  for idx, ex in enumerate(dataset):
    if not ex["answers"].get("text"):
      example_idxs.append(idx)
  return example_idxs

In [9]:
import evaluate

metric = evaluate.load("squad_v2")

In [10]:
from functools import partial


def evaluate_model(model_name, test_dataset, inference_batch_size=500):
    model = AutoModelForQuestionAnswering.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    preprocessed_test_dataset = test_dataset.map(
        partial(preprocess_test_examples, tokenizer=tokenizer),
        batched=True,
        remove_columns=test_dataset.column_names,
    )

    outputs = infer_outputs(model, preprocessed_test_dataset, batch_size=inference_batch_size)

    predicted_answers = post_process(
        outputs,
        test_dataset,
        preprocessed_test_dataset=preprocessed_test_dataset,
    )
    theoretical_answers = [
        {"id": ex["id"], "answers": ex["answers"]} for ex in test_dataset
    ]

    print(metric.compute(predictions=predicted_answers, references=theoretical_answers))

In [11]:
# Evaluate RoBERTa
evaluate_model("deepset/roberta-base-squad2", squad2["validation"], inference_batch_size=750)

100%|██████████| 16/16 [01:06<00:00,  4.18s/it]
100%|██████████| 11955/11955 [00:00<00:00, 720961.96it/s]
100%|██████████| 11873/11873 [00:04<00:00, 2845.62it/s]


{'exact': 79.91240630000843, 'f1': 82.91556320321106, 'total': 11873, 'HasAns_exact': 77.9689608636977, 'HasAns_f1': 83.98388696216712, 'HasAns_total': 5928, 'NoAns_exact': 81.85029436501262, 'NoAns_f1': 81.85029436501262, 'NoAns_total': 5945, 'best_exact': 79.91240630000843, 'best_exact_thresh': 1.0, 'best_f1': 82.91556320321097, 'best_f1_thresh': 1.0}


In [12]:
# Evaluate MDeBERTa-v3
evaluate_model("timpal0l/mdeberta-v3-base-squad2", squad2["validation"], inference_batch_size=350)

100%|██████████| 35/35 [02:28<00:00,  4.24s/it]
100%|██████████| 12054/12054 [00:00<00:00, 821963.30it/s]
100%|██████████| 11873/11873 [00:05<00:00, 2063.48it/s]


{'exact': 80.30826244420113, 'f1': 83.47166598148709, 'total': 11873, 'HasAns_exact': 79.53778677462888, 'HasAns_f1': 85.87366568795501, 'HasAns_total': 5928, 'NoAns_exact': 81.07653490328006, 'NoAns_f1': 81.07653490328006, 'NoAns_total': 5945, 'best_exact': 80.30826244420113, 'best_exact_thresh': 1.0, 'best_f1': 83.47166598148702, 'best_f1_thresh': 1.0}


In [13]:
# Evaluate tiny RoBERTa
evaluate_model("deepset/tinyroberta-squad2", squad2["validation"], inference_batch_size=1000)

config.json:   0%|          | 0.00/835 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/326M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/383 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

100%|██████████| 12/12 [00:32<00:00,  2.73s/it]
100%|██████████| 11955/11955 [00:00<00:00, 790673.06it/s]
100%|██████████| 11873/11873 [00:04<00:00, 2472.88it/s]


{'exact': 78.93539964625622, 'f1': 81.97449894113623, 'total': 11873, 'HasAns_exact': 76.48448043184885, 'HasAns_f1': 82.57139438733313, 'HasAns_total': 5928, 'NoAns_exact': 81.37931034482759, 'NoAns_f1': 81.37931034482759, 'NoAns_total': 5945, 'best_exact': 78.93539964625622, 'best_exact_thresh': 1.0, 'best_f1': 81.97449894113613, 'best_f1_thresh': 1.0}


In [14]:
# Evaluate distilbert
evaluate_model("distilbert/distilbert-base-cased-distilled-squad", squad2["validation"], inference_batch_size=1000)

config.json:   0%|          | 0.00/473 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/261M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

100%|██████████| 12/12 [00:32<00:00,  2.70s/it]
100%|██████████| 11974/11974 [00:00<00:00, 1193332.61it/s]
100%|██████████| 11873/11873 [00:03<00:00, 3134.47it/s]


{'exact': 39.84671102501474, 'f1': 43.62804078501804, 'total': 11873, 'HasAns_exact': 79.53778677462888, 'HasAns_f1': 87.11129018902147, 'HasAns_total': 5928, 'NoAns_exact': 0.2691337258200168, 'NoAns_f1': 0.2691337258200168, 'NoAns_total': 5945, 'best_exact': 50.11370336056599, 'best_exact_thresh': 1.0, 'best_f1': 50.11370336056599, 'best_f1_thresh': 1.0}
