In [1]:
%tensorflow_version 1.x 
import re
import os
import sys
import json
import nltk
import logging
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub 
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import Callback
from scipy.stats import spearmanr, pearsonr
from glob import glob
nltk.download('punkt')
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

TensorFlow 1.x selected.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [0]:
# # download the SNLI, MNLI,,, dataset and pretrained BERT model
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip
# !wget https://nlp.stanford.edu/projects/snli/snli_1.0.zip
# !unzip snli_1.0.zip
# !wget https://www.nyu.edu/projects/bowman/multinli/multinli_1.0.zip
# !unzip multinli_1.0.zip

In [0]:
# !git clone https://github.com/gaphex/bert_experimental
# !git clone https://github.com/brmson/dataset-sts

sys.path.insert(0, 'bert_experimental')
sys.path.insert(0, 'dataset-sts/pysts')

from bert_experimental.finetuning.text_preprocessing import build_preprocessor
from bert_experimental.finetuning.bert_layer import BertLayer
from bert_experimental.finetuning.modeling import BertConfig, BertModel, build_bert_module

from loader import load_sts, load_sick2014

In [0]:
from collections import Counter, defaultdict

def load_snli(fpaths):
  sa, sb, lb = [], [], [] 
  fpaths = np.atleast_1d(fpaths)
  for fpath in fpaths:
    with open(fpath) as fi:
      for line in fi:
        sample = json.loads(line)
        sa.append(sample['sentence1'])
        sb.append(sample['sentence2'])
        lb.append(sample['gold_label'])
  return sa, sb, lb

def prepare_snli(sa, sb, lb):
  classes = {'entailment', 'contradiction'}
  anc_to_pairs = defaultdict(list)
  filtered = {}
  skipped = 0
  anchor_id = 0
  for xa, xb, y in zip(sa, sb, lb):
    anc_to_pairs[xa].append((xb, y))
  
  for anchor, payload in anc_to_pairs.items():
    filtered[anchor_id] = defaultdict(list)
    filtered[anchor_id]['anchor'].append(anchor)
    labels = set([t[1] for t in payload])
    if len(labels&classes) == len(classes):
      for text, label in payload:
        filtered[anchor_id][label].append(text)
      anchor_id += 1
    else:
      skipped += 1
  print('loaded:{}\n Skipped:{}'.format(anchor_id, skipped))
  return filtered

In [5]:
train_data = ["./snli_1.0/snli_1.0_train.jsonl", "./multinli_1.0/multinli_1.0_train.jsonl"]
test_data = ["./snli_1.0/snli_1.0_test.jsonl", "./multinli_1.0/multinli_1.0_dev_matched.jsonl"]

tr_a, tr_b, tr_l = load_snli(train_data)
ts_a, ts_b, ts_l = load_snli(test_data)

fd_tr = prepare_snli(tr_a, tr_b, tr_l)
fd_ts = prepare_snli(ts_a, ts_b, ts_l)

loaded:277277
 Skipped:1603
loaded:5853
 Skipped:804


For training the model we will sample triplets, consisting of an anchor, a positive 
sample and a negative sample.
To handle complex batch generation logic we use the following code:

In [0]:
class TripletGenerator:
    def __init__(self, datadict, hard_frac = 0.5, batch_size=256):
        self.datadict = datadict
        self._anchor_idx = np.array(list(self.datadict.keys()))
        self._hard_frac = hard_frac
        self._generator = self.generate_batch(batch_size)

    def generate_batch(self, size):
        while True:

            hards = int(size*self._hard_frac)
            anchor_ids = np.array(np.random.choice(self._anchor_idx, size, replace=False))

            anchors = self.get_anchors(anchor_ids)
            positives = self.get_positives(anchor_ids)
            negatives = np.hstack([self.get_hard_negatives(anchor_ids[:hards]),
                                   self.get_random_negatives(anchor_ids[hards:])])
            labels = np.ones((size,1))

            assert len(anchors) == len(positives) == len(negatives) == len(labels) == size

            yield [anchors, positives, negatives], labels
            
    def get_anchors(self, anchor_ids):
        classes = ['anchor']
        samples = self.get_samples_from_ids(anchor_ids, classes)
        return samples
    
    def get_positives(self, anchor_ids):
        classes = ['entailment']
        samples = self.get_samples_from_ids(anchor_ids, classes)
        return samples

    def get_hard_negatives(self, anchor_ids):
        classes = ['contradiction']
        samples = self.get_samples_from_ids(anchor_ids, classes)
        return samples

    def get_random_negatives(self, anchor_ids):
        samples = []
        classes = ['contradiction', 'neutral','entailment']
        for anchor_id in anchor_ids:

            other_anchor_id = self.get_random(self._anchor_idx, anchor_id)
            avail_classes = list(set(self.datadict[other_anchor_id].keys()) & set(classes))
            sample_class = self.get_random(avail_classes)
            sample = self.get_random(self.datadict[other_anchor_id][sample_class])
            samples.append(sample)
        samples = np.array(samples)
        return samples
    
    def get_samples_from_ids(self, anchor_ids, classes):
        samples = []
        for anchor_id in anchor_ids:
            sample_class = self.get_random(classes)
            sample = self.get_random(self.datadict[anchor_id][sample_class])
            samples.append(sample)
        samples = np.array(samples)
        return samples

    @staticmethod
    def get_random(seq, exc=None):
        if len(seq) == 1:
            return seq[0]
                                      
        selected = None
        while selected is None or selected == exc:
            selected = np.random.choice(seq)
        return selected

Batch anchor IDs are selected randomly from all available IDs.
Anchor samples are retrieved from anchor samples of their IDs.
Positive samples are retrieved from entailment samples of their IDs.
Negative samples are retrieved from contradiction samples of their IDs. These may be considered hard negative samples, because they are often semantically similar to their anchors. To reduce overfitting we mix them with random negative samples retrieved from other, random ID.

We can frame the problem of learning a measure of sentence similarity as a ranking problem. Suppose we have a corpus of k paraphrase sentence pairs x and y and want to learn a function that
estimates if y is a paraphrase(解释;释义;意译) of x or not.

For some x we have a single positive sample y and k-1 negative samples y_k. This probability distirbution can be written as 

p(y|x) = \frac{P(x,y)}{\sum P(x, y_k}

The joint probability of P(x,y) is estimated using a scoring function, S

P(x,y) \approx e^{S(x,y)}

We will be minimizing the negative log probability of our data,
So, for a batch of K triplets for the loss we can write down

L(x,y,\theta) = \frac{1}{K} \sum log(P(y_i|x_i)) \approx 
\frac{1}{K} \sum < log \sum e^{S} - S  >


In [0]:
def softmax_loss(vectors):
  anc, pos, neg = vectors
  pos_sim = tf.reduce_sum((anc*pos), axis=-1, keepdims=True)
  neg_mul = tf.matmul(anc, neg, transpose_b=True)
  neg_sim = tf.log(tf.reduce_sum( tf.exp(neg_mul), axis=-1, keepdims=True  ))
  loss = tf.nn.relu(neg_sim - pos_sim)
  return loss

In [0]:
BERT_DIR = "/content/uncased_L-12_H-768_A-12/" #@param {type:"string"}

build_bert_module(BERT_DIR+"bert_config.json",
                  BERT_DIR+"vocab.txt",
                  BERT_DIR+"bert_model.ckpt", 
                  "bert_module")

The model has three inputs for the anchor, postive and negative samples. A BERT layer with a mean pooling operation is used as a shared text encoder. Text preprocessing is handled automatically by the layer. 

For convenience, we create 3 keras models: enc_model for encoding sentences, sim_model for compute similarity between sentence pairs and trn_model for training. All models use shared weights

In [0]:
def dot_product(tensor_pair):
  u, v = tensor_pair
  return tf.reduce_sum( (u * v), axis=-1, keepdims=True )

def consine_similarity(tensor_pair):
  u, v = tensor_pair
  u = tf.math.l2_normalize(u, axis=-1)
  v = tf.math.l2_normalize(v, axis=-1)
  return tf.reduce_sum(( u* v), axis=-1, keepdims=True) 

def mean_loss(y_true, y_pred):
  mean_pred = tf.reduce_mean(y_pred - 0 * y_true)
  return mean_pred

def build_model(module_path, seq_len=24, tune_lr=6, loss=softmax_loss):
  inp_anc = tf.keras.Input(shape=(1,), dtype=tf.string, name='input_anchor')
  inp_pos = tf.keras.Input(shape=(1,), dtype=tf.string, name='input_pos')
  inp_neg = tf.keras.Input(shape=(1,), dtype=tf.string, name='input_neg')
  sent_encoder = BertLayer(module_path, seq_len, n_tune_layers=tune_lr, do_preprocessing=True,
                           verbose=False, pooling='mean', trainable=True, tune_embeddings=False)
  c = 0.5 # avoid Nan loss 
  anc_enc = c * sent_encoder(inp_anc)
  pos_enc = c * sent_encoder(inp_pos)
  neg_enc = c * sent_encoder(inp_neg)

  loss = tf.keras.layers.Lambda(loss, name='loss')([anc_enc, pos_enc, neg_enc])
  sim = tf.keras.layers.Lambda(consine_similarity, name='sim')([anc_enc, pos_enc])

  trn_model = tf.keras.models.Model(inputs=[inp_anc, inp_pos, inp_neg], outputs=[loss])
  enc_model = tf.keras.models.Model(inputs=inp_anc, outputs=[anc_enc])
  sim_model = tf.keras.models.Model(inputs=[inp_anc, inp_pos], outputs=[sim])

  trn_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),loss=mean_loss, metrics=[])
  trn_model.summary()
  mdict = {
        "enc_model": enc_model,
        "sim_model": sim_model,
        "trn_model": trn_model
    }
  return mdict

In [0]:
class RankCorrCallback(Callback):

    def __init__(self, loader, filepaths, name=None, verbose=False,
                 sim_model=None, savemodel=None, savepath=None):

        self.savemodel = savemodel
        self.savepath = savepath
        self.sim_model = sim_model
        self.loader = loader
        self.verbose = verbose
        self.name = name

        self.samples, self.labels = self.load_datasets(filepaths)
        self.best = defaultdict(int)
        super(RankCorrCallback, self).__init__()

    def load_datasets(self, filepaths):
        _xa, _xb, _y = [], [], [] 
        for filepath in filepaths:
            sa, sb, lb = self.loader(filepath)
            sa = self.join_by_whitespace(sa)
            sb = self.join_by_whitespace(sb)
            _xa += sa
            _xb += sb
            _y += list(lb)
        return [_xa, _xb], _y
            
    @staticmethod
    def join_by_whitespace(list_of_str):
        return [" ".join(s) for s in list_of_str]

    def on_epoch_begin(self, epoch, logs=None):

        pred = self.sim_model.predict(self.samples, batch_size=128, 
                                      verbose=self.verbose).reshape(-1,)

        for metric, func in [("spearman_r", spearmanr),("pearson_r", pearsonr)]:
          coef, _ = func(self.labels, pred)
          coef = np.round(coef, 4)

          metric_name = f"{self.name}_{metric}"
          message = f"{metric_name} = {coef}"
          if coef > self.best[metric_name]:
            self.best[metric_name] = coef
            message = "*** New best: " + message
            if self.savemodel and self.savepath and metric == "spearman_r":
                self.savemodel.save_weights(self.savepath)

          print(message)

    def on_train_end(self, logs=None):
        self.on_epoch_begin(None)

In [11]:
model_dict = build_model(module_path="bert_module", tune_lr=4, loss=softmax_loss)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_anchor (InputLayer)       [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_pos (InputLayer)          [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_neg (InputLayer)          [(None, 1)]          0                                            
__________________________________________________________________________________________________
bert_layer (BertLayer)          (None, 768)          109482240   input_anchor[0][0]               
                                                                 input_pos[0][0]              

In [0]:
HFRAC = 0.5
BSIZE = 200

enc_model = model_dict["enc_model"]
sim_model = model_dict["sim_model"]
trn_model = model_dict["trn_model"]

tr_gen = TripletGenerator(fd_tr, hard_frac=HFRAC, batch_size=BSIZE)
ts_gen = TripletGenerator(fd_ts, hard_frac=HFRAC, batch_size=BSIZE)

clb_sts = RankCorrCallback(load_sts, glob("./dataset-sts/data/sts/semeval-sts/all/*test*.tsv"), name='STS',
                               sim_model=sim_model, savemodel=enc_model, savepath="encoder_en.h5")
clb_sick = RankCorrCallback(load_sick2014, glob("./dataset-sts/data/sts/sick2014/SICK_test_annotated.txt"), 
                                name='SICK', sim_model=sim_model)

callbacks = [clb_sts, clb_sick]

In [0]:
trn_model.fit_generator(
    tr_gen._generator, validation_data=ts_gen._generator,
    steps_per_epoch=256, validation_steps=32, epochs=10, callbacks=callbacks)

*** New best: STS_spearman_r = 0.5418
*** New best: STS_pearson_r = 0.5475
*** New best: SICK_spearman_r = 0.5799
*** New best: SICK_pearson_r = 0.6069
Epoch 1/10
 16/256 [>.............................] - ETA: 9:56 - loss: 2.0933 

In [0]:
from tensorflow.keras.utils import plot_model
plot_model(trn_model)