In [1]:
import sys
sys.path.append('/mnt/home/TF_NEW/tf-transformers/src/')

In [4]:
import random
import collections
import itertools
import tensorflow as tf
from transformers import BertTokenizer

In [3]:
#### Define needed arguments .

# https://github.com/tensorflow/models/blob/master/official/nlp/data/create_pretraining_data.py
do_lower_case = True
do_whole_word_mask = False
max_ngram_size = None # "Mask contiguous whole words (n-grams) of up to `max_ngram_size` using a "
                      # "weighting scheme to favor shorter n-grams. "
                      # "Note: `--do_whole_word_mask=True` must also be set when n-gram masking.")
    
max_seq_length = 128
max_predictions_per_seq = 20
random_seed = 12345
dupe_factor = 1
masked_lm_prob = 0.15
short_seq_prob = 0.1


In [5]:
rng = random.Random(random_seed)
MaskedLmInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])

# A _Gram is a [half-open) interval of token indices which form a word.
# E.g.,
#   words:  ["The", "doghouse"]
#   tokens: ["The", "dog", "##house"]
#   grams:  [(0,1), (1,3)]
_Gram = collections.namedtuple("_Gram", ["begin", "end"])
model_name = 'bert-base-cased'
tokenizer_path = "../../PRE_MODELS/HuggingFace_models/tokenizers/bert_base_cased"
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
vocab_words = list(tokenizer.vocab.keys())

In [6]:

def _window(iterable, size):
  """Helper to create a sliding window iterator with a given size.
  E.g.,
    input = [1, 2, 3, 4]
    _window(input, 1) => [1], [2], [3], [4]
    _window(input, 2) => [1, 2], [2, 3], [3, 4]
    _window(input, 3) => [1, 2, 3], [2, 3, 4]
    _window(input, 4) => [1, 2, 3, 4]
    _window(input, 5) => None
  Args:
    iterable: elements to iterate over.
    size: size of the window.
  Yields:
    Elements of `iterable` batched into a sliding window of length `size`.
  """
  i = iter(iterable)
  window = []
  try:
    for e in range(0, size):
      window.append(next(i))
    yield window
  except StopIteration:
    # handle the case where iterable's length is less than the window size.
    return
  for e in i:
    window = window[1:] + [e]
    yield window


def _contiguous(sorted_grams):
  """Test whether a sequence of grams is contiguous.
  Args:
    sorted_grams: _Grams which are sorted in increasing order.
  Returns:
    True if `sorted_grams` are touching each other.
  E.g.,
    _contiguous([(1, 4), (4, 5), (5, 10)]) == True
    _contiguous([(1, 2), (4, 5)]) == False
  """
  for a, b in _window(sorted_grams, 2):
    if a.end != b.begin:
      return False
  return True

def _masking_ngrams(grams, max_ngram_size, max_masked_tokens, rng):
  """Create a list of masking {1, ..., n}-grams from a list of one-grams.
  This is an extention of 'whole word masking' to mask multiple, contiguous
  words such as (e.g., "the red boat").
  Each input gram represents the token indices of a single word,
     words:  ["the", "red", "boat"]
     tokens: ["the", "red", "boa", "##t"]
     grams:  [(0,1), (1,2), (2,4)]
  For a `max_ngram_size` of three, possible outputs masks include:
    1-grams: (0,1), (1,2), (2,4)
    2-grams: (0,2), (1,4)
    3-grams; (0,4)
  Output masks will not overlap and contain less than `max_masked_tokens` total
  tokens.  E.g., for the example above with `max_masked_tokens` as three,
  valid outputs are,
       [(0,1), (1,2)]  # "the", "red" covering two tokens
       [(1,2), (2,4)]  # "red", "boa", "##t" covering three tokens
  The length of the selected n-gram follows a zipf weighting to
  favor shorter n-gram sizes (weight(1)=1, weight(2)=1/2, weight(3)=1/3, ...).
  Args:
    grams: List of one-grams.
    max_ngram_size: Maximum number of contiguous one-grams combined to create
      an n-gram.
    max_masked_tokens: Maximum total number of tokens to be masked.
    rng: `random.Random` generator.
  Returns:
    A list of n-grams to be used as masks.
  """
  if not grams:
    return None

  grams = sorted(grams)
  num_tokens = grams[-1].end

  # Ensure our grams are valid (i.e., they don't overlap).
  for a, b in _window(grams, 2):
    if a.end > b.begin:
      raise ValueError("overlapping grams: {}".format(grams))

  # Build map from n-gram length to list of n-grams.
  ngrams = {i: [] for i in range(1, max_ngram_size+1)}
  for gram_size in range(1, max_ngram_size+1):
    for g in _window(grams, gram_size):
      if _contiguous(g):
        # Add an n-gram which spans these one-grams.
        ngrams[gram_size].append(_Gram(g[0].begin, g[-1].end))

  # Shuffle each list of n-grams.
  for v in ngrams.values():
    rng.shuffle(v)

  # Create the weighting for n-gram length selection.
  # Stored cummulatively for `random.choices` below.
  cummulative_weights = list(
      itertools.accumulate([1./n for n in range(1, max_ngram_size+1)]))

  output_ngrams = []
  # Keep a bitmask of which tokens have been masked.
  masked_tokens = [False] * num_tokens
  # Loop until we have enough masked tokens or there are no more candidate
  # n-grams of any length.
  # Each code path should ensure one or more elements from `ngrams` are removed
  # to guarentee this loop terminates.
  while (sum(masked_tokens) < max_masked_tokens and
         sum(len(s) for s in ngrams.values())):
    # Pick an n-gram size based on our weights.
    sz = random.choices(range(1, max_ngram_size+1),
                        cum_weights=cummulative_weights)[0]

    # Ensure this size doesn't result in too many masked tokens.
    # E.g., a two-gram contains _at least_ two tokens.
    if sum(masked_tokens) + sz > max_masked_tokens:
      # All n-grams of this length are too long and can be removed from
      # consideration.
      ngrams[sz].clear()
      continue

    # All of the n-grams of this size have been used.
    if not ngrams[sz]:
      continue

    # Choose a random n-gram of the given size.
    gram = ngrams[sz].pop()
    num_gram_tokens = gram.end-gram.begin

    # Check if this would add too many tokens.
    if num_gram_tokens + sum(masked_tokens) > max_masked_tokens:
      continue

    # Check if any of the tokens in this gram have already been masked.
    if sum(masked_tokens[gram.begin:gram.end]):
      continue

    # Found a usable n-gram!  Mark its tokens as masked and add it to return.
    masked_tokens[gram.begin:gram.end] = [True] * (gram.end-gram.begin)
    output_ngrams.append(gram)
  return output_ngrams


def _wordpieces_to_grams(tokens):
  """Reconstitue grams (words) from `tokens`.
  E.g.,
     tokens: ['[CLS]', 'That', 'lit', '##tle', 'blue', 'tru', '##ck', '[SEP]']
      grams: [          [1,2), [2,         4),  [4,5) , [5,       6)]
  Args:
    tokens: list of wordpieces
  Returns:
    List of _Grams representing spans of whole words
    (without "[CLS]" and "[SEP]").
  """
  grams = []
  gram_start_pos = None
  for i, token in enumerate(tokens):
    if gram_start_pos is not None and token.startswith("##"):
      continue
    if gram_start_pos is not None:
      grams.append(_Gram(gram_start_pos, i))
    if token not in ["[CLS]", "[SEP]"]:
      gram_start_pos = i
    else:
      gram_start_pos = None
  if gram_start_pos is not None:
    grams.append(_Gram(gram_start_pos, len(tokens)))
  return grams


def create_masked_lm_predictions(tokens, masked_lm_prob,
                                 max_predictions_per_seq, vocab_words, rng,
                                 do_whole_word_mask,
                                 max_ngram_size=None):
  """Creates the predictions for the masked LM objective."""
  if do_whole_word_mask:
    grams = _wordpieces_to_grams(tokens)
  else:
    # Here we consider each token to be a word to allow for sub-word masking.
    if max_ngram_size:
      raise ValueError("cannot use ngram masking without whole word masking")
    grams = [_Gram(i, i+1) for i in range(0, len(tokens))
             if tokens[i] not in ["[CLS]", "[SEP]"]]

  num_to_predict = min(max_predictions_per_seq,
                       max(1, int(round(len(tokens) * masked_lm_prob))))
  # Generate masks.  If `max_ngram_size` in [0, None] it means we're doing
  # whole word masking or token level masking.  Both of these can be treated
  # as the `max_ngram_size=1` case.
  masked_grams = _masking_ngrams(grams, max_ngram_size or 1,
                                 num_to_predict, rng)
  masked_lms = []
  output_tokens = list(tokens)
  for gram in masked_grams:
    # 80% of the time, replace all n-gram tokens with [MASK]
    if rng.random() < 0.8:
      replacement_action = lambda idx: "[MASK]"
    else:
      # 10% of the time, keep all the original n-gram tokens.
      if rng.random() < 0.5:
        replacement_action = lambda idx: tokens[idx]
      # 10% of the time, replace each n-gram token with a random word.
      else:
        replacement_action = lambda idx: rng.choice(vocab_words)

    for idx in range(gram.begin, gram.end):
      output_tokens[idx] = replacement_action(idx)
      masked_lms.append(MaskedLmInstance(index=idx, label=tokens[idx]))

  assert len(masked_lms) <= num_to_predict
  masked_lms = sorted(masked_lms, key=lambda x: x.index)

  masked_lm_positions = []
  masked_lm_labels = []
  for p in masked_lms:
    masked_lm_positions.append(p.index)
    masked_lm_labels.append(p.label)

  return (output_tokens, masked_lm_positions, masked_lm_labels)

In [7]:
import pandas as pd
import glob
import json
from tf_transformers.data import TFWriter, TFReader, TFProcessor


In [8]:
# Get data
def get_text_list_from_files(files):
    text_list = []
    for name in files:
        with open(name) as f:
            for line in f:
                text_list.append(line)
    return text_list


def get_data_from_text_files(folder_name):

    pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
    pos_texts = get_text_list_from_files(pos_files)
    neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
    neg_texts = get_text_list_from_files(neg_files)
    df = pd.DataFrame(
        {
            "review": pos_texts + neg_texts,
            "sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),
        }
    )
    df = df.sample(len(df)).reset_index(drop=True)
    return df


train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

all_data = train_df.append(test_df)

In [80]:
def create_mlm(text):
    
    # text = tf.compat.as_str_any(text)
    text = text.numpy().decode()
    tokens = tokenizer.tokenize(text)[:max_seq_len-2]
    
    (tokens, masked_lm_positions,
             masked_lm_labels) = create_masked_lm_predictions(
                 tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
                 do_whole_word_mask, max_ngram_size)
    
    tokens    =  [tokenizer.sep_token]  + tokens + [tokenizer.cls_token]
    input_ids =  tokenizer.convert_tokens_to_ids(tokens)
    masked_lm_positions = [pos+1 for pos in masked_lm_positions]
    masked_lm_labels = tokenizer.convert_tokens_to_ids(masked_lm_labels)
    return [tokens, input_ids, masked_lm_positions, masked_lm_labels]

def create_mlm_map_fn(text):
    tokens, input_ids, masked_lm_positions, masked_lm_labels = tf.py_function(create_mlm, [text],
                                                                              [tf.string, tf.int32, tf.int32, tf.int32])
    
    result = {}
    result['input_ids'] = input_ids
    result['input_type_ids'] = tf.zeros_like(input_ids)
    result['input_mask'] = tf.ones_like(input_ids)
    result['masked_lm_positions'] = masked_lm_positions
    result['masked_lm_labels'] = masked_lm_labels
    return result

MIN_WORDS_IN_SENTENCE = 5
def get_data_as_generator():
    for index, row in all_data.iterrows():
        paragraph = row['review']
        for text in paragraph.split('. '):
            if text == '' or text == ' ' or text == None:
                continue
                
            if len(text.split()) < MIN_WORDS_IN_SENTENCE:
                continue
                
            yield text
            
text_dataset = tf.data.Dataset.from_generator(get_data_as_generator, 
                                             output_signature=tf.TensorSpec(shape=(), dtype=tf.string))

# all_texts = [text for text in get_data_as_generator()]

train_dataset = text_dataset.map(create_mlm_map_fn, num_parallel_calls =tf.data.experimental.AUTOTUNE)

In [66]:
text_dataset = tf.data.Dataset.from_tensor_slices(all_texts)

In [68]:
for item in text_dataset:
    print(item)
    break
    
for item in text_dataset.batch(8):
    print(item)
    break

tf.Tensor(b'This movie is just not worth your time', shape=(), dtype=string)
tf.Tensor(
[b'This movie is just not worth your time'
 b'Its reliance upon New-Age mysticism serves as its only semi-interesting distraction'
 b'The plot is one that has been re-cycled countless times.<br /><br />I was only prompted to even spend the time to put in a comment when I noted that some have tried to prop-up the reputation of this drivel'
 b'Their motivation & objectivity is dubious, since they encourage you not to look at the movies faults, but at its well intentioned message of New Age consciousness.<br /><br />So would it be alright for some twenty to thirty Evangelical Christians, or Islamic Fundamentalists to pour in positive ratings about movies/television that support their views? In spite of the poor qualities of production, or the lack of truth in any of its supposed historic basis? I hope not.<br /><br />I am sure the followers will come right behind me to say flowery things about this mov

In [81]:
train_dataset = text_dataset.map(create_mlm_map_fn, num_parallel_calls =tf.data.experimental.AUTOTUNE)

In [82]:
for index, item in enumerate(train_dataset):
    print(item)
    if index == 1:
        break

{'input_ids': <tf.Tensor: shape=(10,), dtype=int32, numpy=
array([ 102, 1188,  103, 1110, 1198, 1136, 3869, 1240, 1159,  101],
      dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(10,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(10,), dtype=int32, numpy=array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)>, 'masked_lm_positions': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([2], dtype=int32)>, 'masked_lm_labels': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([2523], dtype=int32)>}
{'input_ids': <tf.Tensor: shape=(19,), dtype=int32, numpy=
array([  102,  2098, 24727,   103,  1203,   118,  4936,  1139,  5668,
        1863,  3411,  1112,  1157,   103,  3533,   118,  5426, 15879,
         101], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(19,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
      dtype=int32)>, 'input_mask': <tf.Tensor: shape=(19,), dtype=int32, numpy=
array

In [86]:
padded_shapes = {k: [None] for k in train_dataset.element_spec}
batch_size = 64
train_dataset = train_dataset.padded_batch(batch_size, padded_shapes=padded_shapes)

{'input_ids': <tf.Tensor: shape=(8, 142), dtype=int32, numpy=
array([[  102,  1188,  2523, ...,     0,     0,     0],
       [  102,  2098, 24727, ...,     0,     0,     0],
       [  102,  1109,  4928, ...,     0,     0,     0],
       ...,
       [  102,   103,  1642, ...,     0,     0,     0],
       [  102,  1135,   112, ...,  1105,   103,   101],
       [  102,  1327,   170, ...,     0,     0,     0]], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(8, 142), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(8, 142), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'maske

In [84]:
padded_shapes

{'input_ids': [None],
 'input_type_ids': [None],
 'input_mask': [None],
 'masked_lm_positions': [None],
 'masked_lm_labels': [None]}

In [49]:
    from tf_transformers.data import separate_x_y
    def auto_batch(
        tf_dataset,
        batch_size,
        padded_values=None,
        padded_shapes=None,
        x_keys=None,
        y_keys=None,
        shuffle=False,
        drop_remainder=False,
        shuffle_buffer_size=10000,
        prefetch_buffer_size=100,
    ):
        """Auto Batching

        Args:
            tf_dataset : TF dataset
            x_keys (optional): List of key names. We will filter based on this.
            y_keys (optional): List of key names.
            shuffle (bool, optional): [description]. Defaults to False.
            shuffle_buffer_size (int, optional): [description]. Defaults to 10000.

        Returns:
            batched tf dataset
        """
        element_spec = tf_dataset.element_spec
        _padded_values = {}
        if not padded_values:
            padded_values = {}
        if not padded_shapes:
            padded_shapes = {}
        # sometimes we might have to have sme custom values other than 0
        for k, v in element_spec.items():
            if k in padded_values:
                value = padded_values[k]
                _padded_values[k] = tf.constant(value, dtype=value.dtype)
            else:
                _padded_values[k] = tf.constant(0, dtype=v.dtype)
                
        _padded_shapes = {}
        for k, v in element_spec.items():
            if k in  padded_shapes:
                _padded_shapes[k] = padded_shapes[k]
            else:
                _padded_shapes[k] = [None]
        dataset = tf_dataset.padded_batch(
            padding_values=_padded_values,
            padded_shapes=_padded_shapes,
            batch_size=batch_size,
            drop_remainder=drop_remainder,
        )
        # fmt: off
        if x_keys and y_keys:
            dataset = dataset.map(lambda x: separate_x_y(x, x_keys, y_keys), num_parallel_calls=tf.data.experimental.AUTOTUNE)  # noqa
        # fmt: on
        if shuffle:
            dataset = dataset.shuffle(shuffle_buffer_size, seed=None, reshuffle_each_iteration=True)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

x_keys = ['input_ids', 'input_mask', 'input_type_ids', 'masked_lm_positions']
y_keys = ['masked_lm_labels']
train_dataset = auto_batch(train_dataset,batch_size=batch_size, x_keys=x_keys, y_keys=y_keys)

In [18]:
MIN_WORDS_IN_SENTENCE = 5
def create_mlm_tfrecord():
    counter = 0
    for index, row in all_data.iterrows():
        paragraph = row['review']
        for text in paragraph.split('. '):
            if text == '' or text == ' ' or text == None:
                continue
                
            if len(text.split()) < MIN_WORDS_IN_SENTENCE:
                continue

            # slice off if exceeds tha max_seq_length
            tokens = tokenizer.tokenize(text)[:max_seq_length-2] 
            (tokens, masked_lm_positions,
                     masked_lm_labels) = create_masked_lm_predictions(
                         tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng,
                         do_whole_word_mask, max_ngram_size)

            tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]
            input_ids =  tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(input_ids)
            input_type_ids = [0] * len(input_ids)

            # pos + 1 for accounting shifting by CLS token
            masked_lm_positions = [pos+1 for pos in masked_lm_positions]
            masked_lm_labels = tokenizer.convert_tokens_to_ids(masked_lm_labels)

            result = {}
            result['input_ids'] = input_ids
            result['input_type_ids'] = input_type_ids
            result['input_mask'] = input_mask

            result['masked_lm_positions'] = masked_lm_positions
            result['masked_lm_labels'] = masked_lm_labels
            
            counter += 1
            yield result
        
# Lets write using TF Writer
# Use TFProcessor for smalled data

schema = {'input_ids': ("var_len", "int"), 
         'input_mask': ("var_len", "int"), 
         'input_type_ids': ("var_len", "int"), 
         'masked_lm_positions': ("var_len", "int"),
         'masked_lm_labels': ("var_len", "int") 
         }

tfrecord_train_dir = '../OFFICIAL_TFRECORDS/squad/bert_mlm/train'
tfrecord_filename = 'imdb'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    overwrite=True
                    )
tfwriter.process(parse_fn=create_mlm_tfrecord())

INFO:absl:Wrote 1000 tfrecods
INFO:absl:Wrote 2000 tfrecods
INFO:absl:Wrote 3000 tfrecods
INFO:absl:Wrote 4000 tfrecods
INFO:absl:Wrote 5000 tfrecods
INFO:absl:Wrote 6000 tfrecods
INFO:absl:Wrote 7000 tfrecods
INFO:absl:Wrote 8000 tfrecods
INFO:absl:Wrote 9000 tfrecods
INFO:absl:Wrote 10000 tfrecods
INFO:absl:Wrote 11000 tfrecods
INFO:absl:Wrote 12000 tfrecods
INFO:absl:Wrote 13000 tfrecods
INFO:absl:Wrote 14000 tfrecods
INFO:absl:Wrote 15000 tfrecods
INFO:absl:Wrote 16000 tfrecods
INFO:absl:Wrote 17000 tfrecods
INFO:absl:Wrote 18000 tfrecods
INFO:absl:Wrote 19000 tfrecods
INFO:absl:Wrote 20000 tfrecods
INFO:absl:Wrote 21000 tfrecods
INFO:absl:Wrote 22000 tfrecods
INFO:absl:Wrote 23000 tfrecods
INFO:absl:Wrote 24000 tfrecods
INFO:absl:Wrote 25000 tfrecods
INFO:absl:Wrote 26000 tfrecods
INFO:absl:Wrote 27000 tfrecods
INFO:absl:Wrote 28000 tfrecods
INFO:absl:Wrote 29000 tfrecods
INFO:absl:Wrote 30000 tfrecods
INFO:absl:Wrote 31000 tfrecods
INFO:absl:Wrote 32000 tfrecods
INFO:absl:Wrote 3

In [19]:
# Read Data
schema = json.load(open("{}/schema.json".format(tfrecord_train_dir)))
all_files = glob.glob("{}/*.tfrecord".format(tfrecord_train_dir))
tf_reader = TFReader(schema=schema, 
                    tfrecord_files=all_files)

x_keys = ['input_ids', 'input_mask', 'input_type_ids', 'masked_lm_positions']
y_keys = ['masked_lm_labels']
batch_size = 64

padded_values = {'masked_lm_psitions': tf.constant(tokenizer.pad_token_id), 
                'masked_lm_labels': tf.constant(tokenizer.pad_token_id)}
train_dataset = tf_reader.read_record(auto_batch=True, 
                                   keys=x_keys,
                                   batch_size=batch_size,
                                    padded_values=padded_values,
                                   x_keys = x_keys, 
                                   y_keys = y_keys,
                                   shuffle=True, 
                                   drop_remainder=True
                                  )

In [20]:
for batch_inputs, batch_labels in train_dataset.take(1):
    print(batch_inputs, batch_labels)

{'input_ids': <tf.Tensor: shape=(32, 120), dtype=int32, numpy=
array([[  101,  6853,   103, ...,     0,     0,     0],
       [  101,  4978,  2499, ...,     0,     0,     0],
       [  101,  1337,  1123, ...,     0,     0,     0],
       ...,
       [  101,  2119,  5819, ...,     0,     0,     0],
       [  101,  1327,   170, ...,     0,     0,     0],
       [  101, 16752, 21089, ...,   119,   119,   102]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(32, 120), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 1, 1, 1]], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(32, 120), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>, 'ma

In [21]:
batch_inputs['masked_lm_positions']

<tf.Tensor: shape=(32, 18), dtype=int32, numpy=
array([[  2,  15,  17,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  6,  10,  11,  15,  22,  26,  29,  33,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  9,  12,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  2,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [ 11,  14,  16,  18,  25,  26,  34,  54,  56,  59,  61,  72,  89,
         96, 101, 105,   0,   0],
       [  4,   7,  15,  33,  34,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  9,  12,  14,  31,  34,  38,  40,  42,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  1,  10,  12,  13,  14,  27,  44,  52,  53,   0,   0,   0,   0,
          0,   0,   0,   0,   0],
       [  9,  16,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0, 

In [22]:
batch_labels['masked_lm_labels']

<tf.Tensor: shape=(32, 18), dtype=int32, numpy=
array([[ 1240,  1136,  1106,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0],
       [  117,  1114,   170, 20975,  1106,  1117,  2368, 13559,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0],
       [ 2641,  5750,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0],
       [ 1122,  5367,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0],
       [ 1114, 25470,  2349, 10401,   170,  1992,  1105,  2570,  8702,
        23878,  1113,  7831,  1402,  1105,  9052,   117,     0,     0],
       [ 2606,  1141,  1132,  1107,  1115,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0],
       [ 2627, 16387,   132,  1641, 17294,  2168,  4390,  3703,     0,
            0,     0,  

In [23]:
!nvidia-smi

Tue May 11 14:31:25 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 418.126.02   Driver Version: 418.126.02   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-SXM2...  On   | 00000000:85:00.0 Off |                    0 |
| N/A   37C    P0    58W / 300W |  30587MiB / 32480MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

In [32]:
from tf_transformers.models import BERTEncoder

config = {
    "attention_probs_dropout_prob": 0.1,
    "hidden_act": "gelu",
    "intermediate_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "embedding_size": 768,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "max_position_embeddings": 512,
    "num_attention_heads": 12,
    "attention_head_size": 64,
    "num_hidden_layers": 12,
    "type_vocab_size": 2,
    "vocab_size": 28996,
    "layer_norm_epsilon": 1e-12,
    "mask_mode": "user_defined",
}

model = BERTEncoder(config=config, use_masked_lm_positions=True, return_all_layer_outputs=True)
model = model.get_model()

In [27]:
# model.load_checkpoint("/tmp/tf_transformers_cache/bert-base-cased/")

In [28]:
from tf_transformers.losses import cross_entropy_loss

def joint_mlm_loss_fn(y_true_dict, y_pred_dict):
    
    loss_dict = {}
    loss_holder = []
    masked_lm_mask = tf.cast(tf.not_equal(y_true_dict['masked_lm_labels'], tokenizer.pad_token_id),
                            y_pred_dict['token_logits'].dtype)
    for i, layer_output in enumerate(y_pred_dict['all_layer_token_logits']):
        layer_loss = cross_entropy_loss(labels=y_true_dict['masked_lm_labels'], 
                                       logits=layer_output, 
                                       label_weights=masked_lm_mask)
        loss_dict['layer_{}'.format(i)] = layer_loss
        loss_holder.append(layer_loss)
    loss_dict['loss'] = tf.reduce_mean(loss_holder)
    return tf.reduce_mean(loss_holder)


def mlm_loss_fn(y_true_dict, y_pred_dict):
    
    masked_lm_mask = tf.cast(tf.not_equal(y_true_dict['masked_lm_labels'], tokenizer.pad_token_id),
                            y_pred_dict['token_logits'].dtype)
    
    loss = cross_entropy_loss(labels=y_true_dict['masked_lm_labels'], 
                                       logits=y_pred_dict['token_logits'], 
                                       label_weights=masked_lm_mask)

    return loss

In [29]:
class MaskedTextGenerator(tf.keras.callbacks.Callback):
    def __init__(self, tokenizer, top_k=5):
        self.top_k = top_k
        self.tokenizer = tokenizer
        #model = BERTEncoder(config=config, return_all_layer_outputs=True)
        #model = model.get_model()
        #self.original_model = model


    def on_epoch_begin(self, epoch, logs=None):
        # self.original_model.set_weights(self.model.get_weights())
        sample_text = "I have watched this [MASK] and it was awesome"

        input_ids = tf.constant(self.tokenizer.encode(sample_text))

        masked_index = np.where(input_ids == self.tokenizer.mask_token_id)[0][0]
        input_ids = tf.expand_dims(input_ids, axis=0)
        input_type_ids = tf.zeros_like(input_ids)
        input_mask     = tf.ones_like(input_ids)
        inputs = {}
        inputs["input_ids"] = input_ids
        inputs["input_type_ids"] = input_type_ids
        inputs["input_mask"] = input_mask
        inputs['masked_lm_positions'] = tf.constant([range(len(input_ids[0]))])
        outputs = self.model(inputs)
        top_k_probs, top_k_indices = tf.nn.top_k(outputs['token_logits'][:, masked_index, :], k = self.top_k)
        top_k_probs = top_k_probs[0]
        top_k_indices = top_k_indices[0]
        for i in range(self.top_k):
            print("top {} , {} , {}".format(i, top_k_probs[i], tokenizer.decode([top_k_indices[i]])))
            
        #for layer_output in outputs['all_layer_token_logits']:
        #    top_k_probs, top_k_indices = tf.nn.top_k(layer_output[:, masked_index, :], k = self.top_k)
        #    top_k_probs = top_k_probs[0]
        #    top_k_indices = top_k_indices[0]
        #    for i in range(self.top_k):
        #        print("top {} , {} , {}".format(i, top_k_probs[i], tokenizer.decode([top_k_indices[i]])))
        #    print('-------------------------------------------------------------------------------------------')


generator_callback = MaskedTextGenerator(tokenizer)

In [None]:
from tf_transformers.core import optimization
train_data_size = 300000
learning_rate   = 2e-5
steps_per_epoch = int(train_data_size / batch_size)
EPOCHS = 5
num_train_steps = steps_per_epoch * EPOCHS
warmup_steps = int(EPOCHS * train_data_size * 0.1 / batch_size)
# creates an optimizer with learning rate schedule
optimizer_type = 'adamw'
optimizer, learning_rate_fn = optimization.create_optimizer(learning_rate,
                                                steps_per_epoch * EPOCHS,
                                                warmup_steps,
                                                optimizer_type)
model.compile2(optimizer=optimizer, custom_loss={'token_logits': mlm_loss_fn}, loss=None)

In [31]:
import numpy as np
history = model.fit(train_dataset, 
                   epochs=EPOCHS, 
                   callbacks=[generator_callback])

Epoch 1/5
top 0 , 2.2106940746307373 , sets
top 1 , 1.8485941886901855 , A2
top 2 , 1.719996690750122 , MGM
top 3 , 1.712893009185791 , filters
top 4 , 1.7128450870513916 , ##6th








Epoch 2/5
top 0 , 3.6591522693634033 , .
top 1 , 3.54352068901062 , the
top 2 , 3.472245931625366 , ,
top 3 , 2.9732494354248047 , a
top 4 , 2.895054817199707 , '
Epoch 3/5
top 0 , 4.1321258544921875 , .
top 1 , 3.961129665374756 , the
top 2 , 3.7622623443603516 , ,
top 3 , 3.263580799102783 , a
top 4 , 3.199124574661255 , of
Epoch 4/5
top 0 , 4.193861961364746 , .
top 1 , 3.902042865753174 , the
top 2 , 3.817805290222168 , ,
top 3 , 3.376988649368286 , a
top 4 , 3.255675792694092 , to
Epoch 5/5
top 0 , 4.476137161254883 , .
top 1 , 4.145719528198242 , the
top 2 , 4.076451778411865 , ,
top 3 , 3.651524543762207 , and
top 4 , 3.587021589279175 , to

KeyboardInterrupt: 

In [40]:
model2 = generator_callback.original_model
model2.set_weights(model.get_weights())

In [41]:
batch_inputs_copy = batch_inputs.copy()
del batch_inputs_copy['masked_lm_positions']

In [42]:
outputs = model(batch_inputs)

In [43]:
outputs2 = model2(batch_inputs_copy)

In [45]:
outputs['token_logits'][0, :, :]

<tf.Tensor: shape=(19, 28996), dtype=float32, numpy=
array([[-9.598297, -9.752734, -9.611607, ..., -9.480362, -9.760493,
        -9.632543],
       [-9.598297, -9.752733, -9.611607, ..., -9.480361, -9.760493,
        -9.632543],
       [-9.598297, -9.752733, -9.611607, ..., -9.480361, -9.760493,
        -9.632543],
       ...,
       [-9.598297, -9.752733, -9.611607, ..., -9.480362, -9.760493,
        -9.632543],
       [-9.598297, -9.752733, -9.611607, ..., -9.480362, -9.760493,
        -9.632543],
       [-9.598297, -9.752733, -9.611607, ..., -9.480362, -9.760493,
        -9.632543]], dtype=float32)>

In [50]:
for index in batch_inputs['masked_lm_positions'][0]:
    print(outputs2['token_logits'][0, :, :][index])

tf.Tensor([-9.598299  -9.752737  -9.6116085 ... -9.480363  -9.760494  -9.6325445], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299  -9.752736  -9.6116085 ... -9.480363  -9.760494  -9.6325445], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299  -9.752736  -9.6116085 ... -9.480363  -9.760494  -9.6325445], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299 -9.752737 -9.611609 ... -9.480363 -9.760494 -9.632545], shape=(28996,), dtype=float32)
tf.Tensor([-9.598298  -9.752735  -9.6116085 ... -9.480363  -9.760494  -9.6325445], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299  -9.752734  -9.6116085 ... -9.480363  -9.760494  -9.6325445], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299 -9.752737 -9.611609 ... -9.480364 -9.760495 -9.632545], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299 -9.752736 -9.611609 ... -9.480363 -9.760494 -9.632545], shape=(28996,), dtype=float32)
tf.Tensor([-9.598299 -9.752736 -9.611609 ... -9.480364 -9.760495 -9.632545], shape=(28996,), dtype=float32

<tf.Tensor: shape=(128, 28996), dtype=float32, numpy=
array([[-9.598299 , -9.752736 , -9.6116085, ..., -9.480363 , -9.760494 ,
        -9.6325445],
       [-9.598299 , -9.752737 , -9.611609 , ..., -9.480364 , -9.760494 ,
        -9.632545 ],
       [-9.598299 , -9.752736 , -9.611609 , ..., -9.480363 , -9.760495 ,
        -9.632545 ],
       ...,
       [-9.598299 , -9.752737 , -9.611609 , ..., -9.480364 , -9.760496 ,
        -9.632545 ],
       [-9.598299 , -9.752736 , -9.6116085, ..., -9.480363 , -9.760494 ,
        -9.6325445],
       [-9.598298 , -9.752735 , -9.6116085, ..., -9.480363 , -9.760494 ,
        -9.6325445]], dtype=float32)>

In [None]:
        sample_text = "I have watched this [MASK] and it was awesome"

        input_ids = tf.constant(tokenizer.encode(sample_text))

        masked_index = np.where(input_ids == tokenizer.mask_token_id)[0][0]
        input_ids = tf.expand_dims(input_ids, axis=0)
        input_type_ids = tf.zeros_like(input_ids)
        input_mask     = tf.ones_like(input_ids)
        inputs = {}
        inputs["input_ids"] = input_ids
        inputs["input_type_ids"] = input_type_ids
        inputs["input_mask"] = input_mask
        inputs['masked_lm_positions'] = tf.constant([range(len(input_ids[0]))])
        outputs = model2(inputs)
        
        for layer_output in outputs['all_layer_token_logits']:
            top_k_probs, top_k_indices = tf.nn.top_k(layer_output[:, masked_index, :], k = top_k)
            top_k_probs = top_k_probs[0]
            top_k_indices = top_k_indices[0]
            for i in range(top_k):
                print("top {} , {} , {}".format(i, top_k_probs[i], tokenizer.decode([top_k_indices[i]])))
            print('-------------------------------------------------------------------------------------------')

            
            
            
top 0 , 2.520089626312256 , also
top 1 , 2.4441680908203125 , bad
top 2 , 2.2934823036193848 , Paris
top 3 , 2.0417215824127197 , such
top 4 , 1.998399257659912 , who
-------------------------------------------------------------------------------------------
top 0 , 2.1321091651916504 , .
top 1 , 2.000959873199463 , seems
top 2 , 1.9706417322158813 , also
top 3 , 1.951348066329956 , is
top 4 , 1.9027767181396484 , Paris
-------------------------------------------------------------------------------------------
top 0 , 2.1556396484375 , .
top 1 , 1.9961425065994263 , performance
top 2 , 1.9506205320358276 , seems
top 3 , 1.9393377304077148 , role
top 4 , 1.8849267959594727 , right
-------------------------------------------------------------------------------------------
top 0 , 2.441511631011963 , people
top 1 , 2.38044810295105 , film
top 2 , 2.305978536605835 , movie
top 3 , 2.1593713760375977 , hard
top 4 , 2.1197104454040527 , ,
-------------------------------------------------------------------------------------------
top 0 , 3.1191365718841553 , people
top 1 , 3.0638527870178223 , film
top 2 , 2.9635488986968994 , movie
top 3 , 2.6810827255249023 , ,
top 4 , 2.4829795360565186 , one
-------------------------------------------------------------------------------------------
top 0 , 4.018131256103516 , movie
top 1 , 3.968170404434204 , film
top 2 , 3.78580904006958 , ,
top 3 , 3.4917008876800537 , people
top 4 , 3.391716957092285 , one
-------------------------------------------------------------------------------------------
top 0 , 4.816952705383301 , movie
top 1 , 4.428909778594971 , film
top 2 , 4.038279056549072 , ,
top 3 , 3.9647138118743896 , one
top 4 , 3.482236623764038 , people
-------------------------------------------------------------------------------------------
top 0 , 6.0528082847595215 , movie
top 1 , 5.275920391082764 , film
top 2 , 4.3914337158203125 , one
top 3 , 4.133042335510254 , time
top 4 , 3.9425129890441895 , ,
-------------------------------------------------------------------------------------------
top 0 , 7.545930862426758 , movie
top 1 , 6.443050861358643 , film
top 2 , 5.300261497497559 , one
top 3 , 4.732942581176758 , time
top 4 , 4.522573947906494 , show
-------------------------------------------------------------------------------------------
top 0 , 8.466971397399902 , movie
top 1 , 7.410390853881836 , film
top 2 , 5.797360897064209 , one
top 3 , 5.493937015533447 , show
top 4 , 5.378448963165283 , time
-------------------------------------------------------------------------------------------
top 0 , 9.444768905639648 , movie
top 1 , 7.948002338409424 , film
top 2 , 6.243518352508545 , one
top 3 , 6.10836124420166 , show
top 4 , 6.008388996124268 , time
-------------------------------------------------------------------------------------------
top 0 , 9.97020149230957 , movie
top 1 , 9.040471076965332 , film
top 2 , 7.235130786895752 , one
top 3 , 6.5173492431640625 , time
top 4 , 6.311537265777588 , show
-------------------------------------------------------------------------------------------