In [1]:
import sys
sys.path.append("/home/sidhu/Projects/tf-transformers/src/")

In [2]:
from tf_transformers.data import TFWriter, TFReader, TFProcessor

In [19]:
import tensorflow as tf
from tf_transformers.layers.mask import SelfAttentionMask
from tf_transformers.data.processors.mlm import dynamic_masking_from_features

In [4]:
# Create dummy tfrecord

def parse_train():
    for i in range(1000):
        random_length = tf.random.uniform(minval=10, maxval=128, shape=(1,), dtype=tf.int32)[0]
        vector = tf.random.uniform(minval=0, maxval=10000,shape=(random_length,), dtype=tf.int32)
        vector = vector.numpy().tolist()
        yield {"input_ids": vector}
    
    
schema = {
    "input_ids": ("var_len", "int"),
}

tfrecord_train_dir = 'tfrecord_dummy'
tfrecord_filename = 'dummy'
tfwriter = TFWriter(schema=schema, 
                    file_name=tfrecord_filename, 
                    model_dir=tfrecord_train_dir,
                    tag='train',
                    n_files=1,
                    overwrite=True
                    )
# Process
tfwriter.process(parse_fn=parse_train())


INFO:absl:Total individual observations/examples written is 1000 in 1.1995949745178223 seconds
INFO:absl:All writer objects closed


In [5]:
# Read tfrecord
import glob
def get_tfdataset_from_tfrecords(tfrecord_path_list):
    """Get tf dataset from tfrecords"""
    all_files = []
    for tfrecord_path in tfrecord_path_list:
        all_files.extend(glob.glob("{}/*.tfrecord".format(tfrecord_path)))
    schema    = json.load(open("{}/schema.json".format(tfrecord_path)))
    tf_reader = TFReader(schema=schema, 
                        tfrecord_files=all_files)
    train_dataset = tf_reader.read_record(
                                      )
    return train_dataset

dataset = get_tfdataset_from_tfrecords([tfrecord_train_dir])

In [6]:
for item in dataset:
    print(item)
    break

{'input_ids': <tf.Tensor: shape=(12,), dtype=int32, numpy=
array([1012, 3744, 2457, 4470, 6325, 1196, 8150, 7410, 4806, 6899, 3046,
       2511], dtype=int32)>}


In [392]:
from tf_transformers.layers.mask import prefix_mask
from tf_transformers.data.utils import auto_batch
from tf_transformers.utils import tf_utils

In [411]:
def dynamic_prefix_lm_from_features(max_seq_len, 
                                    cls_id, sep_id):
    
    def dynamic_map_prefix(item):
        input_ids = item['input_ids']
        input_ids = input_ids[:max_seq_len-1] # we need -2 for cls and sep, but in causal LM we shift one pos
                                              # so we use -1, length input_ids = max_seq_len + 1
        # Add cls sep
        input_ids = tf.concat([[cls_id], input_ids, [sep_id]], axis=0)
        labels    = input_ids[1:] # exclude first word till last
        input_ids = input_ids[:-1] # exclude last word
        
        input_seq_length = tf.shape(input_ids)[0]
        sentence_length = tf.random.uniform(minval=1, maxval=input_seq_length, shape=(1,), dtype=tf.int32)[0]
        remaining_length = input_seq_length - sentence_length
        input_mask = tf.concat([tf.ones(shape=(sentence_length,), dtype=tf.int32),
                   tf.zeros(shape=(remaining_length,), dtype=tf.int32)], axis=0)
        # Opposite to input_mask
        labels_mask = tf.concat([tf.zeros(shape=(sentence_length,), dtype=tf.int32),
                   tf.ones(shape=(remaining_length,), dtype=tf.int32)], axis=0)
        
        # input type ids
        input_type_ids = tf.zeros_like(input_ids)
        mask = prefix_mask(input_mask)
        inputs = {'input_ids': input_ids,
                  'input_type_ids': input_type_ids, 
                  '3d_mask': mask, 
                  'input_mask': input_mask,
                  'masked_lm_positions': tf.range(tf.shape(input_ids)[0])
                 }
        
        outputs = {
                  'masked_lm_labels': labels,
                  'masked_lm_weights': labels_mask}
        return inputs, outputs
    return dynamic_map_prefix


def dynamic_causal_lm_from_features(max_seq_len, 
                                    cls_id, sep_id):
    
    def attention_mask_square(nd):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        dtype = tf_utils.get_dtype()
        ns = nd
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def mask_causal_mask(input_ids):
        input_ids = tf.expand_dims(input_ids, 0)
        from_shape = tf_utils.get_shape_list(input_ids, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask

        return tf.cast(tf.squeeze(mask, axis=0), tf.float32)
    
    def dynamic_map_causal(item):
        input_ids = item['input_ids']
        input_ids = input_ids[:max_seq_len-1] # we need -2 for cls and sep, but in causal LM we shift one pos
                                              # so we use -1, length input_ids = max_seq_len + 1
        # Add cls sep
        input_ids = tf.concat([[cls_id], input_ids, [sep_id]], axis=0)
        labels    = input_ids[1:] # exclude first word till last
        input_ids = input_ids[:-1] # exclude last word
        labels_mask = tf.ones_like(input_ids)
        input_mask = labels_mask
        # input type ids
        input_type_ids = tf.zeros_like(input_ids)
        mask = mask_causal_mask(input_ids)
        
        inputs = {'input_ids': input_ids,
                  'input_type_ids': input_type_ids, 
                  '3d_mask': mask, 
                  'input_mask': input_mask,
                  'masked_lm_positions': tf.range(tf.shape(input_ids)[0])
                 }
        
        outputs = {
                  'masked_lm_labels': labels,
                  'masked_lm_weights': labels_mask}
        
        return inputs, outputs
    return dynamic_map_causal

In [409]:
    def attention_mask_square(nd):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        dtype = tf_utils.get_dtype()
        ns = nd
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def mask_causal_mask(input_ids):
        input_ids = tf.expand_dims(input_ids, 0)
        from_shape = tf_utils.get_shape_list(input_ids, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask

        return tf.cast(mask, tf.float32)

In [410]:
mask_causal_mask(tf.range(10))

<tf.Tensor: shape=(1, 10, 10), dtype=float32, numpy=
array([[[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]], dtype=float32)>

In [412]:
def get_dataset(
    tfrecord_path_list,
    max_seq_len,
    max_predictions_per_batch,
    vocab_size,
    cls_token_id,
    sep_token_id,
    unk_token_id,
    pad_token_id,
    mask_token_id,
    batch_size,
    min_sen_len,
):
    """Get dataset after mlm from TFRecords"""

    def filter_by_length(x, min_sen_len):
        """Filter by minimum sentence length (subwords)"""
        return tf.squeeze(tf.greater_equal(tf.shape(x['input_ids']), tf.constant(min_sen_len)), axis=0)

    def filter_by_batch(x, y, batch_size):
        """Filter by batch size"""
        x_batch = tf.shape(x['input_ids'])[0]
        return tf.equal(x_batch, tf.constant(batch_size))
    
    def prepare_3d_input_mask_mlm(input_mask):
        """Prepare 3D mask from 2D"""
        batch_size = tf.shape(input_mask)[0]
        seq_length = tf.shape(input_mask)[1]

        to_mask = tf.cast(tf.reshape(input_mask, [batch_size, 1, seq_length]), dtype=input_mask.dtype)
        broadcast_ones = tf.ones(shape=[batch_size, seq_length, 1], dtype=input_mask.dtype)

        mask = broadcast_ones * to_mask

        return tf.cast(mask, tf.float32)
    
    # Dynamic MLM
    dynamic_mlm_fn = dynamic_masking_from_features(
        max_seq_len,
        max_predictions_per_batch,
        vocab_size,
        cls_token_id,
        sep_token_id,
        unk_token_id,
        pad_token_id,
        mask_token_id,
    )
    

    # Dynamic Prefix LM
    dynamic_prefix_lm = dynamic_prefix_lm_from_features(max_seq_len, 
                                    cls_token_id, sep_token_id)
    
    # Dynamic Causal LM
    dynamic_causal_lm = dynamic_causal_lm_from_features(max_seq_len, 
                                    cls_token_id, sep_token_id)
    
    train_dataset = get_tfdataset_from_tfrecords(tfrecord_path_list)

    if min_sen_len and min_sen_len > 0:
        train_dataset = train_dataset.filter(lambda x: filter_by_length(x, min_sen_len))
    
    # prob check has to be inside map
    # otherwise things become deterministic
    def get_dataset_based_on_prob(item):
        """Map function"""
        
        def add_mark(x, mode, prob):
            """Check are we getting all if conditions with equal probability"""
            x['mode'] = mode
            x['prob'] = prob
            return x
        
        def map_mlm(x):
            """MLM"""
            x['input_ids'] = tf.RaggedTensor.from_tensor(tf.expand_dims(x['input_ids'], axis=0))
            x_copy , y_copy = dynamic_mlm_fn(x)
            x = {}
            for name, v_tensor in x_copy.items():
                x[name] = tf.squeeze(v_tensor, axis=0)
            y = {}
            for name, v_tensor in y_copy.items():
                y[name] = tf.squeeze(v_tensor, axis=0)
            x['3d_mask']   = tf.squeeze(prepare_3d_input_mask_mlm(x_copy['input_mask']), axis=0)
            
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
        
        def map_pcmlm(x):
            """Prefix Causal LM"""
            x, y = dynamic_prefix_lm(x)
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
        
        def map_cmlm(x):
            """Causal LM"""
            x, y = dynamic_causal_lm(x)
            for name, v_tensor in y.items():
                x[name] = v_tensor
            return x
    

        prob = tf.random.uniform(shape=())
        # Keep a copy like this importatnt
        # otherwise transformation in first if cond might affect other
        input_ids = item['input_ids']
        
        # Do MLM
        if prob <= 0.33:
            x = map_mlm(item)
            x['masked_lm_positions'] = tf.cast(x['masked_lm_positions'], dtype=tf.int32)
            x['masked_lm_weights']   = tf.cast(x['masked_lm_weights'], dtype=tf.int32)
            del x['input_mask']
            x = add_mark(x, "mlm", prob)
            
        # Prefix CLM
        elif prob < 0.66:
            x = map_pcmlm({"input_ids": input_ids})
            del x['input_mask']
            x = add_mark(x, "prefix", prob)
            
        else:
            x = map_cmlm({"input_ids": input_ids})
            del x['input_mask']
            x = add_mark(x, "causal", prob)
#             x = map_pcmlm({"input_ids": input_ids})
#             del x['input_mask']
#             x = add_mark(x, "prefix", prob)
        return x
    
    train_dataset = train_dataset.map(get_dataset_based_on_prob, num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = auto_batch(train_dataset, 
                              batch_size, 
                              x_keys=['input_ids', 'input_type_ids', '3d_mask', 'masked_lm_positions'],
                              y_keys=['masked_lm_labels', 'masked_lm_weights', 'mode', 'prob'], 
                              shuffle=True
                              )
    # train_dataset = train_dataset.padded_batch(batch_size)
    #train_dataset = train_dataset.filter(lambda x, y: filter_by_batch(x, y, batch_size))
    #train_dataset = train_dataset.shuffle(100)
    #train_dataset = train_dataset.prefetch(100)

    return train_dataset

In [421]:
max_seq_len = 128
max_predictions_per_batch = 20
vocab_size = 30200
cls_token_id = 2
sep_token_id = 3
unk_token_id = 1
pad_token_id = 0
mask_token_id = 5
batch_size = 5
min_sen_len = None
train_dataset = get_dataset(
    [tfrecord_train_dir],
    max_seq_len,
    max_predictions_per_batch,
    vocab_size,
    cls_token_id,
    sep_token_id,
    unk_token_id,
    pad_token_id,
    mask_token_id,
    batch_size,
    min_sen_len,
)

In [426]:
all_modes = []
all_probs = []
for (batch_inputs, batch_labels) in train_dataset:
    all_modes.extend(batch_labels['mode'].numpy())
    all_probs.extend(batch_labels['prob'].numpy())
    

In [428]:
from collections import Counter

Counter(all_modes)

Counter({b'causal': 336, b'mlm': 337, b'prefix': 327})

In [348]:
    def attention_mask_square(nd):
        """1's in the lower triangle, counting from the lower right corner.

        Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
        """
        dtype = tf_utils.get_dtype()
        ns = nd
        i = tf.range(nd)[:, None]
        j = tf.range(ns)
        m = i >= j - ns + nd
        return tf.cast(m, dtype)

    def mask_causal_mask(input_ids):
        input_ids = tf.expand_dims(input_ids, 0)
        from_shape = tf_utils.get_shape_list(input_ids, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        # 2D Lower Triangular Mask
        from_mask = attention_mask_square(from_seq_length)

        # Replicate 2D `N` times
        mask = tf.cast(tf.ones([batch_size, 1, 1]), from_mask.dtype) * from_mask

        return tf.cast(mask, tf.float32)

In [350]:
mask_causal_mask(batch_inputs['input_ids'][0])

NameError: name 'tf_utils' is not defined

In [346]:
batch_labels

{'masked_lm_labels': <tf.Tensor: shape=(5, 124), dtype=int32, numpy=
 array([[7360, 2671, 1797,  268,  385, 7854, 2081, 9934, 2489, 6778, 8202,
         3485, 4311, 1653, 7400, 5233, 2070,  505, 9003, 4187, 7152, 3756,
          122, 7860, 2346, 7479, 3400, 5167, 8520, 6820, 3688, 7064, 9408,
         9050, 9255, 5864, 6016, 4608, 4062, 1373, 6370, 1864,  576,  961,
         3370, 8433, 4274, 1888, 3030, 8563, 5498,  722, 5040, 3101, 6857,
         5985, 7828, 2840, 8387,   86, 9464, 4574, 5887, 1686, 5011, 5806,
         7871, 9073, 1729, 5455, 5356, 7183, 7743, 5720, 9512, 9855, 6594,
         3409, 9087,  944, 6050,  573, 6634, 3722, 9980, 3443, 7708,  461,
         4780,  754,  977, 6616, 4482, 6597, 8777, 9404, 4605, 4580, 7966,
         2954, 6345, 3159, 4283, 4873,    3,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0],
        [8845,  883, 7451, 9982, 4865,  610, 5307, 5411, 2123, 1204, 9199,
   

In [209]:
def map_mlm_check(x):
    """Map function"""

    def add_mark(x, y, mode, prob):
        y['mode'] = mode
        y['prob'] = prob
        return x, y

    prob = tf.random.uniform(shape=()) 
    #train_dataset = train_dataset.apply(tf.data.experimental.dense_to_ragged_batch(batch_size=batch_size))
    x['input_ids'] = tf.RaggedTensor.from_tensor(tf.expand_dims(x['input_ids'], axis=0))
    x , y = dynamic_mlm_fn(x)
    x['3d_mask']   = tf.cast(prepare_3d_input_mask_mlm(x['input_mask']), tf.int32)
    
    return x, y

def map_pclm_check(x):
    """Map function"""

    def add_mark(x, y, mode, prob):
        y['mode'] = mode
        y['prob'] = prob
        return x, y

    x, y = dynamic_prefix_lm(x)
    #x['masked_lm_positions'] = tf.range(tf.shape(x['input_ids']))
    return x, y


In [210]:
ts_mlm = train_dataset.map(map_mlm_check)
ts_pclm = train_dataset.map(map_pclm_check)

In [194]:
ts_mlm = ts_mlm.padded_batch(5)

In [195]:
for item in ts_mlm:
    print(item)
    break

({'input_ids': <tf.Tensor: shape=(5, 128), dtype=int32, numpy=
array([[    2,  1012,  3744,     5,  4470,  6325,  1196,  8150, 17811,
         4806,  6899,  3046,  2511,     3,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,

In [196]:
ts_pclm = ts_pclm.padded_batch(5)

In [211]:
for item in ts_pclm:
    print(item)
    break

({'input_ids': <tf.Tensor: shape=(13,), dtype=int32, numpy=
array([   2, 1012, 3744, 2457, 4470, 6325, 1196, 8150, 7410, 4806, 6899,
       3046, 2511], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(13,), dtype=int32, numpy=array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32)>, '3d_mask': <tf.Tensor: shape=(128, 128), dtype=float32, numpy=
array([[1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       ...,
       [1., 1., 1., ..., 1., 0., 0.],
       [1., 1., 1., ..., 1., 1., 0.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)>, 'input_mask': <tf.Tensor: shape=(128,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [176]:
tf.range(tf.shape(item[0]['input_ids']))

<tf.Tensor: shape=(13,), dtype=int32, numpy=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12], dtype=int32)>

In [111]:
all_modes = []
all_probs = []
for (batch_inputs, batch_labels) in train_dataset:
    #print(batch_inputs, batch_labels)
    all_modes.append(batch_labels['mode'])
    all_probs.append(batch_labels['prob'])

In [112]:
all_modes

[<tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tensor: shape=(), dtype=string, numpy=b'prefix'>,
 <tf.Tenso

In [101]:
prob

<tf.Tensor: shape=(), dtype=float32, numpy=0.43081582>

In [113]:
all_probs

[<tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.5386268>,
 <tf.Tensor: shape=(), dtype=float32, numpy=0.53

In [6]:
# Normal attention masking

input_mask = tf.constant([[1, 1, 1, 0, 0], 
              [1, 1, 1, 1, 1], 
              [1, 1, 1, 1, 1]
             ])
#attention_mask = SelfAttentionMask()([embeddings, input_mask])

In [10]:
to_mask

<tf.Tensor: shape=(3, 1, 5), dtype=int32, numpy=
array([[[1, 1, 1, 0, 0]],

       [[1, 1, 1, 1, 1]],

       [[1, 1, 1, 1, 1]]], dtype=int32)>

In [12]:
broadcast_ones

<tf.Tensor: shape=(3, 5, 1), dtype=int32, numpy=
array([[[1],
        [1],
        [1],
        [1],
        [1]],

       [[1],
        [1],
        [1],
        [1],
        [1]],

       [[1],
        [1],
        [1],
        [1],
        [1]]], dtype=int32)>

In [14]:
mask

<tf.Tensor: shape=(3, 5, 5), dtype=int32, numpy=
array([[[1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0],
        [1, 1, 1, 0, 0]],

       [[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]],

       [[1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1]]], dtype=int32)>