A multiple choice task is similar to question answering, except several candidate answers are provided along with a context and the model is trained to select the correct answer.

This guide shows how to:

1. Finetune BERT on the regular configuration of the SWAG dataset to select the best answer given multiple options and some context.
2. Use finetuned model for inference.

# Libraries

In [None]:
pip install transformers datasets evaluate

In [None]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset
from dataclasses import dataclass
from typing import Optional, Union
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy

mps_device = torch.device("mps")

# Load Data

In [None]:
swag = load_dataset("swag", "regular")

In [None]:
# Inspection
# sent1 and sent2: these fields show how a sentence starts (put the two together, you get the startphrase field).
# ending[0-3]: suggests a possible ending for how a sentence can end, but only one of them is correct.
# label: identifies the correct sentence ending.
swag["train"][0]

# Preprocess

In [None]:
# load a BERT tokenizer to process the sentence starts and the four possible endings
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")

In [None]:
# Create a preprocessing function to:
# 1. Make four copies of the sent1 field and combine each of them with sent2 to recreate how a sentence starts.
# 2. Combine each copy with each of the four possible sentence endings.
# 3. Flatten for tokenisation.
# 4. Unflatten so each copy has a corresponding input_ids, attention_mask, and labels field.
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}


In [None]:
tokenized_swag = swag.map(preprocess_function, batched=True)

In [None]:
# Transformers doesn’t have a data collator for multiple choice
# Adapt DataCollatorWithPadding to create a batch of examples
@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    This class flattens all the model inputs, applies padding, and then unflattens the results.
    """
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    # Method to dynamically pad the sentences to the longest length in a batch during collation...
    # ...instead of padding the whole dataset to the maximum length
    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

# Evaluation

In [None]:
# Load accuracy metric
accuracy = evaluate.load("accuracy")

In [None]:
# compute_metrics function called during training
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

# Training

In [None]:
model = AutoModelForMultipleChoice.from_pretrained("google-bert/bert-base-uncased")
model.to(mps_device)

In [None]:
training_args = TrainingArguments(
    output_dir="multiple_choice_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_swag["train"],
    eval_dataset=tokenized_swag["validation"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

# Inference

In [None]:
# Prepare some text and candidate answers
prompt = "France has a bread law, Le Décret Pain, with strict rules on what is allowed in a traditional baguette."
candidate1 = "The law does not apply to croissants and brioche."
candidate2 = "The law applies to baguettes."
candidate3 = "The law has nothing to do with bread as we know it, but with money instead."

In [None]:
# Tokenize each prompt and candidate answer pair and return PyTorch tensors
# Also create some labels
tokenizer = AutoTokenizer.from_pretrained("multiple_choice_model")
inputs = tokenizer([[prompt, candidate1], [prompt, candidate2], [prompt, candidate3]], return_tensors="pt", padding=True)
labels = torch.tensor(0).unsqueeze(0)

# Pass inputs and labels to the model and return the logits
model = AutoModelForMultipleChoice.from_pretrained("multiple_choice_model")
outputs = model(**{k: v.unsqueeze(0) for k, v in inputs.items()}, labels=labels)
logits = outputs.logits

In [None]:
predicted_class = logits.argmax().item()
predicted_class