In [1]:
# data from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/utils/helpers.py
# reuploaded to husein's S3
# !wget https://malaya-dataset.s3-ap-southeast-1.amazonaws.com/counts_1grams.txt

In [2]:
with open('counts_1grams.txt') as fopen:
    f = fopen.read().split('\n')[:-1]
    
words = {}
for l in f:
    w, c = l.split('\t')
    c = int(c)
    words[w] = c + words.get(w, 0)

In [3]:
# original from https://github.com/cbaziotis/ekphrasis/blob/master/ekphrasis/classes/spellcorrect.py
# improved it

import re
from collections import Counter

class SpellCorrector:
    """
    The SpellCorrector extends the functionality of the Peter Norvig's
    spell-corrector in http://norvig.com/spell-correct.html
    """

    def __init__(self):
        """
        :param corpus: the statistics from which corpus to use for the spell correction.
        """
        super().__init__()
        self.WORDS = words
        self.N = sum(self.WORDS.values())
        
    @staticmethod
    def tokens(text):
        return REGEX_TOKEN.findall(text.lower())

    def P(self, word):
        """
        Probability of `word`.
        """
        return self.WORDS[word] / self.N

    def most_probable(self, words):
        _known = self.known(words)
        if _known:
            return max(_known, key=self.P)
        else:
            return []

    @staticmethod
    def edit_step(word):
        """
        All edits that are one edit away from `word`.
        """
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        """
        All edits that are two edits away from `word`.
        """
        return (e2 for e1 in self.edit_step(word)
                for e2 in self.edit_step(e1))

    def known(self, words):
        """
        The subset of `words` that appear in the dictionary of WORDS.
        """
        return set(w for w in words if w in self.WORDS)

    def edit_candidates(self, word, assume_wrong=False, fast=True):
        """
        Generate possible spelling corrections for word.
        """

        if fast:
            ttt = self.known(self.edit_step(word)) or {word}
        else:
            ttt = self.known(self.edit_step(word)) or self.known(self.edits2(word)) or {word}
        
        ttt = self.known([word]) | ttt
        return list(ttt)

In [4]:
corrector = SpellCorrector()

In [5]:
possible_states = corrector.edit_candidates('eting')
possible_states

['eking',
 'geting',
 'eting',
 'eling',
 'ating',
 'edting',
 'reting',
 'eating',
 'eing',
 'beting',
 'epting',
 'etin',
 'etang',
 'kting',
 'meting',
 'elting',
 'sting',
 'eying',
 'ewing',
 'etting',
 'etling',
 'ting',
 'ering',
 'ebing',
 'enting']

In [6]:
# !wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
# !unzip uncased_L-12_H-768_A-12.zip

In [7]:
BERT_VOCAB = 'uncased_L-12_H-768_A-12/vocab.txt'
BERT_INIT_CHKPNT = 'uncased_L-12_H-768_A-12/bert_model.ckpt'
BERT_CONFIG = 'uncased_L-12_H-768_A-12/bert_config.json'

In [8]:
import bert
from bert import run_classifier
from bert import optimization
from bert import tokenization
from bert import modeling
import tensorflow as tf
import numpy as np

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])





  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [9]:
tokenization.validate_case_matches_checkpoint(True,BERT_INIT_CHKPNT)
tokenizer = tokenization.FullTokenizer(
      vocab_file=BERT_VOCAB, do_lower_case=True)




In [10]:
text = 'scientist suggests eting burger can lead to obesity'
text_mask = text.replace('eting', '**mask**')
text_mask

'scientist suggests **mask** burger can lead to obesity'

In [11]:
def tokens_to_masked_ids(tokens, mask_ind):
    masked_tokens = tokens[:]
    masked_tokens[mask_ind] = "[MASK]"
    masked_tokens = ["[CLS]"] + masked_tokens + ["[SEP]"]
    masked_ids = tokenizer.convert_tokens_to_ids(masked_tokens)
    return masked_ids

In [12]:
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)

In [13]:
class Model:
    def __init__(
        self,
    ):
        self.X = tf.placeholder(tf.int32, [None, None])
        
        model = modeling.BertModel(
            config=bert_config,
            is_training=False,
            input_ids=self.X,
            use_one_hot_embeddings=False)
        
        output_layer = model.get_sequence_output()
        embedding = model.get_embedding_table()
        
        with tf.variable_scope('cls/predictions'):
            with tf.variable_scope('transform'):
                input_tensor = tf.layers.dense(
                    output_layer,
                    units = bert_config.hidden_size,
                    activation = modeling.get_activation(bert_config.hidden_act),
                    kernel_initializer = modeling.create_initializer(
                        bert_config.initializer_range
                    ),
                )
                input_tensor = modeling.layer_norm(input_tensor)
            
            output_bias = tf.get_variable(
            'output_bias',
            shape = [bert_config.vocab_size],
            initializer = tf.zeros_initializer(),
            )
            logits = tf.matmul(input_tensor, embedding, transpose_b = True)
            self.logits = tf.nn.bias_add(logits, output_bias)

In [14]:
tf.reset_default_graph()
sess = tf.InteractiveSession()
model = Model()

sess.run(tf.global_variables_initializer())
var_lists = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'bert')




The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Instructions for updating:
Use keras.layers.dense instead.


In [20]:
cls = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope = 'cls')
cls

[<tf.Variable 'cls/predictions/transform/dense/kernel:0' shape=(768, 768) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/dense/bias:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/LayerNorm/beta:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/transform/LayerNorm/gamma:0' shape=(768,) dtype=float32_ref>,
 <tf.Variable 'cls/predictions/output_bias:0' shape=(30522,) dtype=float32_ref>]

In [21]:
saver = tf.train.Saver(var_list = var_lists + cls)
saver.restore(sess, BERT_INIT_CHKPNT)

Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from uncased_L-12_H-768_A-12/bert_model.ckpt


In [22]:
replaced_masks = [text_mask.replace('**mask**', state) for state in possible_states]
replaced_masks

['scientist suggests eking burger can lead to obesity',
 'scientist suggests geting burger can lead to obesity',
 'scientist suggests eting burger can lead to obesity',
 'scientist suggests eling burger can lead to obesity',
 'scientist suggests ating burger can lead to obesity',
 'scientist suggests edting burger can lead to obesity',
 'scientist suggests reting burger can lead to obesity',
 'scientist suggests eating burger can lead to obesity',
 'scientist suggests eing burger can lead to obesity',
 'scientist suggests beting burger can lead to obesity',
 'scientist suggests epting burger can lead to obesity',
 'scientist suggests etin burger can lead to obesity',
 'scientist suggests etang burger can lead to obesity',
 'scientist suggests kting burger can lead to obesity',
 'scientist suggests meting burger can lead to obesity',
 'scientist suggests elting burger can lead to obesity',
 'scientist suggests sting burger can lead to obesity',
 'scientist suggests eying burger can lead

In [23]:
def get_score(mask):
    tokens = tokenizer.tokenize(mask)
    input_ids = [tokens_to_masked_ids(tokens, i) for i in range(len(tokens))]
    preds = sess.run(tf.nn.softmax(model.logits), feed_dict = {model.X: input_ids})
    tokens_ids = tokenizer.convert_tokens_to_ids(tokens)
    return np.prod([preds[i, i + 1, x] for i, x in enumerate(tokens_ids)])

In [24]:
scores = [get_score(mask) for mask in replaced_masks]
scores

[9.990078e-24,
 1.7535894e-26,
 9.22324e-29,
 3.768507e-25,
 1.0349554e-25,
 1.438714e-28,
 2.1763072e-27,
 1.3728026e-17,
 2.9091794e-26,
 5.0968306e-28,
 1.5277633e-28,
 3.610259e-28,
 2.4073189e-23,
 6.3879305e-29,
 2.9510165e-26,
 7.9180676e-28,
 4.1319517e-26,
 1.0869552e-22,
 4.0500935e-25,
 2.3287322e-26,
 5.3360014e-29,
 1.0582614e-27,
 2.0237078e-22,
 1.3330295e-28,
 7.3765675e-27]

In [25]:
prob_scores = np.array(scores) / np.sum(scores)
prob_scores

array([7.2769569e-07, 1.2773469e-09, 6.7183785e-12, 2.7450502e-08,
       7.5388060e-09, 1.0479858e-11, 1.5852623e-10, 9.9997473e-01,
       2.1191000e-09, 3.7126253e-11, 1.1128510e-11, 2.6297792e-11,
       1.7535355e-06, 4.6530865e-12, 2.1495750e-09, 5.7676666e-11,
       3.0097900e-09, 7.9175825e-06, 2.9501628e-08, 1.6962916e-09,
       3.8868418e-12, 7.7085720e-11, 1.4741061e-05, 9.7100331e-12,
       5.3732280e-10], dtype=float32)

In [26]:
list(zip(possible_states, scores))

[('eking', 9.990078e-24),
 ('geting', 1.7535894e-26),
 ('eting', 9.22324e-29),
 ('eling', 3.768507e-25),
 ('ating', 1.0349554e-25),
 ('edting', 1.438714e-28),
 ('reting', 2.1763072e-27),
 ('eating', 1.3728026e-17),
 ('eing', 2.9091794e-26),
 ('beting', 5.0968306e-28),
 ('epting', 1.5277633e-28),
 ('etin', 3.610259e-28),
 ('etang', 2.4073189e-23),
 ('kting', 6.3879305e-29),
 ('meting', 2.9510165e-26),
 ('elting', 7.9180676e-28),
 ('sting', 4.1319517e-26),
 ('eying', 1.0869552e-22),
 ('ewing', 4.0500935e-25),
 ('etting', 2.3287322e-26),
 ('etling', 5.3360014e-29),
 ('ting', 1.0582614e-27),
 ('ering', 2.0237078e-22),
 ('ebing', 1.3330295e-28),
 ('enting', 7.3765675e-27)]