In [None]:
!pip install -q transformers datasets evaluate

# Multiple Choice

A multiple choice task is similar to question answering, except several candidate answers are provided along with a context and the model is trained to select the correct answer.

## Load the SWAG dataset

In [None]:
from datasets import load_dataset

swag = load_dataset('swag', 'regular')

README.md:   0%|          | 0.00/9.20k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/14.8M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/4.81M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/4.78M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/73546 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/20006 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/20005 [00:00<?, ? examples/s]

In [None]:
swag['train'][0]

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0}

* `sent1` and `sent2`: these fields show how a sentence starts, and if we put the two together, we get the `startphrase` field.
* `ending`: suggusts a possible ending for how a sentence can end, but only one of them is correct.
* `label`: identifies the correct sentence ending.

## Preprocess

In [None]:
from transformers import AutoTokenizer

checkpoint = 'google-bert/bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



The preprocessing function needs to:
* Make four copies of the `sent1` field and combine each of them with `sent2` to recreate how a sentence starts.
* Combine `sent2` with each of the four possible sentence endings.
* Flatten these two lists so we can tokenize them, and then unflatten them afterward so each example has a corresponding `input_ids`, `attention_mask`, and `labels` field.

In [None]:
ending_names = ['ending0', 'ending1', 'ending2', 'ending3']


def preprocess_function(examples):
    first_sentences = [[context]*4 for context in examples['sent1']]
    question_headers = examples['sent2']

    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names]
        for i, header in enumerate(question_headers)
    ]

    # flatten
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {
        k: [v[i : i+4] for i in range(0, len(v), 4)]
        for k,v in tokenized_examples.items()
    }

In [None]:
tokenized_swag = swag.map(preprocess_function, batched=True)

Map:   0%|          | 0/73546 [00:00<?, ? examples/s]

Map:   0%|          | 0/20006 [00:00<?, ? examples/s]

Map:   0%|          | 0/20005 [00:00<?, ? examples/s]

Transformers library does not have a data collator for multiple choices, so we need to adapt the `DataCollatorWithPadding` to create a batch of examples.

The following `DataCollatorForMultipleChoices` flattens all the model inputs, applies padding, and then unflattens the results:

In [None]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = 'label' if 'label' in features[0].keys() else 'labels'
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]['input_ids'])

        flattened_features = [
            [{k: v[i] for k,v in feature.items()} for i in range(num_choices)]
            for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors='pt',
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch['labels'] = torch.tensor(labels, dtype=torch.int64)
        return batch

## Evaluate

In [None]:
import evaluate

accuracy = evaluate.load('accuracy')

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

## Train

In [None]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained(checkpoint)

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
training_args = TrainingArguments(
    output_dir="my_swag_model",
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_swag["train"],
    eval_dataset=tokenized_swag["validation"],
    processing_class=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer), # use our customed collator
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

## Inference

In [None]:
prompt = "France has a bread law, Le Décret Pain, with strict rules on what is allowed in a traditional baguette."
candidate1 = "The law does not apply to croissants and brioche."
candidate2 = "The law applies to baguettes."

In [None]:
from transformers import AutoTokenizer, AutoModelForMultipleChoice

checkpoint = 'stevhliu/my_awesome_swag_model'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForMultipleChoice.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/348 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

In [None]:
inputs = tokenizer(
    [[prompt, candidate1], [prompt, candidate2]],
    return_tensors='pt',
    padding=True,
)
labels = torch.tensor(0).unsqueeze(0)
outputs = model(**{k: v.unsqueeze(0) for k,v in inputs.items()}, labels=labels)
logits = outputs.logits
logits

tensor([[5.9728, 5.7664]], grad_fn=<ViewBackward0>)

In [None]:
predicted_class = logits.argmax().item()
predicted_class

0