*Copyright (c) Microsoft Corporation. All rights reserved.*  

*Licensed under the MIT License.*

# Natural Language Inference on MultiNLI Dataset using Transformers

# Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|4 **CPU**s, 14GB memory| ~ 15 minutes|
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5 minutes|
|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 10.5 hours|
|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 2.5 hours|

If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

## Summary
In this notebook, we demostrate using [BERT](https://arxiv.org/abs/1810.04805) to perform Natural Language Inference (NLI). We use the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   
The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.
<img src="https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG">

In [2]:
import sys, os
nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)
    
from tempfile import TemporaryDirectory

import numpy as np
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

import torch

from utils_nlp.models.transformers.sequence_classification import Processor, SequenceClassifier
from utils_nlp.dataset.multinli import load_pandas_df
from utils_nlp.common.timer import Timer

I1107 22:10:21.768640 139623268476672 file_utils.py:39] PyTorch version 1.2.0 available.
I1107 22:10:21.812602 139623268476672 modeling_xlnet.py:194] Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
I1107 22:10:22.139613 139623268476672 utils.py:141] NumExpr defaulting to 6 threads.


## Configurations

In [3]:
MODEL_NAME = "bert-base-uncased"

TRAIN_DATA_USED_PERCENT = 1
DEV_DATA_USED_PERCENT = 1
NUM_EPOCHS = 2
WARMUP_STEPS= 2500

if QUICK_RUN:
    TRAIN_DATA_USED_PERCENT = 0.001
    DEV_DATA_USED_PERCENT = 0.01
    NUM_EPOCHS = 1
    WARMUP_STEPS= 10

if torch.cuda.is_available():
    BATCH_SIZE = 32
else:
    BATCH_SIZE = 16

RANDOM_SEED = 42

# model configurations
TO_LOWER = True
MAX_SEQ_LENGTH = 128

# optimizer configurations
LEARNING_RATE= 5e-5

# data configurations
TEXT_COL = "text"
LABEL_COL = "gold_label"

# CACHE_DIR = TemporaryDirectory().name
CACHE_DIR = "./temp"

## Load Data
The MultiNLI dataset comes with three subsets: train, dev_matched, dev_mismatched. The dev_matched dataset are from the same genres as the train dataset, while the dev_mismatched dataset are from genres not seen in the training dataset.   
The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split`.

In [4]:
train_df = load_pandas_df(local_cache_path=CACHE_DIR, file_split="train")
dev_df_matched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_matched")
dev_df_mismatched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_mismatched")

In [5]:
dev_df_matched = dev_df_matched.loc[dev_df_matched['gold_label'] != '-']
dev_df_mismatched = dev_df_mismatched.loc[dev_df_mismatched['gold_label'] != '-']

In [6]:
print("Training dataset size: {}".format(train_df.shape[0]))
print("Development (matched) dataset size: {}".format(dev_df_matched.shape[0]))
print("Development (mismatched) dataset size: {}".format(dev_df_mismatched.shape[0]))
print()
print(train_df[['gold_label', 'sentence1', 'sentence2']].head())

Training dataset size: 392702
Development (matched) dataset size: 9815
Development (mismatched) dataset size: 9832

   gold_label                                          sentence1  \
0     neutral  Conceptually cream skimming has two basic dime...   
1  entailment  you know during the season and i guess at at y...   
2  entailment  One of our number will carry out your instruct...   
3  entailment  How do you know? All this is their information...   
4     neutral  yeah i tell you what though if you go price so...   

                                           sentence2  
0  Product and geography are what make cream skim...  
1  You lose the things to the following level if ...  
2  A member of my team will execute your orders w...  
3                  This information belongs to them.  
4           The tennis shoes have a range of prices.  


Concatenate the first and second sentences to form the input text.

In [7]:
train_df[TEXT_COL] = list(zip(train_df['sentence1'], train_df['sentence2']))
dev_df_matched[TEXT_COL] = list(zip(dev_df_matched['sentence1'], dev_df_matched['sentence2']))
dev_df_mismatched[TEXT_COL] = list(zip(dev_df_mismatched['sentence1'], dev_df_mismatched['sentence2']))
train_df[[TEXT_COL, LABEL_COL]].head()

Unnamed: 0,text,gold_label
0,(Conceptually cream skimming has two basic dim...,neutral
1,(you know during the season and i guess at at ...,entailment
2,(One of our number will carry out your instruc...,entailment
3,(How do you know? All this is their informatio...,entailment
4,(yeah i tell you what though if you go price s...,neutral


In [8]:
train_df = train_df.sample(frac=TRAIN_DATA_USED_PERCENT).reset_index(drop=True)
dev_df_matched = dev_df_matched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)
dev_df_mismatched = dev_df_mismatched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)

In [9]:
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_df[LABEL_COL])
num_labels = len(np.unique(train_labels))

## Tokenize and Preprocess
Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets.

In [10]:
%load_ext autoreload

In [11]:
%autoreload 2

In [16]:
processor = Processor(model_name=MODEL_NAME, cache_dir=CACHE_DIR)
train_dataset = processor.preprocess_sentence_pair(
    train_df[TEXT_COL], train_labels, max_len=MAX_SEQ_LENGTH
)
dev_dataset_matched = processor.preprocess_sentence_pair(dev_df_matched[TEXT_COL], None, max_len=MAX_SEQ_LENGTH)
dev_dataset_mismatched = processor.preprocess_sentence_pair(dev_df_mismatched[TEXT_COL], None, max_len=MAX_SEQ_LENGTH)

I1107 22:12:45.331787 139623268476672 tokenization_utils.py:373] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at ./temp/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
100%|██████████| 393/393 [00:00<00:00, 1733.76it/s]
100%|██████████| 98/98 [00:00<00:00, 1693.88it/s]
100%|██████████| 98/98 [00:00<00:00, 1700.61it/s]


In addition, we perform the following preprocessing steps in the cell below:

* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary
* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence
* Pad or truncate the token lists to the specified max length
* Return mask lists that indicate paddings' positions
* Return token type id lists that indicate which sentence the tokens belong to

*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*

## Train and Predict

### Create Classifier

In [17]:
classifier = SequenceClassifier(
    model_name=MODEL_NAME, num_labels=num_labels, cache_dir=CACHE_DIR
)

I1107 22:23:48.046394 139623268476672 file_utils.py:296] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json not found in cache or force_download set to True, downloading to /tmp/tmpbj409coy
100%|██████████| 313/313 [00:00<00:00, 250442.04B/s]
I1107 22:23:48.082683 139623268476672 file_utils.py:309] copying /tmp/tmpbj409coy to cache at ./temp/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
I1107 22:23:48.083700 139623268476672 file_utils.py:313] creating metadata file for ./temp/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
I1107 22:23:48.084568 139623268476672 file_utils.py:322] removing temp file /tmp/tmpbj409coy
I1107 22:23:48.085379 139623268476672 configuration_utils.py:151] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from c

### Train Classifier

In [18]:
with Timer() as t:
    classifier.fit(
            train_dataset,
            num_epochs=NUM_EPOCHS,
            batch_size=BATCH_SIZE,
        )

print("Training time : {:.3f} hrs".format(t.interval / 3600))

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]
                                            it/s][A
Epoch:   0%|          | 0/1 [00:00<?, ?it/s]     
Iteration:   0%|          | 0/13 [00:00<?, ?it/s][A

Loss:0.037395



Iteration:   8%|▊         | 1/13 [00:01<00:19,  1.59s/it][A
Iteration:  15%|█▌        | 2/13 [00:02<00:16,  1.53s/it][A
Iteration:  23%|██▎       | 3/13 [00:04<00:14,  1.50s/it][A
Iteration:  31%|███       | 4/13 [00:05<00:13,  1.47s/it][A
Iteration:  38%|███▊      | 5/13 [00:07<00:11,  1.45s/it][A
Iteration:  46%|████▌     | 6/13 [00:08<00:10,  1.44s/it][A
Iteration:  54%|█████▍    | 7/13 [00:10<00:08,  1.43s/it][A
Iteration:  62%|██████▏   | 8/13 [00:11<00:07,  1.42s/it][A
Iteration:  69%|██████▉   | 9/13 [00:12<00:05,  1.41s/it][A
                                            04,  1.41s/it][A
Epoch:   0%|          | 0/1 [00:14<?, ?it/s]              
Iteration:  77%|███████▋  | 10/13 [00:14<00:04,  1.41s/it][A

Loss:0.034597



Iteration:  85%|████████▍ | 11/13 [00:15<00:02,  1.41s/it][A
Iteration:  92%|█████████▏| 12/13 [00:17<00:01,  1.41s/it][A
Epoch: 100%|██████████| 1/1 [00:17<00:00, 17.97s/it]3s/it][A

Training time : 0.005 hrs





### Predict on Test Data

In [19]:
with Timer() as t:
    predictions_matched = classifier.predict(dev_dataset_matched, batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Evaluating: 100%|██████████| 4/4 [00:01<00:00,  2.56it/s]

Prediction time : 0.000 hrs





In [20]:
with Timer() as t:
    predictions_mismatched = classifier.predict(dev_dataset_mismatched, batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Evaluating: 100%|██████████| 4/4 [00:01<00:00,  2.55it/s]

Prediction time : 0.000 hrs





## Evaluate

In [21]:
predictions_matched = label_encoder.inverse_transform(predictions_matched)
print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))

               precision    recall  f1-score   support

contradiction      0.000     0.000     0.000        34
   entailment      0.351     0.971     0.515        35
      neutral      0.000     0.000     0.000        29

    micro avg      0.347     0.347     0.347        98
    macro avg      0.117     0.324     0.172        98
 weighted avg      0.125     0.347     0.184        98



  'precision', 'predicted', average, warn_for)


In [22]:
predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)
print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))

               precision    recall  f1-score   support

contradiction      0.000     0.000     0.000        47
   entailment      0.278     1.000     0.435        27
      neutral      0.000     0.000     0.000        24

    micro avg      0.276     0.276     0.276        98
    macro avg      0.093     0.333     0.145        98
 weighted avg      0.077     0.276     0.120        98

