# Multi-lingual Inference on XNLI Dataset using BERT

## Summary
In this notebook, we demostrate using the [Multi-lingual BERT model](https://github.com/google-research/bert/blob/master/multilingual.md) to do language inference in Chinese and Hindi. We use the [XNLI](https://github.com/facebookresearch/XNLI) dataset and the task is to classify sentence pairs into three classes: contradiction, entailment, and neutral.   
The figure below shows how [BERT](https://arxiv.org/abs/1810.04805) classifies sentence pairs. It concatenates the tokens in each sentence pairs and separates the sentences by the [SEP] token. A [CLS] token is prepended to the token list and used as the aggregate sequence representation for the classification task.
<img src="https://nlpbp.blob.core.windows.net/images/bert_two_sentence.PNG">

In [1]:
import sys
import os
import random
import numpy as np
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder

import torch

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.bert.sequence_classification import BERTSequenceClassifier
from utils_nlp.bert.common import Language, Tokenizer
from utils_nlp.dataset.xnli import load_pandas_df

In [2]:
# set random seeds
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
num_cuda_devices = torch.cuda.device_count()
if num_cuda_devices > 1:
    torch.cuda.manual_seed_all(RANDOM_SEED)

# model configurations
LANGUAGE_CHINESE = Language.CHINESE
LANGUAGE_MULTI = Language.MULTILINGUAL
TO_LOWER = True
MAX_SEQ_LENGTH = 128

# training configurations
NUM_GPUS = 2
BATCH_SIZE = 32
NUM_EPOCHS = 2

# optimizer configurations
LEARNING_RATE= 5e-5
WARMUP_PROPORTION= 0.1

# data configurations
TEXT_COL = "text"
LABEL_COL = "label"

CACHE_DIR = "./temp"

## Load Data
The XNLI dataset comes in two zip files:  
* XNLI-1.0.zip: dev and test datasets in 15 languages. The original English data was translated into other languages by human translators. 
* XNLI-MT-1.0.zip: training dataset in 15 languages. This dataset is machine translations of the [MultiNLI](https://www.nyu.edu/projects/bowman/multinli/) dataset. It also contains English translations of the dev and test datasets, but not used in this notebook.  

The `load_pandas_df` function downloads and extracts the zip files if they don't already exist in `local_cache_path` and returns the data subset specified by `file_split` and `language`

In [3]:
train_df_chinese = load_pandas_df(local_cache_path="./", file_split="train", language="zh")
dev_df_chinese = load_pandas_df(local_cache_path="./", file_split="dev", language="zh")
test_df_chinese = load_pandas_df(local_cache_path="./", file_split="test", language="zh")

train_df_hindi = load_pandas_df(local_cache_path="./", file_split="train", language="hi")
dev_df_hindi = load_pandas_df(local_cache_path="./", file_split="dev", language="hi")
test_df_hindi = load_pandas_df(local_cache_path="./", file_split="test", language="hi")

In [4]:
print("Chinese training dataset size: {}".format(train_df_chinese.shape[0]))
print("Chinese dev dataset size: {}".format(dev_df_chinese.shape[0]))
print("Chinese test dataset size: {}".format(test_df_chinese.shape[0]))
print()
print("Hindi training dataset size: {}".format(train_df_hindi.shape[0]))
print("Hindi dev dataset size: {}".format(dev_df_hindi.shape[0]))
print("Hindi test dataset size: {}".format(test_df_hindi.shape[0]))
print()
print(train_df_chinese.head())
print(train_df_hindi.head())

Chinese training dataset size: 392702
Chinese dev dataset size: 2490
Chinese test dataset size: 5010

Hindi training dataset size: 392702
Hindi dev dataset size: 2490
Hindi test dataset size: 5010

                                                text       label
0  (从 概念 上 看 , 奶油 收入 有 两 个 基本 方面 产品 和 地理 ., 产品 和 ...     neutral
1  (你 知道 在 这个 季节 , 我 猜 在 你 的 水平 你 把 他们 丢到 下 一个 水平...  entailment
2  (我们 的 一个 号码 会 非常 详细 地 执行 你 的 指示, 我 团队 的 一个 成员 ...  entailment
3   (你 怎么 知道 的 ? 所有 这些 都 是 他们 的 信息 ., 这些 信息 属于 他们 .)  entailment
4  (是 啊 , 我 告诉 你 , 如果 你 去 买 一些 网球鞋 , 我 可以 看到 为什么 ...     neutral
                                                text       label
0  (Conceptually क ् रीम एंजलिस में दो मूल आयाम ह...     neutral
1  (आप मौसम के दौरान जानते हैं और मैं अपने स ् तर...  entailment
2  (हमारे एक नंबर में से एक आपके निर ् देशों को म...  entailment
3  (आप कैसे जानते हैं ? ये सब उनकी जानकारी फिर से...  entailment
4  (हाँ मैं आपको बताता हूँ कि अगर आप उन टेनिस जूत...     neutral


In [5]:
train_df_chinese = train_df_chinese.loc[:1000]
dev_df_chinese = dev_df_chinese.loc[:1000]
test_df_chinese = test_df_chinese.loc[:1000]

train_df_hindi = train_df_hindi.loc[:1000]
dev_df_hindi = dev_df_hindi.loc[:1000]
test_df_hindi = test_df_hindi.loc[:1000]

Note that the texts are convereted to unicode which can be processed by BERT models. 

## Language Inference on Chinese
For Chinese dataset, we use the `bert-base-chinese` model which was pretrained on Chinese dataset only. The `bert-base-multilingual-cased` model can also be used on Chinese, but the accuracy is 3% lower.

### Tokenize and Preprocess
Before training, we tokenize the sentence texts and convert them to lists of tokens. The following steps instantiate a BERT tokenizer given the language, and tokenize the text of the training and testing sets.

In [6]:
tokenizer_chinese = Tokenizer(LANGUAGE_CHINESE, to_lower=TO_LOWER, cache_dir=CACHE_DIR)

train_tokens_chinese = tokenizer_chinese.tokenize(train_df_chinese[TEXT_COL])
test_tokens_chinese= tokenizer_chinese.tokenize(test_df_chinese[TEXT_COL])

100%|██████████| 1001/1001 [00:00<00:00, 2612.95it/s]
100%|██████████| 1001/1001 [00:00<00:00, 3663.45it/s]


In addition, we perform the following preprocessing steps in the cell below:

* Convert the tokens into token indices corresponding to the BERT tokenizer's vocabulary
* Add the special tokens [CLS] and [SEP] to mark the beginning and end of a sentence
* Pad or truncate the token lists to the specified max length
* Return mask lists that indicate paddings' positions
* Return token type id lists that indicate which sentence the tokens belong to

*See the original [implementation](https://github.com/google-research/bert/blob/master/run_classifier.py) for more information on BERT's input format.*

In [7]:
train_token_ids_chinese, train_input_mask_chinese, train_token_type_ids_chinese = \
    tokenizer_chinese.preprocess_classification_tokens(train_tokens_chinese, max_len=MAX_SEQ_LENGTH)
test_token_ids_chinese, test_input_mask_chinese, test_token_type_ids_chinese = \
    tokenizer_chinese.preprocess_classification_tokens(test_tokens_chinese, max_len=MAX_SEQ_LENGTH)

In [8]:
label_encoder_chinese = LabelEncoder()
train_labels_chinese = label_encoder_chinese.fit_transform(train_df_chinese[LABEL_COL])
num_labels_chinese = len(np.unique(train_labels_chinese))

### Create Classifier

In [9]:
classifier_chinese = BERTSequenceClassifier(language=LANGUAGE_CHINESE,
                                            num_labels=num_labels_chinese,
                                            cache_dir=CACHE_DIR)

### Train Classifier

In [10]:
classifier_chinese.fit(token_ids=train_token_ids_chinese,
                       input_mask=train_input_mask_chinese,
                       token_type_ids=train_token_type_ids_chinese,
                       labels=train_labels_chinese,
                       num_gpus=NUM_GPUS,
                       num_epochs=NUM_EPOCHS,
                       batch_size=BATCH_SIZE,
                       lr=LEARNING_RATE,
                       warmup_proportion=WARMUP_PROPORTION)

epoch:1/2; batch:1->4/31; loss:1.273249
epoch:1/2; batch:5->8/31; loss:1.103003
epoch:1/2; batch:9->12/31; loss:1.107130
epoch:1/2; batch:13->16/31; loss:1.112338
epoch:1/2; batch:17->20/31; loss:1.334211
epoch:1/2; batch:21->24/31; loss:1.244677
epoch:1/2; batch:25->28/31; loss:1.146302
epoch:1/2; batch:29->31/31; loss:1.145210
epoch:2/2; batch:1->4/31; loss:1.084830
epoch:2/2; batch:5->8/31; loss:1.107789
epoch:2/2; batch:9->12/31; loss:1.125404
epoch:2/2; batch:13->16/31; loss:1.104571
epoch:2/2; batch:17->20/31; loss:1.115697
epoch:2/2; batch:21->24/31; loss:1.153866
epoch:2/2; batch:25->28/31; loss:1.093025
epoch:2/2; batch:29->31/31; loss:1.098436


### Predict on Test Data

In [11]:
predictions_chinese = classifier_chinese.predict(token_ids=test_token_ids_chinese,
                                                 input_mask=test_input_mask_chinese,
                                                 token_type_ids=test_token_type_ids_chinese,
                                                 batch_size=BATCH_SIZE)

  0%|          | 0/1001 [00:00<?, ?it/s]



1024it [00:14, 86.87it/s]                         


### Evaluate

In [12]:
predictions_chinese = label_encoder_chinese.inverse_transform(predictions_chinese)
print(classification_report(test_df_chinese[LABEL_COL], predictions_chinese))

               precision    recall  f1-score   support

contradiction       0.00      0.00      0.00       333
   entailment       0.33      1.00      0.50       334
      neutral       0.00      0.00      0.00       334

     accuracy                           0.33      1001
    macro avg       0.11      0.33      0.17      1001
 weighted avg       0.11      0.33      0.17      1001



  'precision', 'predicted', average, warn_for)


## Language Inference on Hindi
For Hindi and all other languages except Chinese, we use the `bert-base-multilingual-cased` model.  
The preprocesing, model training, and prediction steps are the same as on Chinese data, except for the underlying tokenizer and BERT model used

### Tokenize and Preprocess

In [13]:
tokenizer_multi = Tokenizer(LANGUAGE_MULTI, cache_dir=CACHE_DIR)

train_tokens_hindi = tokenizer_multi.tokenize(train_df_hindi[TEXT_COL])
test_tokens_hindi= tokenizer_multi.tokenize(test_df_hindi[TEXT_COL])

train_token_ids_hindi, train_input_mask_hindi, train_token_type_ids_hindi = \
    tokenizer_multi.preprocess_classification_tokens(train_tokens_hindi, max_len=MAX_SEQ_LENGTH)
test_token_ids_hindi, test_input_mask_hindi, test_token_type_ids_hindi = \
    tokenizer_multi.preprocess_classification_tokens(test_tokens_hindi, max_len=MAX_SEQ_LENGTH)

label_encoder_hindi = LabelEncoder()
train_labels_hindi = label_encoder_hindi.fit_transform(train_df_hindi[LABEL_COL])
num_labels_hindi = len(np.unique(train_labels_hindi))

100%|██████████| 1001/1001 [00:00<00:00, 1645.88it/s]
100%|██████████| 1001/1001 [00:00<00:00, 2262.68it/s]


### Create and Train Classifier

In [14]:
classifier_multi = BERTSequenceClassifier(language=LANGUAGE_MULTI,
                                          num_labels=num_labels_hindi,
                                          cache_dir=CACHE_DIR)
classifier_multi.fit(token_ids=train_token_ids_hindi,
                     input_mask=train_input_mask_hindi,
                     token_type_ids=train_token_type_ids_hindi,
                     labels=train_labels_hindi,
                     num_gpus=NUM_GPUS,
                     num_epochs=NUM_EPOCHS,
                     batch_size=BATCH_SIZE,
                     lr=LEARNING_RATE,
                     warmup_proportion=WARMUP_PROPORTION)

epoch:1/2; batch:1->4/31; loss:1.128533
epoch:1/2; batch:5->8/31; loss:1.139760
epoch:1/2; batch:9->12/31; loss:1.128057
epoch:1/2; batch:13->16/31; loss:1.163460
epoch:1/2; batch:17->20/31; loss:1.091910
epoch:1/2; batch:21->24/31; loss:1.198568
epoch:1/2; batch:25->28/31; loss:0.941484
epoch:1/2; batch:29->31/31; loss:1.049881
epoch:2/2; batch:1->4/31; loss:1.109279
epoch:2/2; batch:5->8/31; loss:1.075177
epoch:2/2; batch:9->12/31; loss:1.122685
epoch:2/2; batch:13->16/31; loss:1.124175
epoch:2/2; batch:17->20/31; loss:1.109364
epoch:2/2; batch:21->24/31; loss:1.052536
epoch:2/2; batch:25->28/31; loss:1.074721
epoch:2/2; batch:29->31/31; loss:1.132380


### Predict and Evaluate

In [15]:
predictions_hindi = classifier_multi.predict(token_ids=test_token_ids_hindi,
                                             input_mask=test_input_mask_hindi,
                                             token_type_ids=test_token_type_ids_hindi,
                                             batch_size=BATCH_SIZE)
predictions_hindi= label_encoder_hindi.inverse_transform(predictions_hindi)
print(classification_report(test_df_hindi[LABEL_COL], predictions_hindi))

  0%|          | 0/1001 [00:00<?, ?it/s]



1024it [00:14, 86.39it/s]                         

               precision    recall  f1-score   support

contradiction       0.33      1.00      0.50       333
   entailment       0.00      0.00      0.00       334
      neutral       0.00      0.00      0.00       334

     accuracy                           0.33      1001
    macro avg       0.11      0.33      0.17      1001
 weighted avg       0.11      0.33      0.17      1001




