# BertForMultipleChoice : Fine-tuning on SWAG with BERT

In [1]:
# !pip install transformers datasets evaluate

In [2]:
import torch
import evaluate
import numpy as np
from datasets import load_dataset, Value
from DataCollatorForMultipleChoice import DataCollatorForMultipleChoice
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, Trainer

Hyperparameters from the paper: "**4.4 SWAG** We fine-tune the model for 3 epochs with a learning rate of 2e-5 and a batch size of 16."

In [3]:
lr = 2e-5
epochs = 3
batch_size = 16
weight_decay = 0.01
model_name = "bert-base-uncased"

## Dataset
### Swag: A Large-Scale Adversarial Dataset for Grounded Commonsense Inference
**Abstract** : *Given a partial description like “she opened the hood of the car,” humans can reason about the situation and anticipate what might come next (“then, she examined the engine”). In this paper, we introduce the task of grounded commonsense inference, unifying natural language inference and commonsense reasoning. We present Swag, a new dataset with 113k multiple choice questions about a rich spectrum of grounded situations. To address the recurring challenges of the annotation artifacts and human biases found in many existing datasets, we propose Adversarial Filtering (AF), a novel procedure that constructs a de-biased dataset by iteratively training an ensemble of stylistic classifiers, and using them to filter the data. To account for the aggressive adversarial filtering, we use state-of-theart language models to massively oversample a diverse set of potential counterfactuals. Empirical results demonstrate that while humans can solve the resulting inference problems with high accuracy (88%), various competitive models struggle on our task. We provide comprehensive analysis that indicates significant opportunities for future research.*

An example of the dataset looks like this:
```
{
    'video-id': 'anetv_jkn6uvmqwh4',
    'fold-ind': '3417',
    'startphrase': 'A drum line passes by walking down the street playing their instruments. Members of the procession',
    'sent1': 'A drum line passes by walking down the street playing their instruments.',
    'sent2': 'Members of the procession',
    'gold-source': 'gen',
    'ending0': 'are playing ping pong and celebrating one left each in quick.',
    'ending1': 'wait slowly towards the cadets.',
    'ending2': 'makes a square call and ends by jumping down into snowy streets where fans begin to take their positions.',
    'ending3': 'play and go back and forth hitting the drums while the audience claps for them.',
    'label': 3
}
```
where,
- video-id = identification
- fold-ind = identification
- startphrase = the context to be filled
- **sent1** = the first sentence
- **sent2** = the start of the second sentence (to be filled)
- gold-source = generated or comes from the found completion
- **ending0** = first proposition
- **ending1** = second proposition
- **ending2** = third proposition
- **ending3** = fourth proposition
- **label** = the correct proposition

In [4]:
# Load the dataset
swag = load_dataset("swag", "regular")
swag["train"].to_pandas().head()

Unnamed: 0,video-id,fold-ind,startphrase,sent1,sent2,gold-source,ending0,ending1,ending2,ending3,label
0,anetv_jkn6uvmqwh4,3416,Members of the procession walk down the street...,Members of the procession walk down the street...,A drum line,gold,passes by walking down the street playing thei...,has heard approaching them.,arrives and they're outside dancing and asleep.,turns the lead singer watches the performance.,0
1,anetv_jkn6uvmqwh4,3417,A drum line passes by walking down the street ...,A drum line passes by walking down the street ...,Members of the procession,gen,are playing ping pong and celebrating one left...,wait slowly towards the cadets.,continues to play as well along the crowd alon...,"continue to play marching, interspersed.",3
2,anetv_jkn6uvmqwh4,3415,A group of members in green uniforms walks wav...,A group of members in green uniforms walks wav...,Members of the procession,gold,pay the other coaches to cheer as people this ...,walk down the street holding small horn brass ...,is seen in the background.,are talking a couple of people playing a game ...,1
3,anetv_jkn6uvmqwh4,3417,A drum line passes by walking down the street ...,A drum line passes by walking down the street ...,Members of the procession,gen,are playing ping pong and celebrating one left...,wait slowly towards the cadets.,makes a square call and ends by jumping down i...,play and go back and forth hitting the drums w...,3
4,anetv_Bri_myFFu4A,2408,The person plays a song on the violin. The man,The person plays a song on the violin.,The man,gold,finishes the song and lowers the instrument.,hits the saxophone and demonstrates how to pro...,finishes massage the instrument again and cont...,continues dancing while the man gore the music...,0


In [5]:
# An example of how a human would view this task
example = swag["train"][0]
print(f"{example['sent1']}\n\
    A - {example['sent2']} {example['ending0']}\n\
    B - {example['sent2']} {example['ending1']}\n\
    C - {example['sent2']} {example['ending2']}\n\
    D - {example['sent2']} {example['ending3']}\n\
Ground truth: option {['A', 'B', 'C', 'D'][example['label']]}")

Members of the procession walk down the street holding small horn brass instruments.
    A - A drum line passes by walking down the street playing their instruments.
    B - A drum line has heard approaching them.
    C - A drum line arrives and they're outside dancing and asleep.
    D - A drum line turns the lead singer watches the performance.
Ground truth: option A


## Pre-processing
### Tokenizer


We pre-process and tokenize the texts corresponding to the `bert-base-uncased` model architecture. We get three keys: `input_ids`, `token_type_ids` and `attention_mask`.
- input_ids = numerical representations of tokens building the sentences
- token_type_ids = segment embeddings, indicating which sentence the token belongs to
- attention_mask = which tokens should be attended to, and which should not

From the paper, "**4.4 SWAG** *When fine-tuning on the SWAG dataset, we construct four input sequences, each containing the concatenation of the given sentence (sentence `A`) and a possible continuation (sentence `B`). The only task-specific parameters introduced is a vector whose dot product with the `[CLS]` token representation `C` denotes a score for each choice which is normalized with a softmax layer*".

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
endings = ["ending0", "ending1", "ending2", "ending3"]

def preprocess_swag(data):
    sentA = sum([[A] * 4 for A in data["sent1"]], [])
    sentB = sum([[f"{B} {data[ending][i]}" for ending in endings] for i, B in enumerate(data["sent2"])], [])
    
    tokenized_sents = tokenizer(sentA, sentB, truncation=True)
    return {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_sents.items()}

encoded_swag = swag.map(preprocess_swag, batched=True)

In [7]:
# Sanity check
# [tokenizer.decode(encoded_swag["train"]["input_ids"][0][i]) for i in range(4)]

In [8]:
# Sanity check
# accepted_keys = ["input_ids", "attention_mask", "label"]
# features = [{k: v for k, v in encoded_swag["train"][i].items() if k in accepted_keys} for i in range(10)]
# batch = DataCollatorForMultipleChoice(tokenizer)(features)
# [tokenizer.decode(batch["input_ids"][0][i].tolist()) for i in range(4)]

## Model
### BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
**Abstract**: *We introduce a new language representation model called BERT, which stands for Bidirectional Encoder Representations from Transformers. Unlike recent language representation models (Peters et al., 2018a; Radford et al., 2018), BERT is designed to pretrain deep bidirectional representations from unlabeled text by jointly conditioning on both left and right context in all layers. As a result, the pre-trained BERT model can be finetuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial taskspecific architecture modifications. BERT is conceptually simple and empirically powerful. It obtains new state-of-the-art results on eleven natural language processing tasks, including pushing the GLUE score to 80.5% (7.7% point absolute improvement), MultiNLI accuracy to 86.7% (4.6% absolute improvement), SQuAD v1.1 question answering Test F1 to 93.2 (1.5 point absolute improvement) and SQuAD v2.0 Test F1 to 83.1 (5.1 point absolute improvement).*

We evaluate the pre-trained **BERT<sub>BASE</sub>** (L=12, H=768, A=12, Total Parameters=110M), where:
- L = number of layers (i.e., Transformer blocks)
- H = hidden size
- A = number of self-attention heads
- Feed-forward/filter size to be 4H, i.e., 3072 for H = 768


In [9]:
model = AutoModelForMultipleChoice.from_pretrained(model_name)

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The warning is telling us we are throwing away some weights (the `vocab_transform` and `vocab_layer_norm` layers) and randomly initializing some other (the `pre_classifier` and `classifier` layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.

In [10]:
# Include a metric for training
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

From here, we follow four steps:
1. Define training hyperparameters in `TrainingArguments`. At the end of each epoch, the `Trainer` will evaluate the accuracy and save the training checkpoint
2. Pass training arguments to `Trainer` along with `model`, `dataset`, `tokenizer`, `data collator`, and `compute_metrics`
3. Call `train()` to fine-tune the model
4. Evaluate the model on using `evaluate()`/`predict()`

In [11]:
# Uncomment this to re-run the fine-tuning process
# training_args = TrainingArguments(
#     f"{model_name}-finetuned-swag",
#     evaluation_strategy="epoch",
#     save_strategy="epoch",
#     load_best_model_at_end=True,
#     learning_rate=lr,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     num_train_epochs=epochs,
#     weight_decay=weight_decay,
# )
# trainer = Trainer(
#     model=model,
#     args=training_args,
#     train_dataset=encoded_swag["train"],
#     eval_dataset=encoded_swag["validation"],
#     tokenizer=tokenizer,
#     data_collator=DataCollatorForMultipleChoice(tokenizer),
#     compute_metrics=compute_metrics,
# )
# trainer.train()

In [12]:
# Comment this if re-running the fine-tuning process
model_name = "chakraborty-de/bert-base-uncased-finetuned-swag"
model = AutoModelForMultipleChoice.from_pretrained(model_name)

In [13]:
# Comment this if re-running the fine-tuning process
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer),
    compute_metrics=compute_metrics,
)
model.eval()
trainer.evaluate(encoded_swag["validation"])

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mchakraborty-de[0m. Use [1m`wandb login --relogin`[0m to force relogin


{'eval_loss': 1.032250165939331,
 'eval_accuracy': 0.790962711186644,
 'eval_runtime': 76.695,
 'eval_samples_per_second': 260.851,
 'eval_steps_per_second': 32.61}

## Problems with the paper and new insights
### HellaSwag: Can a Machine Really Finish Your Sentence?
**Abstract** : *Recent work by Zellers et al. (2018) introduced a new task of commonsense natural language inference: given an event description such as "A woman sits at a piano," a machine must select the most likely followup: "She sets her fingers on the keys." With the introduction of BERT, near human-level performance was reached. Does this mean that machines can perform human level commonsense inference?
In this paper, we show that commonsense inference still proves difficult for even state-of-the-art models, by presenting HellaSwag, a new challenge dataset. Though its questions are trivial for humans (>95% accuracy), state-of-the-art models struggle (<48%). We achieve this via Adversarial Filtering (AF), a data collection paradigm wherein a series of discriminators iteratively select an adversarial set of machine-generated wrong answers. AF proves to be surprisingly robust. The key insight is to scale up the length and complexity of the dataset examples towards a critical 'Goldilocks' zone wherein generated text is ridiculous to humans, yet often misclassified by state-of-the-art models.
Our construction of HellaSwag, and its resulting difficulty, sheds light on the inner workings of deep pretrained models. More broadly, it suggests a new path forward for NLP research, in which benchmarks co-evolve with the evolving state-of-the-art in an adversarial way, so as to present ever-harder challenges.*

An example of the dataset looks like this:
```
{
    'ind': 4,
    'activity_label': 'Removing ice from car',
    'ctx_a': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.',
    'ctx_b': 'then',
    'ctx': 'Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles. then',
    'endings': [
        ', the man adds wax to the windshield and cuts it.',
        ', a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.',
        ', the man puts on a christmas coat, knitted with netting.',
        ', the man continues removing the snow on his car.'
    ],
    'source_id': 'activitynet~v_-1IBHYS3L-Y',
    'split': 'train',
    'split_type': 'indomain',
    'label': 3
}
```
where,
- ind = identification
- source_id = identification
- activity_label = additional activity context
- ctx = the context to be filled
- **ctx_a** = the first sentence
- **ctx_b** = the start of the second sentence (to be filled)
- gold-source = generated or comes from the found completion
- **endings** = list of propositions
- split = train/validation/test
- split_type = indomain/zeroshot
- **label** = the correct proposition

In [14]:
# Load the dataset
hella_swag_trn = load_dataset("Rowan/hellaswag", split="train")
hella_swag_val = load_dataset("Rowan/hellaswag", split="validation")
hella_swag_trn = hella_swag_trn.cast_column("label", Value(dtype='int32', id=None))
hella_swag_val = hella_swag_val.cast_column("label", Value(dtype='int32', id=None))
hella_swag_trn.to_pandas().head()

Unnamed: 0,ind,activity_label,ctx_a,ctx_b,ctx,endings,source_id,split,split_type,label
0,4,Removing ice from car,"Then, the man writes over the snow covering th...",then,"Then, the man writes over the snow covering th...","[, the man adds wax to the windshield and cuts...",activitynet~v_-1IBHYS3L-Y,train,indomain,3
1,8,Baking cookies,A female chef in white uniform shows a stack o...,the pans,A female chef in white uniform shows a stack o...,"[contain egg yolks and baking soda., are then ...",activitynet~v_-2dxp-mv2zo,train,indomain,3
2,9,Baking cookies,A female chef in white uniform shows a stack o...,a knife,A female chef in white uniform shows a stack o...,[is seen moving on a board and cutting out its...,activitynet~v_-2dxp-mv2zo,train,indomain,3
3,12,Baking cookies,A tray of potatoes is loaded into the oven and...,a large tray of meat,A tray of potatoes is loaded into the oven and...,"[is placed onto a baked potato., , ls, and pic...",activitynet~v_-2dxp-mv2zo,train,indomain,3
4,27,Getting a haircut,The man in the center is demonstrating a hairs...,the man in the blue shirt,The man in the center is demonstrating a hairs...,[is standing on the sponge cutting the hair of...,activitynet~v_-JqLjPz-07E,train,indomain,2


In [15]:
# An example of how a human would view this task
example = hella_swag_trn[0]
print(f"{example['ctx_a']}\n\
    A - {example['ctx_b']} {example['endings'][0]}\n\
    B - {example['ctx_b']} {example['endings'][1]}\n\
    C - {example['ctx_b']} {example['endings'][2]}\n\
    D - {example['ctx_b']} {example['endings'][3]}\n\
Ground truth: option {['A', 'B', 'C', 'D'][example['label']]}")

Then, the man writes over the snow covering the window of a car, and a woman wearing winter clothes smiles.
    A - then , the man adds wax to the windshield and cuts it.
    B - then , a person board a ski lift, while two men supporting the head of the person wearing winter clothes snow as the we girls sled.
    C - then , the man puts on a christmas coat, knitted with netting.
    D - then , the man continues removing the snow on his car.
Ground truth: option D


In [16]:
def preprocess_hella_swag(data):
    sentA = sum([[A] * 4 for A in data["ctx_a"]], [])
    sentB = sum([[f"{B}{ending}" for ending in data["endings"][i]] for i, B in enumerate(data["ctx_b"])], [])
    
    tokenized_sents = tokenizer(sentA, sentB, truncation=True)
    return {k: [v[i:i+4] for i in range(0, len(v), 4)] for k, v in tokenized_sents.items()}

encoded_hella_swag_trn = hella_swag_trn.map(preprocess_hella_swag, batched=True)
encoded_hella_swag_val = hella_swag_val.map(preprocess_hella_swag, batched=True)

In [17]:
# Sanity check
# [tokenizer.decode(encoded_hella_swag_trn["input_ids"][0][i]) for i in range(4)]

In [18]:
# Sanity check
# accepted_keys = ["input_ids", "attention_mask", "label"]
# features = [{k: v for k, v in encoded_hella_swag_trn[i].items() if k in accepted_keys} for i in range(10)]
# batch = DataCollatorForMultipleChoice(tokenizer)(features)
# [tokenizer.decode(batch["input_ids"][0][i].tolist()) for i in range(4)]

### Results on the full validation

In [19]:
trainer.evaluate(encoded_hella_swag_val)

{'eval_loss': 2.54646372795105,
 'eval_accuracy': 0.32613025293766185,
 'eval_runtime': 88.1225,
 'eval_samples_per_second': 113.955,
 'eval_steps_per_second': 14.253}

### Results on the in-domain (with SWAG) validation (split_type = "indomain")

In [20]:
encoded_hella_swag_val_indomain = encoded_hella_swag_val.filter(lambda x: x["split_type"] == "indomain")
trainer.evaluate(encoded_hella_swag_val_indomain)

{'eval_loss': 2.72586727142334,
 'eval_accuracy': 0.3311337732453509,
 'eval_runtime': 43.0477,
 'eval_samples_per_second': 116.174,
 'eval_steps_per_second': 14.542}

### Results on the zero-shot validation (split_type = "zeroshot")

In [21]:
encoded_hella_swag_val_zeroshot = encoded_hella_swag_val.filter(lambda x: x["split_type"] == "zeroshot")
trainer.evaluate(encoded_hella_swag_val_zeroshot)

{'eval_loss': 2.368483543395996,
 'eval_accuracy': 0.32116643523110494,
 'eval_runtime': 46.0911,
 'eval_samples_per_second': 109.37,
 'eval_steps_per_second': 13.69}

A comprehensive comparison of fine-tuning BASE-Base on SWAG dataset, then testing on both SWAG and more difficult HellaSwag:

|                         | SWAG   | HellaSwag (Overall) | HellaSwag (In-Domain) | HellaSwag (Zero-Shot) |
|-------------------------|--------|---------------------|-----------------------|-----------------------|
| **Validation loss**     | 1.032  | 2.546               | 2.725                 | 2.368                 |
| **Validation accuracy** | 79.10% | 32.61%              | 33.11%                | 32.11%                |