In [1]:
# data from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/utils/helpers.py
# reuploaded to husein's S3
# !wget https://malaya-dataset.s3-ap-southeast-1.amazonaws.com/counts_1grams.txt

In [2]:
with open('counts_1grams.txt') as fopen:
    f = fopen.read().split('\n')[:-1]
    
words = {}
for l in f:
    w, c = l.split('\t')
    c = int(c)
    words[w] = c + words.get(w, 0)

In [3]:
# original from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/classes/spellcorrect.py
# improved it

import re
from collections import Counter

class SpellCorrector:
    """
    The SpellCorrector extends the functionality of the Peter Norvig's
    spell-corrector in http://norvig.com/spell-correct.html
    """

    def __init__(self):
        """
        :param corpus: the statistics from which corpus to use for the spell correction.
        """
        super().__init__()
        self.WORDS = words
        self.N = sum(self.WORDS.values())
        
    @staticmethod
    def tokens(text):
        return REGEX_TOKEN.findall(text.lower())

    def P(self, word):
        """
        Probability of `word`.
        """
        return self.WORDS[word] / self.N

    def most_probable(self, words):
        _known = self.known(words)
        if _known:
            return max(_known, key=self.P)
        else:
            return []

    @staticmethod
    def edit_step(word):
        """
        All edits that are one edit away from `word`.
        """
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        """
        All edits that are two edits away from `word`.
        """
        return (e2 for e1 in self.edit_step(word)
                for e2 in self.edit_step(e1))

    def known(self, words):
        """
        The subset of `words` that appear in the dictionary of WORDS.
        """
        return set(w for w in words if w in self.WORDS)

    def edit_candidates(self, word, assume_wrong=False, fast=True):
        """
        Generate possible spelling corrections for word.
        """

        if fast:
            ttt = self.known(self.edit_step(word)) or {word}
        else:
            ttt = self.known(self.edit_step(word)) or self.known(self.edits2(word)) or {word}
        
        ttt = self.known([word]) | ttt
        return list(ttt)

In [4]:
corrector = SpellCorrector()

In [5]:
possible_states = corrector.edit_candidates('eting')
possible_states

['etling',
 'etting',
 'ewing',
 'meting',
 'etang',
 'beting',
 'enting',
 'edting',
 'eing',
 'sting',
 'ting',
 'eying',
 'eting',
 'reting',
 'ering',
 'kting',
 'epting',
 'ebing',
 'geting',
 'etin',
 'ating',
 'eating',
 'elting',
 'eking',
 'eling']

In [6]:
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip

In [7]:
BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'
BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'
BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'

In [8]:
import bert
from bert import run_classifier
from bert import optimization
from bert import tokenization
from bert import modeling
import tensorflow as tf
import numpy as np

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])





  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [9]:
import unicodedata

def whitespace_tokenize(text):
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens

class BasicTokenizer(object):

    def __init__(self, do_lower_case=True, never_split=None):
        if never_split is None:
            never_split = []
        self.do_lower_case = do_lower_case
        self.never_split = never_split

    def tokenize(self, text, never_split=None):
        never_split = self.never_split + (never_split if never_split is not None else [])
        text = self._clean_text(text)
        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            if token not in never_split:
                if self.do_lower_case:
                    token = token.lower()
                token = self._run_strip_accents(token)
                split_tokens.extend(self._run_split_on_punc(token))
            else:
                split_tokens.append(token)

        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text, never_split=None):
        if never_split is not None and text in never_split:
            return [text]
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]
    
    def _clean_text(self, text):
        output = []
        for char in text:
            cp = ord(char)
            if cp == 0 or cp == 0xfffd or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)
    
def _is_control(char):
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat.startswith("C"):
        return True
    return False

def _is_whitespace(char):
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False

def _is_punctuation(char):
    cp = ord(char)
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

In [10]:
from bert.tokenization import WordpieceTokenizer, load_vocab, convert_by_vocab

class FullTokenizer(object):
    def __init__(self, vocab_file, do_lower_case=True):
        self.vocab = load_vocab(vocab_file)
        self.inv_vocab = {v: k for k, v in self.vocab.items()}
        self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case, 
                                              never_split = ['[CLS]', '[MASK]', '[SEP]'])
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def tokenize(self, text):
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            for sub_token in self.wordpiece_tokenizer.tokenize(token):
                split_tokens.append(sub_token)

        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        return convert_by_vocab(self.vocab, tokens)

    def convert_ids_to_tokens(self, ids):
        return convert_by_vocab(self.inv_vocab, ids)

In [11]:
tokenizer = FullTokenizer(vocab_file=BERT_VOCAB, do_lower_case=True)




In [12]:
text = '[CLS] scientist suggests eting burger can lead to obesity [SEP]'
text_mask = text.replace('eting', '**mask**')
text_mask

'[CLS] scientist suggests **mask** burger can lead to obesity [SEP]'

In [13]:
def get_indices(mask, word):
    splitted = mask.split('**mask**')
    left = tokenizer.tokenize(splitted[0])
    middle = tokenizer.tokenize(word)
    right = tokenizer.tokenize(splitted[1])
    indices = [i for i in range(len(left))]
    for i in range(len(right)):
        indices.append(i + len(middle) + len(left))
    
    indices = indices[1:-1]
    tokenized = tokenizer.tokenize(mask.replace('**mask**',word))
    ids = tokenizer.convert_tokens_to_ids(tokenized)
    ids_left = tokenizer.convert_tokens_to_ids(left)
    ids_right = tokenizer.convert_tokens_to_ids(right)
    indices_word = ids_left + ids_right
    return ids, indices, indices_word[1:-1]

In [14]:
indices = [get_indices(text_mask, word) for word in possible_states]
ids, seq_ids, word_ids = list(zip(*indices))
ids

([101, 7155, 6083, 3802, 2989, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 3802, 3436, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 24023, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 2777, 2075, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 27859, 3070, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 6655, 2075, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 4372, 3436, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 3968, 3436, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 16417, 2290, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 12072, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 28642, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 1041, 14147, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 3802, 2075, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 2128, 3436, 15890, 2064, 2599, 2000, 24552, 102],
 [101, 7155, 6083, 11781, 2290, 15890, 2064, 2

In [15]:
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)

In [16]:
class Model:
    def __init__(
        self,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        
        model = modeling.BertModel(
            config=bert_config,
            is_training=False,
            input_ids=self.X,
            use_one_hot_embeddings=False)
        
        output_layer = model.get_sequence_output()
        embedding = model.get_embedding_table()
        
        with tf.variable_scope('cls/predictions'):
            with tf.variable_scope('transform'):
                input_tensor = tf.layers.dense(
                    output_layer,
                    units = bert_config.hidden_size,
                    activation = modeling.get_activation(bert_config.hidden_act),
                    kernel_initializer = modeling.create_initializer(
                        bert_config.initializer_range
                    ),
                )
                input_tensor = modeling.layer_norm(input_tensor)
            
            output_bias = tf.get_variable(
            'output_bias',
            shape = [bert_config.vocab_size],
            initializer = tf.zeros_initializer(),
            )
            logits = tf.matmul(input_tensor, embedding, transpose_b = True)
            self.logits = tf.nn.bias_add(logits, output_bias)

In [17]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model()

sess.run(tf.global_variables_initializer())
var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')




The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.


In [18]:
cls = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'cls')
cls

[<tf.Variable 'cls/predictions/transform/dense/kernel:0' shape=(768, 768) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/dense/bias:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/output_bias:0' shape=(30522,) dtype=float32_ref>]

In [20]:
saver = tf.train.Saver(var_list = var_lists + cls)
saver.restore(sess, BERT_INIT_CHKPNT)

INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt


In [21]:
masked_padded = tf.keras.preprocessing.sequence.pad_sequences(ids,padding='post')
masked_padded.shape

(25, 11)

In [22]:
preds = sess.run(tf.nn.softmax(model.logits), feed_dict = {model.X: masked_padded})
preds.shape

(25, 11, 30522)

In [24]:
scores = []

for no, ids in enumerate(seq_ids):
    scores.append(np.prod(preds[no, ids, word_ids[no]]))
    
scores

[2.689204e-05,
 3.36932e-05,
 4.1889663e-05,
 2.2712533e-05,
 3.127968e-05,
 1.5012656e-05,
 3.465448e-05,
 5.2485917e-05,
 7.6286415e-05,
 3.5186342e-05,
 1.9021903e-05,
 3.2630334e-05,
 1.2884642e-05,
 4.779812e-05,
 9.0476e-05,
 3.1589767e-05,
 4.9742277e-05,
 4.847102e-05,
 3.391391e-05,
 1.6768146e-05,
 1.9604393e-05,
 9.826456e-05,
 3.581642e-05,
 4.2474054e-05,
 4.910487e-05]

In [25]:
prob_scores = np.array(scores) / np.sum(scores)
prob_scores

array([0.0269283 , 0.03373863, 0.04194615, 0.02274316, 0.03132186,
       0.0150329 , 0.03470121, 0.05255669, 0.07638928, 0.03523379,
       0.01904755, 0.03267433, 0.01290202, 0.04786257, 0.090598  ,
       0.03163236, 0.04980935, 0.04853638, 0.03395964, 0.01679076,
       0.01963083, 0.09839706, 0.03586472, 0.04253133, 0.04917109],
      dtype=float32)

In [26]:
probs = list(zip(possible_states, prob_scores))
probs.sort(key = lambda x: x[1])  
probs[::-1]

[('eating', 0.09839706),
 ('ering', 0.090598),
 ('eing', 0.07638928),
 ('edting', 0.052556694),
 ('epting', 0.04980935),
 ('eling', 0.049171086),
 ('ebing', 0.048536383),
 ('reting', 0.047862574),
 ('eking', 0.04253133),
 ('ewing', 0.04194615),
 ('elting', 0.03586472),
 ('sting', 0.03523379),
 ('enting', 0.03470121),
 ('geting', 0.033959642),
 ('etting', 0.03373863),
 ('eying', 0.032674335),
 ('kting', 0.031632364),
 ('etang', 0.03132186),
 ('etling', 0.026928302),
 ('meting', 0.02274316),
 ('ating', 0.019630829),
 ('ting', 0.019047555),
 ('etin', 0.016790757),
 ('beting', 0.0150329005),
 ('eting', 0.012902017)]