In [None]:
# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip
# !unzip cased_L-12_H-768_A-12.zip

In [None]:
import sentencepiece as spm
from prepro_utils import preprocess_text, encode_ids

sp_model = spm.SentencePieceProcessor()
sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')

In [13]:
from prepro_utils import preprocess_text, encode_ids

def tokenize_fn(text, sp_model):
    text = preprocess_text(text, lower = False)
    return encode_ids(sp_model, text)

In [4]:
SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]

In [10]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name_to_variable[name]
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [17]:
import xlnet as xlnet_lib
import tensorflow as tf
import numpy as np

class Model:
    def __init__(self, xlnet_config, tokenizer, checkpoint, pool_mode = 'last'):

        kwargs = dict(
            is_training = True,
            use_tpu = False,
            use_bfloat16 = False,
            dropout = 0.0,
            dropatt = 0.0,
            init = 'normal',
            init_range = 0.1,
            init_std = 0.05,
            clamp_len = -1,
        )

        xlnet_parameters = xlnet_lib.RunConfig(**kwargs)

        self._tokenizer = tokenizer
        _graph = tf.Graph()
        with _graph.as_default():
            self.X = tf.placeholder(tf.int32, [None, None])
            self.segment_ids = tf.placeholder(tf.int32, [None, None])
            self.input_masks = tf.placeholder(tf.float32, [None, None])

            xlnet_model = xlnet_lib.XLNetModel(
                xlnet_config = xlnet_config,
                run_config = xlnet_parameters,
                input_ids = tf.transpose(self.X, [1, 0]),
                seg_ids = tf.transpose(self.segment_ids, [1, 0]),
                input_mask = tf.transpose(self.input_masks, [1, 0]),
            )

            self.logits = xlnet_model.get_pooled_out(pool_mode, True)
            self._sess = tf.InteractiveSession()
            self._sess.run(tf.global_variables_initializer())
            tvars = tf.trainable_variables()
            assignment_map, _ = get_assignment_map_from_checkpoint(
                tvars, checkpoint
            )
            self._saver = tf.train.Saver(var_list = assignment_map)
            attentions = [
                n.name
                for n in tf.get_default_graph().as_graph_def().node
                if 'rel_attn/Softmax' in n.name
            ]
            g = tf.get_default_graph()
            self.attention_nodes = [
                g.get_tensor_by_name('%s:0' % (a)) for a in attentions
            ]

    def vectorize(self, strings):
        """
        Vectorize string inputs using bert attention.

        Parameters
        ----------
        strings : str / list of str

        Returns
        -------
        array: vectorized strings
        """

        if isinstance(strings, list):
            if not isinstance(strings[0], str):
                raise ValueError('input must be a list of strings or a string')
        else:
            if not isinstance(strings, str):
                raise ValueError('input must be a list of strings or a string')
        if isinstance(strings, str):
            strings = [strings]

        input_ids, input_masks, segment_ids, _ = xlnet_tokenization(
            self._tokenizer, strings
        )
        return self._sess.run(
            self.logits,
            feed_dict = {
                self.X: input_ids,
                self.segment_ids: segment_ids,
                self.input_masks: input_masks,
            },
        )

    def attention(self, strings, method = 'last', **kwargs):
        """
        Get attention string inputs from xlnet attention.

        Parameters
        ----------
        strings : str / list of str
        method : str, optional (default='last')
            Attention layer supported. Allowed values:

            * ``'last'`` - attention from last layer.
            * ``'first'`` - attention from first layer.
            * ``'mean'`` - average attentions from all layers.

        Returns
        -------
        array: attention
        """

        if isinstance(strings, list):
            if not isinstance(strings[0], str):
                raise ValueError('input must be a list of strings or a string')
        else:
            if not isinstance(strings, str):
                raise ValueError('input must be a list of strings or a string')
        if isinstance(strings, str):
            strings = [strings]

        method = method.lower()
        if method not in ['last', 'first', 'mean']:
            raise Exception(
                "method not supported, only support ['last', 'first', 'mean']"
            )

        input_ids, input_masks, segment_ids, s_tokens = xlnet_tokenization(
            self._tokenizer, strings
        )
        maxlen = max([len(s) for s in s_tokens])
        s_tokens = padding_sequence(s_tokens, maxlen, pad_int = '<cls>')
        attentions = self._sess.run(
            self.attention_nodes,
            feed_dict = {
                self.X: input_ids,
                self.segment_ids: segment_ids,
                self.input_masks: input_masks,
            },
        )

        if method == 'first':
            cls_attn = np.transpose(attentions[0][:, 0], (1, 0, 2))

        if method == 'last':
            cls_attn = np.transpose(attentions[-1][:, 0], (1, 0, 2))

        if method == 'mean':
            cls_attn = np.transpose(
                np.mean(attentions, axis = 0).mean(axis = 1), (1, 0, 2)
            )

        cls_attn = np.mean(cls_attn, axis = 1)
        total_weights = np.sum(cls_attn, axis = -1, keepdims = True)
        attn = cls_attn / total_weights
        output = []
        for i in range(attn.shape[0]):
            output.append(
                merge_sentencepiece_tokens(list(zip(s_tokens[i], attn[i])))
            )
        return output

In [7]:
def merge_sentencepiece_tokens(paired_tokens, weighted = True):
    new_paired_tokens = []
    n_tokens = len(paired_tokens)
    rejected = ['<cls>', '<sep>']

    i = 0

    while i < n_tokens:

        current_token, current_weight = paired_tokens[i]
        if not current_token.startswith('▁') and current_token not in rejected:
            previous_token, previous_weight = new_paired_tokens.pop()
            merged_token = previous_token
            merged_weight = [previous_weight]
            while (
                not current_token.startswith('▁')
                and current_token not in rejected
            ):
                merged_token = merged_token + current_token.replace('▁', '')
                merged_weight.append(current_weight)
                i = i + 1
                current_token, current_weight = paired_tokens[i]
            merged_weight = np.mean(merged_weight)
            new_paired_tokens.append((merged_token, merged_weight))

        else:
            new_paired_tokens.append((current_token, current_weight))
            i = i + 1

    words = [
        i[0].replace('▁', '')
        for i in new_paired_tokens
        if i[0] not in ['<cls>', '<sep>', '<pad>']
    ]
    weights = [
        i[1]
        for i in new_paired_tokens
        if i[0] not in ['<cls>', '<sep>', '<pad>']
    ]
    if weighted:
        weights = np.array(weights)
        weights = weights / np.sum(weights)
    return list(zip(words, weights))

In [8]:
def padding_sequence(seq, maxlen, padding = 'post', pad_int = 0):
    padded_seqs = []
    for s in seq:
        if padding == 'post':
            padded_seqs.append(s + [pad_int] * (maxlen - len(s)))
        if padding == 'pre':
            padded_seqs.append([pad_int] * (maxlen - len(s)) + s)
    return padded_seqs


def xlnet_tokenization(tokenizer, texts):
    input_ids, input_masks, segment_ids, s_tokens = [], [], [], []
    for text in texts:
        tokens_a = tokenize_fn(text, tokenizer)
        tokens = []
        segment_id = []
        for token in tokens_a:
            tokens.append(token)
            segment_id.append(SEG_ID_A)

        tokens.append(SEP_ID)
        segment_id.append(SEG_ID_A)
        tokens.append(CLS_ID)
        segment_id.append(SEG_ID_CLS)

        input_id = tokens
        input_mask = [0] * len(input_id)

        input_ids.append(input_id)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
        s_tokens.append([tokenizer.IdToPiece(i) for i in tokens])

    maxlen = max([len(i) for i in input_ids])
    input_ids = padding_sequence(input_ids, maxlen, padding = 'pre')
    input_masks = padding_sequence(
        input_masks, maxlen, padding = 'pre', pad_int = 1
    )
    segment_ids = padding_sequence(
        segment_ids, maxlen, padding = 'pre', pad_int = SEG_ID_PAD
    )

    return input_ids, input_masks, segment_ids, s_tokens

In [11]:
xlnet_config = xlnet_lib.XLNetConfig(json_path = 'xlnet_cased_L-12_H-768_A-12/xlnet_config.json')
xlnet_checkpoint = 'xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt'
model = Model(
xlnet_config, sp_model, xlnet_checkpoint, pool_mode = 'last'
)
model._saver.restore(model._sess, xlnet_checkpoint)

W0806 00:38:22.101047 140372522661696 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.


In [15]:
model.vectorize('i love u').shape

(1, 768)

In [18]:
model.attention('i love u sooo much!', method = 'first')

[[('i', 0.2029094),
  ('love', 0.17279622),
  ('u', 0.21883217),
  ('sooo', 0.19822793),
  ('much!', 0.20723426)]]

In [19]:
model.attention('i love u sooo much!', method = 'last')

[[('i', 0.5867717),
  ('love', 0.08966514),
  ('u', 0.23113684),
  ('sooo', 0.057987895),
  ('much!', 0.03443839)]]

In [20]:
model.attention('i love u sooo much!', method = 'mean')

[[('i', 0.22920714),
  ('love', 0.2319069),
  ('u', 0.18129185),
  ('sooo', 0.13547097),
  ('much!', 0.22212313)]]