# Multi-lingual Entailment on XNLI Dataset using BERT

### Required packages
* pytorch-pretrained-bert
* pandas
* seqeval
* unicode

In [1]:
import sys
import os
import random
import numpy as np
import csv
import six

import torch

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.bert.sequence_classification import BERTSequenceClassifier
from utils_nlp.bert.common import Language, Tokenizer

In [15]:
# set random seeds
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
num_cuda_devices = torch.cuda.device_count()
if num_cuda_devices > 1:
    torch.cuda.manual_seed_all(random_seed)

# model configurations
language = Language.CHINESE
do_lower_case = True
max_seq_length = 128

# training configurations
device="gpu"
batch_size = 32
num_train_epochs = 2

# optimizer configurations
learning_rate = 5e-5
config_file = "config_multilingual.yaml"
train_data_dir = "./data/XNLI-MT-1.0/XNLI-MT-1.0/"
dev_data_dir = "./data/XNLI-MT-1.0/XNLI-MT-1.0/"
cache_dir="."

## Preprocess Data

In [3]:
def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    elif six.PY2:
        if isinstance(text, str):
            return text.decode("utf-8", "ignore")
        elif isinstance(text, unicode):
            return text
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")
        
class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines

        
class XnliProcessor(DataProcessor):
  """Processor for the XNLI data set."""

  def __init__(self):
    self.language = "zh"

  def get_train_examples(self, data_dir):
    """See base class."""
    lines = self._read_tsv(
        os.path.join(data_dir, "multinli",
                     "multinli.train.%s.tsv" % self.language))
    text_list = []
    label_list = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      text_a = convert_to_unicode(line[0])
      text_b = convert_to_unicode(line[1])
      label = convert_to_unicode(line[2])
      if label == convert_to_unicode("contradictory"):
        label = convert_to_unicode("contradiction")
      text_list.append((text_a, text_b))
      label_list.append(label)
    return text_list, label_list

  def get_dev_examples(self, data_dir):
    """See base class."""
    lines = self._read_tsv(os.path.join(data_dir, "xnli", "xnli.dev.tsv"))
    text_list = []
    label_list = []
    for (i, line) in enumerate(lines):
      if i == 0:
        continue
      language = convert_to_unicode(line[0])
      if language != convert_to_unicode(self.language):
        continue
      text_a = convert_to_unicode(line[6])
      text_b = convert_to_unicode(line[7])
      label = convert_to_unicode(line[1])
    
      text_list.append((text_a, text_b))
      label_list.append(label)
    return text_list, label_list

  def get_labels(self):
    """See base class."""
    return ["contradiction", "entailment", "neutral"]

In [4]:
xnli_processor =  XnliProcessor()
train_text, train_labels = xnli_processor.get_train_examples(data_dir=train_data_dir)
dev_text, dev_labels= xnli_processor.get_dev_examples(data_dir=dev_data_dir)
label_list = xnli_processor.get_labels()

In [5]:
train_text = train_text[:1000]
train_labels = train_labels[:1000]
dev_text = dev_text[:1000]
dev_labels = dev_labels[:1000]

In [6]:
print(label_list)

['contradiction', 'entailment', 'neutral']


### Convert examples to features
The function `convert_examples_to_token_features` converts raw string data to numerical features, involving the following steps:
1. Tokenization
2. Convert tokens and labels to numerical values
3. Sequence padding or truncation

In [7]:
tokenizer = Tokenizer(language=language, 
                      to_lower=do_lower_case, 
                      cache_dir=cache_dir)

In [8]:
train_tokens = tokenizer.tokenize(train_text)
dev_tokens = tokenizer.tokenize(dev_text)

100%|██████████| 1000/1000 [00:00<00:00, 2916.60it/s]
100%|██████████| 1000/1000 [00:00<00:00, 3776.67it/s]


In [10]:
train_token_ids, train_input_mask, train_token_type_ids = \
    tokenizer.preprocess_classification_tokens(train_tokens, max_len=max_seq_length)
dev_token_ids, dev_input_mask, dev_token_type_ids = \
    tokenizer.preprocess_classification_tokens(dev_tokens, max_len=max_seq_length)

In [11]:
label_map = {label: i for i, label in enumerate(label_list)}
train_label_ids = [label_map[l] for l in train_labels]
dev_label_ids = [label_map[l] for l in dev_labels]

In [12]:
classifier = BERTSequenceClassifier(language=language,
                                    num_labels=len(label_list),
                                    cache_dir=cache_dir)

100%|██████████| 382072689/382072689 [00:07<00:00, 48295901.53B/s]


In [16]:
classifier.fit(token_ids=train_token_ids,
               input_mask=train_input_mask,
               token_type_ids=train_token_type_ids,
               labels=train_label_ids,
               num_gpus=2,
               num_epochs=num_train_epochs,
               batch_size=batch_size,
               lr=learning_rate,
               warmup_proportion=0.1)

epoch:1/2; batch:1->4/31; loss:1.304341
epoch:1/2; batch:5->8/31; loss:1.074271
epoch:1/2; batch:9->12/31; loss:1.063943
epoch:1/2; batch:13->16/31; loss:1.124226
epoch:1/2; batch:17->20/31; loss:1.100629
epoch:1/2; batch:21->24/31; loss:1.128772
epoch:1/2; batch:25->28/31; loss:1.216407
epoch:1/2; batch:29->32/31; loss:1.041168
epoch:2/2; batch:1->4/31; loss:1.054121
epoch:2/2; batch:5->8/31; loss:1.015953
epoch:2/2; batch:9->12/31; loss:1.072215
epoch:2/2; batch:13->16/31; loss:1.143950
epoch:2/2; batch:17->20/31; loss:0.927066
epoch:2/2; batch:21->24/31; loss:1.054619
epoch:2/2; batch:25->28/31; loss:0.914871
epoch:2/2; batch:29->32/31; loss:0.977393


In [17]:
predictions = classifier.predict(token_ids=dev_token_ids,
                                 input_mask=dev_input_mask,
                                 token_type_ids=dev_token_type_ids,
                                 batch_size=8)

  1%|          | 8/1000 [00:00<00:20, 49.08it/s]



100%|██████████| 1000/1000 [00:15<00:00, 65.77it/s]


In [18]:
print(len([l for l in predictions if l==0]))
print(len([l for l in predictions if l==1]))
print(len([l for l in predictions if l==2]))

480
520
0
