In [1]:
# !wget https://storage.googleapis.com/xlnet/released_models/cased_L-12_H-768_A-12.zip -O xlnet.zip
# !unzip xlnet.zip

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [3]:
import sentencepiece as spm
from prepro_utils import preprocess_text, encode_ids

sp_model = spm.SentencePieceProcessor()
sp_model.Load('xlnet_cased_L-12_H-768_A-12/spiece.model')

def tokenize_fn(text):
    text = preprocess_text(text, lower= False)
    return encode_ids(sp_model, text)

In [4]:
SEG_ID_A   = 0
SEG_ID_B   = 1
SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4

special_symbols = {
    "<unk>"  : 0,
    "<s>"    : 1,
    "</s>"   : 2,
    "<cls>"  : 3,
    "<sep>"  : 4,
    "<pad>"  : 5,
    "<mask>" : 6,
    "<eod>"  : 7,
    "<eop>"  : 8,
}

VOCAB_SIZE = 32000
UNK_ID = special_symbols["<unk>"]
CLS_ID = special_symbols["<cls>"]
SEP_ID = special_symbols["<sep>"]
MASK_ID = special_symbols["<mask>"]
EOD_ID = special_symbols["<eod>"]

In [5]:
# data from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/utils/helpers.py
# reuploaded to husein's S3
# !wget https://malaya-dataset.s3-ap-southeast-1.amazonaws.com/counts_1grams.txt

In [6]:
with open('counts_1grams.txt') as fopen:
    f = fopen.read().split('\n')[:-1]
    
words = {}
for l in f:
    w, c = l.split('\t')
    c = int(c)
    words[w] = c + words.get(w, 0)

In [7]:
# original from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/classes/spellcorrect.py
# improved it

import re
from collections import Counter

class SpellCorrector:
    """
    The SpellCorrector extends the functionality of the Peter Norvig's
    spell-corrector in http://norvig.com/spell-correct.html
    """

    def __init__(self):
        """
        :param corpus: the statistics from which corpus to use for the spell correction.
        """
        super().__init__()
        self.WORDS = words
        self.N = sum(self.WORDS.values())
        
    @staticmethod
    def tokens(text):
        return REGEX_TOKEN.findall(text.lower())

    def P(self, word):
        """
        Probability of `word`.
        """
        return self.WORDS[word] / self.N

    def most_probable(self, words):
        _known = self.known(words)
        if _known:
            return max(_known, key=self.P)
        else:
            return []

    @staticmethod
    def edit_step(word):
        """
        All edits that are one edit away from `word`.
        """
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        """
        All edits that are two edits away from `word`.
        """
        return (e2 for e1 in self.edit_step(word)
                for e2 in self.edit_step(e1))

    def known(self, words):
        """
        The subset of `words` that appear in the dictionary of WORDS.
        """
        return set(w for w in words if w in self.WORDS)

    def edit_candidates(self, word, assume_wrong=False, fast=True):
        """
        Generate possible spelling corrections for word.
        """

        if fast:
            ttt = self.known(self.edit_step(word)) or {word}
        else:
            ttt = self.known(self.edit_step(word)) or self.known(self.edits2(word)) or {word}
        
        ttt = self.known([word]) | ttt
        return list(ttt)

In [8]:
corrector = SpellCorrector()

In [9]:
possible_states = corrector.edit_candidates('wolking')
possible_states

['working', 'wolfing', 'walking', 'wilking', 'woking', 'wonking']

In [10]:
text = 'wolking is good for health'
text_mask = text.replace('wolking', '**mask**')
text_mask

'**mask** is good for health'

In [11]:
PADDING_TEXT = """
    The quick brown fox jumps over the lazy dog. A horrible, messy split second presents
    itself to the heart-shaped version as Scott is moved. The upcoming movie benefits at 
    the mental cost of ages 14 to 12. Nothing substantial is happened for almost 48 days. 
    When that happens, we lose our heart. <eod>
"""
padded_text = tokenize_fn(PADDING_TEXT)

def tokens_to_masked_ids(tokens, mask_ind):
    masked_tokens = tokens
    masked_tokens[mask_ind] = MASK_ID
    segment_id = [SEG_ID_A] * len(masked_tokens)
    input_mask = [0] * len(masked_tokens)
    perm_masks = np.zeros((1, len(masked_tokens)))
    perm_masks[0, mask_ind] = 1.0
    target_mappings = np.zeros((1, len(masked_tokens)))
    target_mappings[0, mask_ind] = 1.0
    
    return masked_tokens, segment_id, input_mask, perm_masks, target_mappings

In [12]:
replaced_masks = [text_mask.replace('**mask**', state) for state in possible_states]
replaced_masks

['working is good for health',
 'wolfing is good for health',
 'walking is good for health',
 'wilking is good for health',
 'woking is good for health',
 'wonking is good for health']

In [13]:
import xlnet
import tensorflow as tf
import model_utils

kwargs = dict(
      is_training=True,
      use_tpu=False,
      use_bfloat16=False,
      dropout=0.0,
      dropatt=0.0,
      init='normal',
      init_range=0.1,
      init_std=0.05,
      clamp_len=-1)

xlnet_parameters = xlnet.RunConfig(**kwargs)
xlnet_config = xlnet.XLNetConfig(json_path='xlnet_cased_L-12_H-768_A-12/xlnet_config.json')

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])






  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [14]:
class Model:
    def __init__(
        self,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        self.segment_ids = tf.placeholder(tf.int32, [None, None])
        self.input_masks = tf.placeholder(tf.float32, [None, None])
        self.perm_masks = tf.placeholder(tf.float32, [None, None, None])
        self.target_mappings = tf.placeholder(tf.float32, [None, None, None])
        
        xlnet_model = xlnet.XLNetModel(
            xlnet_config=xlnet_config,
            run_config=xlnet_parameters,
            input_ids=self.X,
            seg_ids=self.segment_ids,
            input_mask=self.input_masks,
            perm_mask = self.perm_masks,
            target_mapping = self.target_mappings
        )
        
        output = xlnet_model.get_sequence_output()
        self.output = output
        lookup_table = xlnet_model.get_embedding_table()

        initializer = xlnet_model.get_initializer()
        with tf.variable_scope('model', reuse = tf.AUTO_REUSE):
            with tf.variable_scope('lm_loss'):
                softmax_w = lookup_table
                softmax_b = tf.get_variable(
                    'bias',
                    [xlnet_config.n_token],
                    dtype = output.dtype,
                    initializer = tf.zeros_initializer(),
                )
                logits = tf.einsum('ibd,nd->ibn', output, softmax_w) + softmax_b
                self.logits = logits

In [15]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model()

sess.run(tf.global_variables_initializer())




INFO:tensorflow:memory input None
INFO:tensorflow:Use float type <dtype: 'float32'>

Instructions for updating:
Use keras.layers.dropout instead.
The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.


In [16]:
import collections
import re

def get_assignment_map_from_checkpoint(tvars, init_checkpoint):
    """Compute the union of the current variables and checkpoint variables."""
    assignment_map = {}
    initialized_variable_names = {}

    name_to_variable = collections.OrderedDict()
    for var in tvars:
        name = var.name
        m = re.match('^(.*):\\d+$', name)
        if m is not None:
            name = m.group(1)
        name_to_variable[name] = var

    init_vars = tf.train.list_variables(init_checkpoint)

    assignment_map = collections.OrderedDict()
    for x in init_vars:
        (name, var) = (x[0], x[1])
        if name not in name_to_variable:
            continue
        assignment_map[name] = name_to_variable[name]
        initialized_variable_names[name] = 1
        initialized_variable_names[name + ':0'] = 1

    return (assignment_map, initialized_variable_names)

In [17]:
tvars = tf.trainable_variables()
checkpoint = 'xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt'
assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, 
                                                                                checkpoint)

In [18]:
saver = tf.train.Saver(var_list = assignment_map)
saver.restore(sess, checkpoint)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from xlnet_cased_L-12_H-768_A-12/xlnet_model.ckpt


In [19]:
import numpy as np

In [20]:
tokens = tokenize_fn(replaced_masks[0])
input_ids = [tokens_to_masked_ids(tokens, i) for i in range(len(tokens))]
a = list(zip(*input_ids))
batch_x = np.array(a[0])
batch_segment = np.array(a[1])
batch_mask = np.array(a[2])
perm_masks = np.array(a[3])
target_mappings = np.array(a[4])
preds = sess.run(tf.nn.softmax(model.logits), 
                 feed_dict = {model.X: batch_x, 
                              model.segment_ids: batch_segment,
                              model.input_masks: batch_mask,
                              model.perm_masks: perm_masks,
                              model.target_mappings: target_mappings})
preds.shape

(5, 5, 32000)

In [21]:
def get_score(mask):
    tokens = tokenize_fn(mask)
    input_ids = [tokens_to_masked_ids(tokens, i) for i in range(len(tokens))]
    a = list(zip(*input_ids))
    batch_x = np.array(a[0])
    batch_segment = np.array(a[1])
    batch_mask = np.array(a[2])
    perm_masks = np.array(a[3])
    target_mappings = np.array(a[4])
    preds = sess.run(tf.nn.log_softmax(model.logits), 
                     feed_dict = {model.X: batch_x, 
                                  model.segment_ids: batch_segment,
                                  model.input_masks: batch_mask,
                                  model.perm_masks: perm_masks,
                                  model.target_mappings: target_mappings})
    tokens_ids = tokens
    preds = preds.astype('float64')
    return np.sum([preds[i, i, x] for i, x in enumerate(tokens_ids)])

In [22]:
scores = [get_score(mask) for mask in replaced_masks]
scores

[-100.41720771789551,
 -120.50064468383789,
 -100.41720771789551,
 -140.5840950012207,
 -140.5840950012207,
 -120.50064468383789]

In [23]:
scores = np.exp(np.array(scores).astype('float64'))

In [24]:
prob_scores = np.array(scores) / np.sum(scores)
prob_scores

array([4.99999999e-01, 9.48078180e-10, 4.99999999e-01, 1.79768047e-18,
       1.79768047e-18, 9.48078180e-10])

In [25]:
list(zip(possible_states, prob_scores))

[('working', 0.49999999905192183),
 ('wolfing', 9.480781803042756e-10),
 ('walking', 0.49999999905192183),
 ('wilking', 1.7976804735628786e-18),
 ('woking', 1.7976804735628786e-18),
 ('wonking', 9.480781803042756e-10)]