In [1]:
import sys
sys.path.append("/home/sidhu/Projects/tf-transformers/src/")

In [2]:
from tf_transformers.data import TFWriter, TFReader, TFProcessor

In [19]:
import tensorflow as tf
from tf_transformers.layers.mask import SelfAttentionMask
from tf_transformers.data.processors.mlm import dynamic_masking_from_features

In [4]:
# Create dummy tfrecord

def parse_train():
    for i in range(1000):
        random_length = tf.random.uniform(minval=10, maxval=128, shape=(1,), dtype=tf.int32)[0]
        vector = tf.random.uniform(minval=0, maxval=10000,shape=(random_length,), dtype=tf.int32)
        vector = vector.numpy().tolist()
        yield {"input_ids": vector}
    
    
schema = {
    "input_ids": ("var_len", "int"),
}

tfrecord_train_dir = 'tfrecord_dummy'
tfrecord_filename = 'dummy'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    n_files=1,
                    overwrite=True
                    )
# Process
tfwriter.process(parse_fn=parse_train())


INFO:absl:Total individual observations/examples written is 1000 in 1.1995949745178223 seconds
INFO:absl:All writer objects closed


In [5]:
# Read tfrecord
import glob
def get_tfdataset_from_tfrecords(tfrecord_path_list):
    """Get tf dataset from tfrecords"""
    all_files = []
    for tfrecord_path in tfrecord_path_list:
        all_files.extend(glob.glob("{}/*.tfrecord".format(tfrecord_path)))
    schema    = json.load(open("{}/schema.json".format(tfrecord_path)))
    tf_reader = TFReader(schema=schema, 
                        tfrecord_files=all_files)
    train_dataset = tf_reader.read_record(
                                      )
    return train_dataset

dataset = get_tfdataset_from_tfrecords([tfrecord_train_dir])

In [6]:
for item in dataset:
    print(item)
    break

{'input_ids': <tf.Tensor: shape=(12,), dtype=int32, numpy=
array([1012, 3744, 2457, 4470, 6325, 1196, 8150, 7410, 4806, 6899, 3046,
       2511], dtype=int32)>}


In [392]:
from tf_transformers.layers.mask import prefix_mask
from tf_transformers.data.utils import auto_batch
from tf_transformers.utils import tf_utils

In [411]:
def dynamic_prefix_lm_from_features(max_seq_len, 
                                    cls_id, sep_id):
    
    def dynamic_map_prefix(item):
        input_ids = item['input_ids']
        input_ids = input_ids[:max_seq_len-1] # we need -2 for cls and sep, but in causal LM we shift one pos
                                              # so we use -1, length input_ids = max_seq_len + 1
        # Add cls sep
        input_ids = tf.concat([[cls_id], input_ids, [sep_id]], axis=0)
        labels    = input_ids[1:] # exclude first word till last
        input_ids = input_ids[:-1] # exclude last word
        
        input_seq_length = tf.shape(input_ids)[0]
        sentence_length = tf.random.uniform(minval=1, maxval=input_seq_length, shape=(1,), dtype=tf.int32)[0]
        remaining_length = input_seq_length - sentence_length
        input_mask = tf.concat([tf.ones(shape=(sentence_length,), dtype=tf.int32),
                   tf.zeros(shape=(remaining_length,), dtype=tf.int32)], axis=0)
        # Opposite to input_mask
        labels_mask = tf.concat([tf.zeros(shape=(sentence_length,), dtype=tf.int32),
                   tf.ones(shape=(remaining_length,), dtype=tf.int32)], axis=0)
        
        # input type ids
        input_type_ids = tf.zeros_like(input_ids)
        mask = prefix_mask(input_mask)
        inputs = {'input_ids': input_ids,
                  'input_type_ids': input_type_ids, 
                  '3d_mask': mask, 
                  'input_mask': input_mask,
                  'masked_lm_positions': tf.range(tf.shape(input_ids)[0])
                 }
        
        outputs = {
                  'masked_lm_labels': labels,
                  'masked_lm_weights': labels_mask}
        return inputs, outputs
    return dynamic_map_prefix


def dynamic_causal_lm_from_features(max_seq_len, 
                                    cls_id, sep_id):
    
    def attention_mask_square(nd):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        dtype = tf_utils.get_dtype()
        ns = nd
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def mask_causal_mask(input_ids):
        input_ids = tf.expand_dims(input_ids, 0)
        from_shape = tf_utils.get_shape_list(input_ids, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask

        return tf.cast(tf.squeeze(mask, axis=0), tf.float32)
    
    def dynamic_map_causal(item):
        input_ids = item['input_ids']
        input_ids = input_ids[:max_seq_len-1] # we need -2 for cls and sep, but in causal LM we shift one pos
                                              # so we use -1, length input_ids = max_seq_len + 1
        # Add cls sep
        input_ids = tf.concat([[cls_id], input_ids, [sep_id]], axis=0)
        labels    = input_ids[1:] # exclude first word till last
        input_ids = input_ids[:-1] # exclude last word
        labels_mask = tf.ones_like(input_ids)
        input_mask = labels_mask
        # input type ids
        input_type_ids = tf.zeros_like(input_ids)
        mask = mask_causal_mask(input_ids)
        
        inputs = {'input_ids': input_ids,
                  'input_type_ids': input_type_ids, 
                  '3d_mask': mask, 
                  'input_mask': input_mask,
                  'masked_lm_positions': tf.range(tf.shape(input_ids)[0])
                 }
        
        outputs = {
                  'masked_lm_labels': labels,
                  'masked_lm_weights': labels_mask}
        
        return inputs, outputs
    return dynamic_map_causal

In [409]:
    def attention_mask_square(nd):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        dtype = tf_utils.get_dtype()
        ns = nd
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def mask_causal_mask(input_ids):
        input_ids = tf.expand_dims(input_ids, 0)
        from_shape = tf_utils.get_shape_list(input_ids, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask

        return tf.cast(mask, tf.float32)

In [410]:
mask_causal_mask(tf.range(10))

<tf.Tensor: shape=(1, 10, 10), dtype=float32, numpy=
array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]], dtype=float32)>

In [412]:
def get_dataset(
    tfrecord_path_list,
    max_seq_len,
    max_predictions_per_batch,
    vocab_size,
    cls_token_id,
    sep_token_id,
    unk_token_id,
    pad_token_id,
    mask_token_id,
    batch_size,
    min_sen_len,
):
    """Get dataset after mlm from TFRecords"""

    def filter_by_length(x, min_sen_len):
        """Filter by minimum sentence length (subwords)"""
        return tf.squeeze(tf.greater_equal(tf.shape(x['input_ids']), tf.constant(min_sen_len)), axis=0)

    def filter_by_batch(x, y, batch_size):
        """Filter by batch size"""
        x_batch = tf.shape(x['input_ids'])[0]
        return tf.equal(x_batch, tf.constant(batch_size))
    
    def prepare_3d_input_mask_mlm(input_mask):
        """Prepare 3D mask from 2D"""
        batch_size = tf.shape(input_mask)[0]
        seq_length = tf.shape(input_mask)[1]

        to_mask = tf.cast(tf.reshape(input_mask, [batch_size, 1, seq_length]), dtype=input_mask.dtype)
        broadcast_ones = tf.ones(shape=[batch_size, seq_length, 1], dtype=input_mask.dtype)

        mask = broadcast_ones * to_mask

        return tf.cast(mask, tf.float32)
    
    # Dynamic MLM
    dynamic_mlm_fn = dynamic_masking_from_features(
        max_seq_len,
        max_predictions_per_batch,
        vocab_size,
        cls_token_id,
        sep_token_id,
        unk_token_id,
        pad_token_id,
        mask_token_id,
    )
    

    # Dynamic Prefix LM
    dynamic_prefix_lm = dynamic_prefix_lm_from_features(max_seq_len, 
                                    cls_token_id, sep_token_id)
    
    # Dynamic Causal LM
    dynamic_causal_lm = dynamic_causal_lm_from_features(max_seq_len, 
                                    cls_token_id, sep_token_id)
    
    train_dataset = get_tfdataset_from_tfrecords(tfrecord_path_list)

    if min_sen_len and min_sen_len > 0:
        train_dataset = train_dataset.filter(lambda x: filter_by_length(x, min_sen_len))
    
    # prob check has to be inside map
    # otherwise things become deterministic
    def get_dataset_based_on_prob(item):
        """Map function"""
        
        def add_mark(x, mode, prob):
            """Check are we getting all if conditions with equal probability"""
            x['mode'] = mode
            x['prob'] = prob
            return x
        
        def map_mlm(x):
            """MLM"""
            x['input_ids'] = tf.RaggedTensor.from_tensor(tf.expand_dims(x['input_ids'], axis=0))
            x_copy , y_copy = dynamic_mlm_fn(x)
            x = {}
            for name, v_tensor in x_copy.items():
                x[name] = tf.squeeze(v_tensor, axis=0)
            y = {}
            for name, v_tensor in y_copy.items():
                y[name] = tf.squeeze(v_tensor, axis=0)
            x['3d_mask']   = tf.squeeze(prepare_3d_input_mask_mlm(x_copy['input_mask']), axis=0)
            
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
        
        def map_pcmlm(x):
            """Prefix Causal LM"""
            x, y = dynamic_prefix_lm(x)
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
        
        def map_cmlm(x):
            """Causal LM"""
            x, y = dynamic_causal_lm(x)
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
    

        prob = tf.random.uniform(shape=())
        # Keep a copy like this importatnt
        # otherwise transformation in first if cond might affect other
        input_ids = item['input_ids']
        
        # Do MLM
        if prob <= 0.33:
            x = map_mlm(item)
            x['masked_lm_positions'] = tf.cast(x['masked_lm_positions'], dtype=tf.int32)
            x['masked_lm_weights']   = tf.cast(x['masked_lm_weights'], dtype=tf.int32)
            del x['input_mask']
            x = add_mark(x, "mlm", prob)
            
        # Prefix CLM
        elif prob < 0.66:
            x = map_pcmlm({"input_ids": input_ids})
            del x['input_mask']
            x = add_mark(x, "prefix", prob)
            
        else:
            x = map_cmlm({"input_ids": input_ids})
            del x['input_mask']
            x = add_mark(x, "causal", prob)
        return x
    
    train_dataset = train_dataset.map(get_dataset_based_on_prob, num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = auto_batch(train_dataset, 
                              batch_size, 
                              x_keys=['input_ids', 'input_type_ids', '3d_mask', 'masked_lm_positions'],
                              y_keys=['masked_lm_labels', 'masked_lm_weights', 'mode', 'prob'], 
                              shuffle=True
                              )
    train_dataset = train_dataset.filter(lambda x, y: filter_by_batch(x, y, batch_size))
    train_dataset = train_dataset.shuffle(100)
    train_dataset = train_dataset.prefetch(100)

    return train_dataset

In [421]:
max_seq_len = 128
max_predictions_per_batch = 20
vocab_size = 30200
cls_token_id = 2
sep_token_id = 3
unk_token_id = 1
pad_token_id = 0
mask_token_id = 5
batch_size = 5
min_sen_len = None
train_dataset = get_dataset(
    [tfrecord_train_dir],
    max_seq_len,
    max_predictions_per_batch,
    vocab_size,
    cls_token_id,
    sep_token_id,
    unk_token_id,
    pad_token_id,
    mask_token_id,
    batch_size,
    min_sen_len,
)

In [426]:
all_modes = []
all_probs = []
for (batch_inputs, batch_labels) in train_dataset:
    all_modes.extend(batch_labels['mode'].numpy())
    all_probs.extend(batch_labels['prob'].numpy())
    

In [428]:
from collections import Counter

Counter(all_modes)

Counter({b'causal': 336, b'mlm': 337, b'prefix': 327})

In [448]:
from tf_transformers.models import BertEncoder
from tf_transformers.core import LegacyModel

In [449]:
class MixEncoder(BertEncoder):
    
    def __init__(self, config, **kwargs):
        print(kwargs)
        super(MixEncoder, self).__init__(config, **kwargs)
        
    def get_model(self, initialize_only=False):
        """Convert tf.keras.Layer to a tf.keras.Model/LegacyModel.
        Args:
            self: model (tf.keras.Layer) instance
        """

        input_ids = tf.keras.layers.Input(
            shape=(self._sequence_length,),
            batch_size=self._batch_size,
            dtype=tf.int32,
            name="input_ids",
        )
        input_mask = tf.keras.layers.Input(
            shape=(self._sequence_length,self._sequence_length),
            batch_size=self._batch_size,
            dtype=tf.float32,
            name="input_mask",
        )
        input_type_ids = tf.keras.layers.Input(
            shape=(self._sequence_length,),
            batch_size=self._batch_size,
            dtype=tf.int32,
            name="input_type_ids",
        )
        masked_lm_positions = tf.keras.layers.Input(
            shape=(None,),
            batch_size=self._batch_size,
            dtype=tf.int32,
            name="masked_lm_positions",
        )
        inputs = {}
        inputs["input_ids"] = input_ids  # Default
        # if mask_mode != 'causal', user has to provde mask
        if self._mask_mode != "causal":
            inputs["input_mask"] = input_mask
        # If type mebddings required
        if self._type_embeddings_layer:
            inputs["input_type_ids"] = input_type_ids
        # if masked_lm_positions
        if self._use_masked_lm_positions:
            inputs["masked_lm_positions"] = masked_lm_positions


        layer_outputs = self(inputs)
        if initialize_only:
            return inputs, layer_outputs

        # Adding model_config is a hack
        model = LegacyModel(inputs=inputs, outputs=layer_outputs, name=self._model_name)
        model.model_config = self._config_dict
        return model

    def call_encoder(self, inputs):
        """Forward pass of an Encoder

        Args:
            inputs ([dict of tf.Tensor]): This is the input to the model.

            'input_ids'         --> tf.int32 (b x s)
            'input_mask'        --> tf.int32 (b x s) # optional
            'input_type_ids'    --> tf.int32 (b x s) # optional

        Returns:
            [dict of tf.Tensor]: Output from the model

            'cls_output'        --> tf.float32 (b x s) # optional
            'token_embeddings'  --> tf.float32 (b x s x h)
            'all_layer_token_embeddings' --> tf.float32 (List of (b x s x h)
                                              from all layers)
            'all_layer_cls_output'       --> tf.float32 (List of (b x s)
                                              from all layers)
        """

        # 1. Collect Word Embeddings
        input_ids = inputs["input_ids"]
        sequence_length = tf.shape(input_ids)[1]
        embeddings = self._embedding_layer(input_ids)
        # Add word_embeddings + position_embeddings + type_embeddings
        if self._type_embeddings_layer:
            input_type_ids = inputs["input_type_ids"]
            type_embeddings = self._type_embeddings_layer(input_type_ids)
            embeddings = embeddings + type_embeddings
        if self._positional_embedding_layer:
            positional_embeddings = self._positional_embedding_layer(tf.range(sequence_length))
            embeddings = embeddings + positional_embeddings

        # 2. Norm + dropout
        embeddings = self._embedding_norm(embeddings)
        embeddings = self._embedding_dropout(embeddings, training=self._use_dropout)

        # 3. Attention  Mask
        attention_mask = inputs['input_mask']

        # 4. Transformer Outputs
        encoder_outputs = []
        for i in range(self._config_dict["num_hidden_layers"]):
            layer = self._transformer_layers[i]
            embeddings, _, _ = layer([embeddings, attention_mask])
            encoder_outputs.append(embeddings)

        # First word of last layer outputs [CLS]
        cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(encoder_outputs[-1])
        # batch_size x embedding_size
        cls_output = self._pooler_layer(cls_token_tensor)
        # batch_size x sequence_length x embedding_size
        token_embeddings = encoder_outputs[-1]

        # check for masked lm positions
        # only for encoder forward pass. This is for MaskedLM training
        if "masked_lm_positions" in inputs:
            masked_lm_positions = inputs["masked_lm_positions"]
        else:
            masked_lm_positions = None

        # MaskedLM layer only project it and normalize (b x s x h)
        token_embeddings_mlm = self._masked_lm_layer(token_embeddings, masked_lm_positions)
        token_logits = tf.matmul(
            token_embeddings_mlm, tf.cast(self.get_embedding_table(), dtype=tf_utils.get_dtype()), transpose_b=True
        )
        # token_logits         =  tf.nn.bias_add(token_logits, self._masked_lm_bias)
        token_logits = self._masked_lm_bias(token_logits)
        last_token_logits = tf.keras.layers.Lambda(lambda x: x[:, -1, :])(token_logits)

        result = {
            "cls_output": cls_output,
            "token_embeddings": token_embeddings,
            "token_logits": token_logits,
            "last_token_logits": last_token_logits,
        }

        if self._return_all_layer_outputs:
            all_cls_output = []
            all_token_logits = []
            for per_layer_token_embeddings in encoder_outputs:
                per_cls_token_tensor = tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
                    per_layer_token_embeddings
                )
                all_cls_output.append(self._pooler_layer(per_cls_token_tensor))

                # token logits per layer
                layer_token_embeddings_mlm = self._masked_lm_layer(per_layer_token_embeddings, masked_lm_positions)
                layer_token_logits = tf.matmul(
                    layer_token_embeddings_mlm,
                    tf.cast(self.get_embedding_table(), dtype=tf_utils.get_dtype()),
                    transpose_b=True,
                )
                layer_token_logits = self._masked_lm_bias(layer_token_logits)
                all_token_logits.append(layer_token_logits)

            result["all_layer_token_embeddings"] = encoder_outputs
            result["all_layer_cls_output"] = all_cls_output
            result["all_layer_token_logits"] = all_token_logits

        return result


In [450]:
config = {
    "attention_probs_dropout_prob": 0.1,
    "hidden_act": "gelu",
    "intermediate_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "embedding_size": 768,
    "initializer_range": 0.02,
    "intermediate_size": 3072,
    "max_position_embeddings": 512,
    "num_attention_heads": 12,
    "attention_head_size": 64,
    "num_hidden_layers": 12,
    "type_vocab_size": 2,
    "vocab_size": 30000,
    "layer_norm_epsilon": 1e-12
}

tf.keras.backend.clear_session()
model = MixEncoder(config,
                         is_training=True,
                         use_dropout=True,
                         use_masked_lm_positions=True,
                         return_all_layer_outputs=True)
model = model.get_model()

{'is_training': True, 'use_dropout': True, 'use_masked_lm_positions': True, 'return_all_layer_outputs': True}


In [451]:
model.input

{'input_ids': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_ids')>,
 'input_mask': <KerasTensor: shape=(None, None, None) dtype=float32 (created by layer 'input_mask')>,
 'input_type_ids': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_type_ids')>,
 'masked_lm_positions': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'masked_lm_positions')>}

In [452]:
model.output

{'cls_output': <KerasTensor: shape=(None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'token_embeddings': <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'token_logits': <KerasTensor: shape=(None, None, 30000) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'last_token_logits': <KerasTensor: shape=(None, 30000) dtype=float32 (created by layer 'tf_transformers/bert')>,
 'all_layer_token_embeddings': [<KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, None, 768) dtype=float32 (created by layer 'tf_transformers/bert')>,
  <KerasTensor: shape=(None, 