In [1]:
import sys
sys.path.append("/home/sidhu/Projects/tf-transformers/src/")

In [2]:
from tf_transformers.data import TFWriter, TFReader, TFProcessor

In [3]:
import tensorflow as tf
from tf_transformers.layers.mask import SelfAttentionMask

In [4]:
# Create dummy tfrecord

def parse_train():
    for i in range(1000):
        random_length = tf.random.uniform(minval=10, maxval=128, shape=(1,), dtype=tf.int32)[0]
        vector = tf.random.uniform(minval=0, maxval=10000,shape=(random_length,), dtype=tf.int32)
        vector = vector.numpy().tolist()
        yield {"input_ids": vector}
    
    
schema = {
    "input_ids": ("var_len", "int"),
}

tfrecord_train_dir = 'tfrecord_dummy'
tfrecord_filename = 'dummy'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    n_files=1,
                    overwrite=True
                    )
# Process
tfwriter.process(parse_fn=parse_train())


INFO:absl:Total individual observations/examples written is 1000 in 0.7124912738800049 seconds
INFO:absl:All writer objects closed


In [5]:
# Read tfrecord
import glob
def get_tfdataset_from_tfrecords(tfrecord_path_list):
    """Get tf dataset from tfrecords"""
    all_files = []
    for tfrecord_path in tfrecord_path_list:
        all_files.extend(glob.glob("{}/*.tfrecord".format(tfrecord_path)))
    schema    = json.load(open("{}/schema.json".format(tfrecord_path)))
    tf_reader = TFReader(schema=schema, 
                        tfrecord_files=all_files)
    train_dataset = tf_reader.read_record(
                                      )
    return train_dataset

dataset = get_tfdataset_from_tfrecords([tfrecord_train_dir])

In [6]:
for item in dataset:
    print(item)
    break

{'input_ids': <tf.Tensor: shape=(22,), dtype=int32, numpy=
array([2292, 9249, 1818, 8795, 1841, 2282, 1352, 8085, 9466, 7381, 7992,
       5334, 9255, 8633,   25, 1242, 7332, 3498, 3038, 4843, 3807, 9422],
      dtype=int32)>}


In [7]:
from tf_transformers.layers.mask import prefix_mask
from tf_transformers.data.utils import auto_batch

In [8]:
MAX_SENTENCE = 128

def map_prefix(item):
    input_ids = item['input_ids']
    input_max_length = tf.shape(input_ids)[0]
    sentence_length = tf.random.uniform(minval=1, maxval=input_max_length, shape=(1,), dtype=tf.int32)[0]
    remaining_length = input_max_length - sentence_length
    input_mask = tf.concat([tf.ones(shape=(sentence_length,), dtype=tf.int32),
               tf.zeros(shape=(remaining_length,), dtype=tf.int32)], axis=0)
    input_type_ids = tf.zeros_like(input_ids)
    mask = prefix_mask(input_mask)
    result = {'input_ids': input_ids,
              'input_mask': input_mask, 
              'input_type_ids': input_type_ids, 
              'mask': mask}
    return result

In [12]:
def auto_batch(
    tf_dataset,
    batch_size,
    padded_values=None,
    padded_shapes=None,
    x_keys=None,
    y_keys=None,
    shuffle=False,
    drop_remainder=True,
    shuffle_buffer_size=100,
    prefetch_buffer_size=100,
):
    """Auto Batching

    Args:
        tf_dataset (tf.data.Dataset): TF dataset
        batch_size (int): Batch Size
        padded_values (dict): dict of key to padded values eg: {'key': tf.constant(0)}
        padded_shapes (dict): dict of key to padded shapes eg: 'key': (None,)}
        x_keys (list): List of key names. We will filter based on this.
        y_keys (list): List of key names.
        shuffle (bool):  Defaults to False.
        shuffle_buffer_size (int):  Defaults to 100.
        prefetch_buffer_size (int): Defaults to 100.

    Returns:
        tf.data.Dataset: Batched
    """
    element_spec = tf_dataset.element_spec
    _padded_values = {}
    if not padded_values:
        padded_values = {}
    if not padded_shapes:
        padded_shapes = {}
    # sometimes we might have to have sme custom values other than 0
    for k, v in element_spec.items():
        if k in padded_values:
            value = padded_values[k]
            _padded_values[k] = tf.constant(value, dtype=value.dtype)
        else:
            if v.dtype == tf.string:
                _padded_values[k] = tf.constant("0", dtype=v.dtype)
                continue

            _padded_values[k] = tf.constant(0, dtype=v.dtype)

    _padded_shapes = {}
    for k, v in element_spec.items():
        if k in padded_shapes:
            _padded_shapes[k] = padded_shapes[k]
        else:
            if len(v.shape.dims) == 1:
                _padded_shapes[k] = [None]
            if len(v.shape.dims) == 0:
                _padded_shapes[k] = []
            if len(v.shape.dims) == 2:
                _padded_shapes[k] = [None, None]
                #raise ValueError("Seems like `{}` has 2 dimensional or more".format(v))

    dataset = tf_dataset.padded_batch(
        padding_values=_padded_values,
        padded_shapes=_padded_shapes,
        batch_size=batch_size,
        drop_remainder=drop_remainder,
    )
    # fmt: off
    if x_keys and y_keys:
        dataset = dataset.map(lambda x: separate_x_y(x, x_keys, y_keys), num_parallel_calls=tf.data.experimental.AUTOTUNE)  # noqa
    # fmt: on
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size, seed=None, reshuffle_each_iteration=True)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    return dataset

In [13]:
dataset_prefix = dataset.map(map_prefix)
#dataset_prefix = dataset_prefix.padded_batch(5)
dataset_prefix = auto_batch(dataset_prefix, batch_size=5, shuffle=True)

In [14]:
for item in dataset_prefix:
    print(item)
    break

{'input_ids': <tf.Tensor: shape=(5, 120), dtype=int32, numpy=
array([[9808, 4668, 9839, 1249, 4103, 5220, 9767, 2588, 3824, 3537, 7871,
        7590, 2892, 6575, 8490, 9166, 2834, 1114, 9205, 4033, 8614, 7614,
        3605, 6878,  363, 2153, 3995, 2337,  100, 3501, 9182, 9450, 2711,
          44, 1613, 5636, 5972, 5342, 9577, 7488,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0,    0,    0],
       [2653, 5191, 7848, 8146, 4583,  111, 6211,  827, 4844, 1824, 2152,
        5322, 3737, 3396, 9301, 6472, 7964, 9421, 3076,

In [16]:
[None] * 3

[None, None, None]

In [72]:
tf.reduce_sum(item['mask'], axis=1)

<tf.Tensor: shape=(98,), dtype=float32, numpy=
array([16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16., 16.,
       16., 16., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26.,
       27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39.,
       40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52.,
       53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64., 65.,
       66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78.,
       79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91.,
       92., 93., 94., 95., 96., 97., 98.], dtype=float32)>

In [44]:
remaining_length

<tf.Tensor: shape=(), dtype=int32, numpy=72>

In [6]:
# Normal attention masking

input_mask = tf.constant([[1, 1, 1, 0, 0], 
              [1, 1, 1, 1, 1], 
              [1, 1, 1, 1, 1]
             ])
#attention_mask = SelfAttentionMask()([embeddings, input_mask])

In [7]:
batch_size = tf.shape(input_mask)[0]
seq_length = tf.shape(input_mask)[1]

In [9]:
to_mask = tf.cast(tf.reshape(input_mask, [batch_size, 1, seq_length]), dtype=input_mask.dtype)
broadcast_ones = tf.ones(shape=[batch_size, seq_length, 1], dtype=input_mask.dtype)

In [10]:
to_mask

<tf.Tensor: shape=(3, 1, 5), dtype=int32, numpy=
array([[[1, 1, 1, 0, 0]],

       [[1, 1, 1, 1, 1]],

       [[1, 1, 1, 1, 1]]], dtype=int32)>

In [12]:
broadcast_ones

<tf.Tensor: shape=(3, 5, 1), dtype=int32, numpy=
array([[[1],
        [1],
        [1],
        [1],
        [1]],

       [[1],
        [1],
        [1],
        [1],
        [1]],

       [[1],
        [1],
        [1],
        [1],
        [1]]], dtype=int32)>

In [13]:
mask = broadcast_ones * to_mask

In [14]:
mask

<tf.Tensor: shape=(3, 5, 5), dtype=int32, numpy=
array([[[1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0]],

       [[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]],

       [[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]]], dtype=int32)>