In [298]:
import sys
sys.path.append("/home/sidhu/Documents/tf-transformers/src/")

In [None]:
import tensorflow_text as text
import tensorflow as tf
import numpy as np

In [2]:
import tensorflow_datasets as tfds

In [30]:
# Load tokenizer

vocab_file = 'bert_tokenizer_dir/vocab.txt'
def _create_vocab_table_and_initializer(vocab_file):
    vocab_initializer = tf.lookup.TextFileInitializer(
        vocab_file,
        key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
        value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
    vocab_table = tf.lookup.StaticHashTable(vocab_initializer, default_value=-1)
    return vocab_table, vocab_initializer

vocab_table , vocab_initializer = _create_vocab_table_and_initializer(vocab_file)
bert_tokenizer = text.BertTokenizer(
        vocab_table, lower_case=False)

CLS_ID, SEP_ID, PAD_ID, UNK_ID, MASK_ID = (101, 102, 0, 100, 103)

In [3]:
imdb = tfds.load('imdb_reviews')

[1mDownloading and preparing dataset 80.23 MiB (download: 80.23 MiB, generated: Unknown size, total: 80.23 MiB) to /home/sidhu/tensorflow_datasets/imdb_reviews/plain_text/1.0.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]





Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-train.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test examples...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-test.tfrecord...:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling imdb_reviews-unsupervised.tfrecord...:   0%|          | 0/50000 [00:00<?, ? examples/s]

[1mDataset imdb_reviews downloaded and prepared to /home/sidhu/tensorflow_datasets/imdb_reviews/plain_text/1.0.0. Subsequent calls will reuse this data.[0m


In [292]:
def get_masked_input_and_labels(encoded_texts):
    # 15% BERT masking
    inp_mask = np.random.rand(*encoded_texts.shape) < 0.15
    # Do not mask special tokens
    inp_mask[encoded_texts == CLS_ID] = False
    inp_mask[encoded_texts == SEP_ID] = False    
    # Set targets to -1 by default, it means ignore
    labels = -1 * np.ones(encoded_texts.shape, dtype=int)
    # Set labels for masked tokens
    labels[inp_mask] = encoded_texts[inp_mask]

    # Prepare input
    encoded_texts_masked = np.copy(encoded_texts)
    # Set input to [MASK] which is the last token for the 90% of tokens
    # This means leaving 10% unchanged
    inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)
    encoded_texts_masked[
        inp_mask_2mask
    ] = MASK_ID  # mask token is the last in the dict

    # Set 10% to a random token
    inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)
    encoded_texts_masked[inp_mask_2random] = np.random.randint(
        3, MASK_ID, inp_mask_2random.sum()
    )
    
    
    # Prepare sample_weights to pass to .fit() method
    sample_weights = np.ones(labels.shape)
    sample_weights[labels == -1] = 0

    # y_labels would be same as encoded_texts i.e input tokens
    y_labels = np.copy(encoded_texts)
    
    # Extract masked lm positions (where we have masked)
    # and use tf.ragged to convert it into tensor
    indexes, positions = np.where(encoded_texts_masked == MASK_ID)
    unique, counts = np.unique(indexes, return_counts=True)
    counts = counts[:-1] # an extra at last (dont know)
    counts = np.cumsum(counts)
    masked_lm_positions = tf.ragged.constant(np.split(positions, counts))
    masked_lm_positions = masked_lm_positions.to_tensor(PAD_ID)
    
    # Gather the positions we want
    y_labels = np.take_along_axis(y_labels, masked_lm_positions.numpy(), axis=1)
    sample_weights = np.take_along_axis(sample_weights, masked_lm_positions.numpy(), axis=1)
    
    return encoded_texts_masked, y_labels, sample_weights, masked_lm_positions

In [293]:
def create_mlm(encoded_text):
    # Input to `augment()` is a TensorFlow tensor which
    # is not supported by `imgaug`. This is why we first
    # convert it to its `numpy` variant.
    return get_masked_input_and_labels(encoded_text.numpy())

def add_start_end(ragged):
    count = ragged.bounding_shape()[0]
    starts = tf.fill([count,1], CLS_ID)
    ends = tf.fill([count,1], SEP_ID)
    return tf.concat([starts, ragged, ends], axis=1)


MAX_SEQ_LEN = 128
def text_to_instance(batch):
    text = batch['text']
    encoded_text = bert_tokenizer.tokenize(text)
    encoded_text = tf.cast(encoded_text.merge_dims(-2, -1), tf.int32)
    encoded_text = encoded_text[:, :MAX_SEQ_LEN-2]
    encoded_text = add_start_end(encoded_text)
    encoded_text = encoded_text.to_tensor()
    
    input_ids, masked_lm_labels, masked_lm_weights, masked_lm_positions = tf.py_function(create_mlm, [encoded_text],
                                                                    [tf.int32, tf.int32, tf.float32, tf.int32])
        
    # masked_lm_labels = get_2d_from_2d(masked_lm_labels, masked_lm_positions)
    # masked_lm_weights = get_2d_from_2d(masked_lm_weights, masked_lm_positions)
    
    input_type_ids = tf.zeros_like(input_ids)
    input_mask     = tf.ones_like(input_ids)
    
    inputs = {}
    inputs['input_ids'] = input_ids
    inputs['input_type_ids'] = input_type_ids
    inputs['input_mask'] = input_mask
    inputs['masked_lm_positions'] = masked_lm_positions
    
    labels = {}
    labels['masked_lm_labels'] = masked_lm_labels
    labels['masked_lm_mask']   = masked_lm_weights
    
    return (inputs, labels)

In [294]:
batch_size = 5
dataset_unsupervised = imdb['unsupervised']
dataset_unsupervised = dataset_unsupervised.batch(batch_size)
dataset_unsupervised = dataset_unsupervised.map(text_to_instance)

In [308]:
for (batch_inputs, batch_labels) in dataset_unsupervised.take(1):
    print(batch_inputs, batch_labels)

{'input_ids': <tf.Tensor: shape=(5, 128), dtype=int32, numpy=
array([[  101, 16625,  2346, 17656,  9637,   118,  1986,  3650,  1103,
         3830,   146,  1525,  1122,  1177,   103,  1115,  1103,  2006,
         2523,  2274,  1282,  1107,   103, 18976,  1105,  1296,  1959,
         1144,   170,  1472,   103,   103,  2431, 21718,  1673,  1234,
         1138,  1242,  1472,  5402,   103,  1147,  5935,   117,  1133,
          103,  1274,   103,   189,  1519,  1172,   103,  7065,   118,
          118,  1152,  1132,  4013,   119,  8491,   112,   188,  1672,
          103,  1105,  2993,  1127,  1825,   103,  1107,  1296,  1959,
          119,  1109, 21803,  1534,   113,  7872,   153,    69,   103,
          114,   117,  1103,  1226,    30,  1140,  1150,   103,  1123,
         1111,  1217,   170, 21803,   113,  5931,   114,   117,  1103,
           37,   103,  1119,  3683,  1119,  1125,   117,  1103,  9207,
         1401,  1119,  3683,  1119,  1125,   117,  1103, 15589,   103,
         1104, 

In [313]:
from tf_transformers.models import BertModel
from tf_transformers.losses import cross_entropy_loss

In [305]:
model_name = 'bert-base-cased'
model, config = BertModel.get_model(model_name=model_name,
                                    use_masked_lm_positions=True,
                                    return_all_layer_outputs=True)

INFO:absl:Successful: Model checkpoints matched and loaded from /tmp/tf_transformers_cache/bert-base-cased


In [306]:
model.input

{'input_ids': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_ids')>,
 'input_mask': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_mask')>,
 'input_type_ids': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_type_ids')>,
 'masked_lm_positions': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'masked_lm_positions')>}

In [307]:
model.output

{'cls_output': <KerasTensor: shape=(None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'token_embeddings': <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'token_logits': <KerasTensor: shape=(None, None, 28996) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'last_token_logits': <KerasTensor: shape=(None, 28996) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'all_layer_token_embeddings': [<KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, 

In [309]:
model_outputs = model(batch_inputs)

In [312]:
model_outputs['all_layer_token_logits']

[<tf.Tensor: shape=(5, 17, 28996), dtype=float32, numpy=
 array([[[ 0.6364071 ,  0.39737868, -0.09157881, ...,  0.34175146,
           0.5136198 , -0.78899205],
         [-0.06405342,  0.44488218,  0.09692273, ...,  0.3577524 ,
           0.01640294, -0.48131403],
         [ 0.02143574,  0.16771254, -0.10223933, ..., -0.44842073,
          -0.3124991 , -0.39341444],
         ...,
         [ 0.10923192,  0.46006092,  0.51024324, ...,  0.14270039,
          -0.16740501, -0.6073307 ],
         [ 0.3565812 ,  0.2618315 ,  1.0749931 , ...,  0.11836857,
          -0.02579454, -0.8573086 ],
         [-0.6148638 ,  0.6964495 , -0.4577043 , ...,  0.3505425 ,
           0.25855872, -0.96123344]],
 
        [[ 0.6176162 ,  0.7161437 ,  0.6636988 , ...,  0.37515405,
           0.3667674 , -0.5562779 ],
         [-0.08186221,  0.58370185,  0.04298308, ...,  0.10876828,
          -0.6345585 , -0.6571902 ],
         [ 0.31409347,  0.4121155 ,  0.10250878, ...,  0.18999144,
          -0.12050951, -0.3

In [316]:
def loss_fn(y_true_dict, y_pred_dict):
    
    loss_dict = {}
    loss_holder = []
    for i, layer_output in enumerate(y_pred_dict['all_layer_token_logits']):
        layer_loss = cross_entropy_loss(labels=y_true_dict['masked_lm_labels'], 
                                       logits=layer_output, 
                                       label_weights=y_true_dict['masked_lm_mask'])
        loss_dict['layer_{}'.format(i)] = layer_loss
        loss_holder.append(layer_loss)
    loss_dict['loss'] = tf.reduce_mean(loss_holder)
    return loss_dict

In [317]:
loss_fn(batch_labels, model_outputs)

{'layer_0': <tf.Tensor: shape=(), dtype=float32, numpy=10.438091>,
 'layer_1': <tf.Tensor: shape=(), dtype=float32, numpy=10.436256>,
 'layer_2': <tf.Tensor: shape=(), dtype=float32, numpy=10.396082>,
 'layer_3': <tf.Tensor: shape=(), dtype=float32, numpy=10.433259>,
 'layer_4': <tf.Tensor: shape=(), dtype=float32, numpy=10.42015>,
 'layer_5': <tf.Tensor: shape=(), dtype=float32, numpy=10.434258>,
 'layer_6': <tf.Tensor: shape=(), dtype=float32, numpy=10.450171>,
 'layer_7': <tf.Tensor: shape=(), dtype=float32, numpy=10.464069>,
 'layer_8': <tf.Tensor: shape=(), dtype=float32, numpy=10.49289>,
 'layer_9': <tf.Tensor: shape=(), dtype=float32, numpy=10.50401>,
 'layer_10': <tf.Tensor: shape=(), dtype=float32, numpy=10.475094>,
 'layer_11': <tf.Tensor: shape=(), dtype=float32, numpy=10.440686>,
 'loss': <tf.Tensor: shape=(), dtype=float32, numpy=10.448751>}

In [None]:
class MaskedTextGenerator(tf.keras.callbacks.Callback):
    def __init__(self, tokenizer, top_k=5):
        self.tokenizer = tokenizer
        model , config = BertModel.get_model(model_name='bert-base-cased',
                                    use_masked_lm_positions=False, 
                                    return_all_layer_outputs=True)
        self.original_model = model


    def on_epoch_end(self, epoch, logs=None):
        self.original_model.set_weights(self.model.get_weights())
        sample_text = "I have watched this [MASK] and it was awesome"

        input_ids = tf.constant(self.tokenizer.encode(sample_text))

        masked_index = np.where(input_ids == MASK_ID)[0][0]
        input_ids = tf.expand_dims(input_ids, axis=0)
        input_type_ids = tf.zeros_like(input_ids)
        input_mask     = tf.ones_like(input_ids)
        inputs = {}
        inputs["input_ids"] = input_ids
        inputs["input_type_ids"] = input_type_ids
        inputs["input_mask"] = input_mask
        outputs = self.original_model(inputs)

        for i, layer_output in enumerate(outputs['all_layer_token_logits']):
          prob_value = tf.reduce_max(layer_output, axis=-1)[0][masked_index]
          predicted_token = self.tokenizer.decode([tf.argmax(layer_output, axis=-1)[0][masked_index]])
          print("Layer {}, {}, {}".format(i, predicted_token, prob_value))


generator_callback = MaskedTextGenerator(tokenizer)