*Copyright (c) Microsoft Corporation. All rights reserved.*  

*Licensed under the MIT License.*

# Natural Language Inference on MultiNLI Dataset using BERT

# Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. 
> **Tip:** If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 

The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|4 **CPU**s, 14GB memory| ~ 15 minutes|
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 5 minutes|
|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| ~ 10.5 hours|
|False|4 NVIDIA Tesla V100 GPUs, 64GB GPU memory| ~ 2.5 hours|

If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [1]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

## Summary
In this notebook, we demostrate using [BERT](https://arxiv.org/abs/1810.04805) to perform Natural Language Inference (NLI). We use the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   
The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.
<img src="https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG">

In [1]:
import sys
import os
import random
import numpy as np
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

import torch

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.models.bert.sequence_classification import BERTSequenceClassifier
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.dataset.multinli import load_pandas_df
from utils_nlp.common.timer import Timer

## Configurations

In [2]:
TRAIN_DATA_USED_PERCENT = 1
DEV_DATA_USED_PERCENT = 1
NUM_EPOCHS = 2

if QUICK_RUN:
    TRAIN_DATA_USED_PERCENT = 0.001
    DEV_DATA_USED_PERCENT = 0.01
    NUM_EPOCHS = 1

if torch.cuda.is_available():
    BATCH_SIZE = 32
else:
    BATCH_SIZE = 16

# set random seeds
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
num_cuda_devices = torch.cuda.device_count()
if num_cuda_devices > 1:
    torch.cuda.manual_seed_all(RANDOM_SEED)

# model configurations
LANGUAGE = Language.ENGLISH
TO_LOWER = True
MAX_SEQ_LENGTH = 128

# optimizer configurations
LEARNING_RATE= 5e-5
WARMUP_PROPORTION= 0.1

# data configurations
TEXT_COL = "text"
LABEL_COL = "gold_label"

CACHE_DIR = "./temp"

## Load Data
The MultiNLI dataset comes with three subsets: train, dev_matched, dev_mismatched. The dev_matched dataset are from the same genres as the train dataset, while the dev_mismatched dataset are from genres not seen in the training dataset.   
The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split`.

In [3]:
train_df = load_pandas_df(local_cache_path=CACHE_DIR, file_split="train")
dev_df_matched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_matched")
dev_df_mismatched = load_pandas_df(local_cache_path=CACHE_DIR, file_split="dev_mismatched")

In [4]:
dev_df_matched = dev_df_matched.loc[dev_df_matched['gold_label'] != '-']
dev_df_mismatched = dev_df_mismatched.loc[dev_df_mismatched['gold_label'] != '-']

In [5]:
print("Training dataset size: {}".format(train_df.shape[0]))
print("Development (matched) dataset size: {}".format(dev_df_matched.shape[0]))
print("Development (mismatched) dataset size: {}".format(dev_df_mismatched.shape[0]))
print()
print(train_df[['gold_label', 'sentence1', 'sentence2']].head())

Training dataset size: 392702
Development (matched) dataset size: 9815
Development (mismatched) dataset size: 9832

   gold_label                                          sentence1  \
0     neutral  Conceptually cream skimming has two basic dime...   
1  entailment  you know during the season and i guess at at y...   
2  entailment  One of our number will carry out your instruct...   
3  entailment  How do you know? All this is their information...   
4     neutral  yeah i tell you what though if you go price so...   

                                           sentence2  
0  Product and geography are what make cream skim...  
1  You lose the things to the following level if ...  
2  A member of my team will execute your orders w...  
3                  This information belongs to them.  
4           The tennis shoes have a range of prices.  


Concatenate the first and second sentences to form the input text.

In [6]:
train_df[TEXT_COL] = list(zip(train_df['sentence1'], train_df['sentence2']))
dev_df_matched[TEXT_COL] = list(zip(dev_df_matched['sentence1'], dev_df_matched['sentence2']))
dev_df_mismatched[TEXT_COL] = list(zip(dev_df_mismatched['sentence1'], dev_df_mismatched['sentence2']))
train_df[[TEXT_COL, LABEL_COL]].head()

Unnamed: 0,text,gold_label
0,(Conceptually cream skimming has two basic dim...,neutral
1,(you know during the season and i guess at at ...,entailment
2,(One of our number will carry out your instruc...,entailment
3,(How do you know? All this is their informatio...,entailment
4,(yeah i tell you what though if you go price s...,neutral


In [7]:
train_df = train_df.sample(frac=TRAIN_DATA_USED_PERCENT).reset_index(drop=True)
dev_df_matched = dev_df_matched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)
dev_df_mismatched = dev_df_mismatched.sample(frac=DEV_DATA_USED_PERCENT).reset_index(drop=True)

## Tokenize and Preprocess
Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets.

In [8]:
tokenizer= Tokenizer(LANGUAGE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)

train_tokens = tokenizer.tokenize(train_df[TEXT_COL])
dev_matched_tokens = tokenizer.tokenize(dev_df_matched[TEXT_COL])
dev_mismatched_tokens = tokenizer.tokenize(dev_df_mismatched[TEXT_COL])

100%|██████████| 392702/392702 [03:25<00:00, 1907.47it/s]
100%|██████████| 9815/9815 [00:05<00:00, 1961.13it/s]
100%|██████████| 9832/9832 [00:05<00:00, 1837.42it/s]


In addition, we perform the following preprocessing steps in the cell below:

* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary
* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence
* Pad or truncate the token lists to the specified max length
* Return mask lists that indicate paddings' positions
* Return token type id lists that indicate which sentence the tokens belong to

*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*

In [9]:
train_token_ids, train_input_mask, train_token_type_ids = \
    tokenizer.preprocess_classification_tokens(train_tokens, max_len=MAX_SEQ_LENGTH)
dev_matched_token_ids, dev_matched_input_mask, dev_matched_token_type_ids = \
    tokenizer.preprocess_classification_tokens(dev_matched_tokens, max_len=MAX_SEQ_LENGTH)
dev_mismatched_token_ids, dev_mismatched_input_mask, dev_mismatched_token_type_ids = \
    tokenizer.preprocess_classification_tokens(dev_mismatched_tokens, max_len=MAX_SEQ_LENGTH)

In [10]:
label_encoder = LabelEncoder()
train_labels = label_encoder.fit_transform(train_df[LABEL_COL])
num_labels = len(np.unique(train_labels))

## Train and Predict

### Create Classifier

In [11]:
classifier = BERTSequenceClassifier(language=LANGUAGE,
                                    num_labels=num_labels,
                                    cache_dir=CACHE_DIR)

### Train Classifier

In [12]:
with Timer() as t:
    classifier.fit(token_ids=train_token_ids,
                   input_mask=train_input_mask,
                   token_type_ids=train_token_type_ids,
                   labels=train_labels,
                   num_epochs=NUM_EPOCHS,
                   batch_size=BATCH_SIZE,
                   lr=LEARNING_RATE,
                   warmup_proportion=WARMUP_PROPORTION)
print("Training time : {:.3f} hrs".format(t.interval / 3600))

Iteration:   0%|          | 1/12272 [00:10<35:06:53, 10.30s/it]

epoch:1/2; batch:1->1228/12272; average training loss:1.199178


Iteration:  10%|█         | 1229/12272 [07:20<1:03:16,  2.91it/s]

epoch:1/2; batch:1229->2456/12272; average training loss:0.783637


Iteration:  20%|██        | 2457/12272 [14:28<55:44,  2.93it/s]  

epoch:1/2; batch:2457->3684/12272; average training loss:0.692243


Iteration:  30%|███       | 3685/12272 [21:37<48:36,  2.94it/s]  

epoch:1/2; batch:3685->4912/12272; average training loss:0.653206


Iteration:  40%|████      | 4913/12272 [28:45<41:36,  2.95it/s]  

epoch:1/2; batch:4913->6140/12272; average training loss:0.625751


Iteration:  50%|█████     | 6141/12272 [35:54<34:44,  2.94it/s]  

epoch:1/2; batch:6141->7368/12272; average training loss:0.605123


Iteration:  60%|██████    | 7369/12272 [42:58<27:46,  2.94it/s]  

epoch:1/2; batch:7369->8596/12272; average training loss:0.590521


Iteration:  70%|███████   | 8597/12272 [50:07<20:52,  2.93it/s]  

epoch:1/2; batch:8597->9824/12272; average training loss:0.577829


Iteration:  80%|████████  | 9825/12272 [57:14<13:46,  2.96it/s]  

epoch:1/2; batch:9825->11052/12272; average training loss:0.566418


Iteration:  90%|█████████ | 11053/12272 [1:04:20<06:53,  2.95it/s]

epoch:1/2; batch:11053->12272/12272; average training loss:0.556558


Iteration: 100%|██████████| 12272/12272 [1:11:21<00:00,  2.88it/s]
Iteration:   0%|          | 1/12272 [00:00<1:12:29,  2.82it/s]

epoch:2/2; batch:1->1228/12272; average training loss:0.319802


Iteration:  10%|█         | 1229/12272 [07:09<1:02:29,  2.95it/s]

epoch:2/2; batch:1229->2456/12272; average training loss:0.331876


Iteration:  20%|██        | 2457/12272 [14:15<55:22,  2.95it/s]  

epoch:2/2; batch:2457->3684/12272; average training loss:0.333463


Iteration:  30%|███       | 3685/12272 [21:21<48:41,  2.94it/s]  

epoch:2/2; batch:3685->4912/12272; average training loss:0.331817


Iteration:  40%|████      | 4913/12272 [28:25<41:26,  2.96it/s]  

epoch:2/2; batch:4913->6140/12272; average training loss:0.327940


Iteration:  50%|█████     | 6141/12272 [35:31<34:34,  2.96it/s]  

epoch:2/2; batch:6141->7368/12272; average training loss:0.325802


Iteration:  60%|██████    | 7369/12272 [42:36<27:48,  2.94it/s]  

epoch:2/2; batch:7369->8596/12272; average training loss:0.324641


Iteration:  70%|███████   | 8597/12272 [49:42<20:53,  2.93it/s]  

epoch:2/2; batch:8597->9824/12272; average training loss:0.322036


Iteration:  80%|████████  | 9825/12272 [56:44<13:50,  2.95it/s]

epoch:2/2; batch:9825->11052/12272; average training loss:0.321205


Iteration:  90%|█████████ | 11053/12272 [1:03:49<06:54,  2.94it/s]

epoch:2/2; batch:11053->12272/12272; average training loss:0.319237


Iteration: 100%|██████████| 12272/12272 [1:10:52<00:00,  2.94it/s]


Training time : 2.374 hrs


### Predict on Test Data

In [13]:
with Timer() as t:
    predictions_matched = classifier.predict(token_ids=dev_matched_token_ids,
                                             input_mask=dev_matched_input_mask,
                                             token_type_ids=dev_matched_token_type_ids,
                                             batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Iteration: 100%|██████████| 307/307 [00:40<00:00,  8.15it/s]

Prediction time : 0.011 hrs





In [14]:
with Timer() as t:
    predictions_mismatched = classifier.predict(token_ids=dev_mismatched_token_ids,
                                                input_mask=dev_mismatched_input_mask,
                                                token_type_ids=dev_mismatched_token_type_ids,
                                                batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Iteration: 100%|██████████| 308/308 [00:38<00:00,  8.30it/s]

Prediction time : 0.011 hrs





## Evaluate

In [15]:
predictions_matched = label_encoder.inverse_transform(predictions_matched)
print(classification_report(dev_df_matched[LABEL_COL], predictions_matched, digits=3))

               precision    recall  f1-score   support

contradiction      0.848     0.865     0.857      3213
   entailment      0.894     0.828     0.860      3479
      neutral      0.783     0.831     0.806      3123

    micro avg      0.841     0.841     0.841      9815
    macro avg      0.842     0.841     0.841      9815
 weighted avg      0.844     0.841     0.842      9815



In [16]:
predictions_mismatched = label_encoder.inverse_transform(predictions_mismatched)
print(classification_report(dev_df_mismatched[LABEL_COL], predictions_mismatched, digits=3))

               precision    recall  f1-score   support

contradiction      0.862     0.863     0.863      3240
   entailment      0.878     0.853     0.865      3463
      neutral      0.791     0.815     0.803      3129

    micro avg      0.844     0.844     0.844      9832
    macro avg      0.844     0.844     0.844      9832
 weighted avg      0.845     0.844     0.845      9832

