# Multi Choice

In this notebook, it will:

    I. Explain the problem
    II. Model
    III. Realization

## I. Presentation

### 1. Definition

This is a case of MRC problem. Instead of locate the answer in the text explained in MRC, 

The problem can be summarized as :
 - given a context (optional) and a question
 - providing several choices
 - predict one or several answers

Example:
 - context: earthworms can regrow segments that break off	
 - Question: what do earthworms do when a segment breaks off 
 - choices: (a) dies (b) regrows it (c) reproduces (d) sediment (e) root growth (f) migrate (g) stops growing (h) roots	
 - answer: b



### 2. data processing

    ______________________________________________________________________________
    |CLS|       context       |SEP|      question     |SEP|      choice1     |SEP|
    ------------------------------------------------------------------------------
    |CLS|       context       |SEP|      question     |SEP|      choice2     |SEP|
    ------------------------------------------------------------------------------
    ...                                     ... 
    ...                                     ...

    Then
     ___________________________
     |CLS1| |CLS2| |CLS3| |CLS4|    ->  choice
     ---------------------------

## II.Model

The model used is AutoModelForMultipleChoice, with the bert base model.

This model uses bert base model to encode the input text, and output the classes (num_labels) of each tokens according to the classes we defined. So the output dimension is [batch, seq_len, classes].

In the class's init function:

```python
    self.num_labels = config.num_labels
    self.bert = BertModel(config, add_pooling_layer=False)
    ...
    self.classifier = nn.Linear(config.hidden_size, 1)
```
Where num_labels in this context is 1. We don't have to redefine.

In the forward function:

```python
    num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
    input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None # [batch * num_choice, seq_len]
    ...
    outputs = self.bert(
            input_ids,
            ...
        )
    pooled_output = outputs[1]                      # [batch * num_choice, hidden_size]

    pooled_output = self.dropout(pooled_output)
    logits = self.classifier(pooled_output)         # [batch * num_choice, 1]
    reshaped_logits = logits.view(-1, num_choices)  # [batch,  num_choice]
```

The input is encoded using bert model. The output of bert model is then put into a linear layer to project the hidden values to the choice space.

## III. Realization

In [1]:
# to set the gpu to use
# Since I have 2 GPUs and I only want to use one, I need to run this.
# Should be run the first
# skip this if you don't need.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # or "0,1" for multiple GPUs
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
## defin repos for data and model

# data

ckp_data = "layoric/labeled-multiple-choice-explained"

# model

ckp = "google-bert/bert-base-uncased"

### 1. import

In [3]:
import evaluate
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer

2024-06-20 21:46:46.656753: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-20 21:46:46.656817: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-20 21:46:46.659076: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-20 21:46:46.671389: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### 2. load data

In [4]:
data = load_dataset(ckp_data)
data

DatasetDict({
    train: Dataset({
        features: ['formatted_question', 'combinedfact', 'answerKey', 'topic', '__index_level_0__', 'explanation'],
        num_rows: 9098
    })
})

In [5]:
data["train"][4]

{'formatted_question': 'lightning can be bad for what? (a) the environment (b) rainstorms (c) destruction (d) visibility (e) thunder (f) the sun (g) the weather. (h) transportation',
 'combinedfact': 'lightning can be bad for the environment.',
 'answerKey': 'a',
 'topic': 'electricity',
 '__index_level_0__': 34080,
 'explanation': 'b) Rainstorms: Lightning is actually a natural phenomenon that occurs during rainstorms. It is not bad for rainstorms, but rather a part of the storm.\n\nc) Destruction: While lightning can cause destruction, the question is asking what lightning can be bad for, not what it can cause.\n\nd) Visibility: Lightning can actually improve visibility during a storm by illuminating the surroundings. It is not bad for visibility.\n\ne) Thunder: Thunder is actually caused by lightning, so it is not bad for thunder.\n\nf) The sun: Lightning is not related to the sun, so it is not bad for the sun.\n\ng) The weather: Lightning is a part of the weather, so it is not bad 

### 3. Split data

In [6]:
split_data = data["train"].train_test_split(test_size=0.2)
split_data

DatasetDict({
    train: Dataset({
        features: ['formatted_question', 'combinedfact', 'answerKey', 'topic', '__index_level_0__', 'explanation'],
        num_rows: 7278
    })
    test: Dataset({
        features: ['formatted_question', 'combinedfact', 'answerKey', 'topic', '__index_level_0__', 'explanation'],
        num_rows: 1820
    })
})

### 4. tokenization

In [7]:
# load tokenizer

tokenizer = AutoTokenizer.from_pretrained(ckp)
tokenizer

BertTokenizerFast(name_or_path='google-bert/bert-base-uncased', vocab_size=30522, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [8]:
label2id = {'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7}

In [9]:
# processing data
# this should be adapted according to the dataset

import re

def process(samples):

    contexts = []
    questions_choices = []
    answers = []
    toks={}

    for ind in range(len(samples["topic"])):
        
        context = samples["combinedfact"][ind]
        
        # all choices are prefix with (X) where X can be a-h
        # so we get the choices by spliting by (X)
        ctx = re.split(" \([a-z]\) ", samples["formatted_question"][ind])

        # the question preceed the choices
        question = ctx[0]

        # combine the input as in the presentation
        # - context + question + choice_X
        for i, c in enumerate(ctx[1:]):
            contexts.append(context)
            questions_choices.append(question + " " + c)

        while(i < 7):
            contexts.append(context)
            questions_choices.append(question + " unKnown")
            i += 1 

        answers.append(label2id.get(samples["answerKey"][ind]))

    # tokenization
    # (batch, 8*seq_len)
    toks = tokenizer(contexts, questions_choices, truncation="only_first", max_length=128, padding="max_length")

    # rearrange the question-choices
    # (batch, 8, seq_len)
    toks = {k: [v[i:i+8] for i in range(0, len(v), 8)] for k, v in toks.items()}
    toks["labels"] = answers

    return toks

In [10]:
tokenized_data = split_data.map(process, batched=True)
tokenized_data

Map:   0%|          | 0/7278 [00:00<?, ? examples/s]

Map:   0%|          | 0/1820 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['formatted_question', 'combinedfact', 'answerKey', 'topic', '__index_level_0__', 'explanation', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 7278
    })
    test: Dataset({
        features: ['formatted_question', 'combinedfact', 'answerKey', 'topic', '__index_level_0__', 'explanation', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1820
    })
})

In [11]:
# they only differ at the end

print(tokenized_data["train"][0]["input_ids"])

[[101, 5492, 3596, 17530, 1997, 2300, 1012, 102, 2054, 3596, 17530, 1997, 2300, 1029, 13016, 2015, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 5492, 3596, 17530, 1997, 2300, 1012, 102, 2054, 3596, 17530, 1997, 2300, 1029, 5492, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 5492, 3596, 17530, 1997, 2300, 1012, 102, 2054, 3596, 17530, 1997, 2300, 1029, 3221, 17530, 1012, 102, 0, 0, 0, 0, 0,

In [12]:
# show the shape of the input

import numpy as np

print(np.array(tokenized_data["train"]["input_ids"]).shape)

(7278, 8, 128)


### 5. load model

In [13]:
model = AutoModelForMultipleChoice.from_pretrained(ckp)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
model

BertForMultipleChoice(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, ele

### 6. define metric

In [15]:
# we only compare the choice, exact match

import numpy as np

acc = evaluate.load("accuracy")

def metric(pred):

    preds, refs = pred

    preds = preds.argmax(axis=-1)

    return acc.compute(predictions=preds, references=refs)

### 7. train args

In [16]:
args = TrainingArguments(
        output_dir="../tmp/checkpoints",
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=3,
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_steps=10,
        load_best_model_at_end=True,
)

### 8.trainer

In [17]:
trainer = Trainer(
    model = model,
    args=args,
    train_dataset=tokenized_data["train"],
    eval_dataset=tokenized_data["test"],
    compute_metrics=metric
)

### 9. train + eval

In [18]:
trainer.evaluate(eval_dataset=tokenized_data["test"])

{'eval_loss': 2.0795347690582275,
 'eval_accuracy': 0.14725274725274726,
 'eval_runtime': 26.5249,
 'eval_samples_per_second': 68.615,
 'eval_steps_per_second': 2.149}

In [19]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.0632,0.059893,0.984066
2,0.0258,0.056948,0.987363
3,0.0248,0.062097,0.987363


TrainOutput(global_step=684, training_loss=0.07388382962883085, metrics={'train_runtime': 811.4543, 'train_samples_per_second': 26.907, 'train_steps_per_second': 0.843, 'total_flos': 1.1489430405574656e+16, 'train_loss': 0.07388382962883085, 'epoch': 3.0})

### 10. inference

In [23]:
# there is no predefined pipeline by hf for multi choice problem

import torch

class MultipleChoicePipeline:

    def __init__(self, model, tokenizer):

        self.model = model
        self.tokenizer = tokenizer

        self.device = model.device

    def preprocess(self, context, question, choices):

        ctx, qs = [], []
        for choice in choices:
            ctx.append(context)
            qs.append(question + " " + choice)
        return tokenizer(ctx, qs, truncation="only_first", max_length=128, padding=True, return_tensors="pt")

    def predict(self, input):

        input = {k:v.unsqueeze(0).to(self.device) for k, v in input.items()}
        return self.model(**input).logits

    def postprocess(self, logits, choices):

        pred = torch.argmax(logits, dim=-1).cpu().item()

        return choices[pred]

    def __call__(self, context, question, choices):

        input = self.preprocess(context, question, choices)
        logits = self.predict(input)
        result = self.postprocess(logits, choices)
        return result

In [24]:
pipe = MultipleChoicePipeline(model, tokenizer)

In [26]:
pipe("lightning can be bad for the environment.", "lightning can be bad for what?", ["environment", "rainstorms", "destruction", "visibility"])

'environment'