A multiple choice task is similar to question answering, except several candidate answers are provided along with a context and the model is trained to select the correct answer.

This guide shows how to:

1. Finetune BERT on the regular configuration of the SWAG dataset to select the best answer given multiple options and some context.
2. Use finetuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer

# Load Data

In [None]:
swag = load_dataset("swag", "regular")

In [None]:
# Inspection
# sent1 and sent2: these fields show how a sentence starts (put the two together, you get the startphrase field).
# ending[0-3]: suggests a possible ending for how a sentence can end, but only one of them is correct.
# label: identifies the correct sentence ending.
swag["train"][0]

# Preprocess

In [None]:
# load a BERT tokenizer to process the sentence starts and the four possible endings
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

In [None]:
# Create a preprocessing function to:
# 1. Make four copies of the sent1 field and combine each of them with sent2 to recreate how a sentence starts.
# 2. Combine each copy with each of the four possible sentence endings.
# 3. Flatten for tokenisation.
# 4. Unflatten so each copy has a corresponding input_ids, attention_mask, and labels field.
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}


In [None]:
tokenized_swag = swag.map(preprocess_function, batched=True)