In [1]:
import numpy as np
import tensorflow as tf

In [2]:
B = 6
T = 4
H = 5

In [3]:
def make_batch(B, T, H):
    b = np.arange(1, B + 1).reshape(-1, 1)
    t = np.arange(1, T + 1).reshape(1, -1) / 10
    h = (np.arange(1, H + 1) / 100)
    batch = (np.expand_dims(b + t, -1) + h).astype(np.float32)
    lengths = make_lengths(B, T)
    #for i, l in enumerate(lengths):
    #    batch[i, l:] = 0
    return batch, lengths

def make_lengths(B, T):
    lengths = np.full((B,), T)
    idx = np.random.randint(0, B, size=(B // 2 + 1))
    lengths[idx] = np.random.randint(1, T, size=(idx.shape))
    return lengths

In [4]:
data, lengths = make_batch(B, T, H)

In [5]:
lengths

array([4, 2, 4, 2, 3, 4])

In [6]:
batch = tf.convert_to_tensor(data)
lengths = tf.convert_to_tensor(lengths)
mask = tf.sequence_mask(lengths, tf.reduce_max(lengths))

In [7]:
# This doesn't work when the mask has holes in it
def dense_masked_select(tensor, mask):
    B = tf.shape(mask)[0]
    T = tf.shape(mask)[1]
    all_indices = tf.transpose(tf.unravel_index(indices=tf.range(B * T), dims=[B, T]))
    indices = tf.where(tf.reshape(tf.cast(mask, tf.bool), (-1, 1)), all_indices, tf.zeros_like(all_indices))
    selected = tf.gather_nd(tensor, indices)
    dense = tf.reshape(selected, (B, T, -1))
    return tf.multiply(dense, tf.expand_dims(tf.cast(mask, dense.dtype), -1))

In [8]:
d = dense_masked_select(batch, mask)

In [9]:
d.shape

TensorShape([6, 4, 5])

In [10]:
d

<tf.Tensor: shape=(6, 4, 5), dtype=float32, numpy=
array([[[1.11, 1.12, 1.13, 1.14, 1.15],
        [1.21, 1.22, 1.23, 1.24, 1.25],
        [1.31, 1.32, 1.33, 1.34, 1.35],
        [1.41, 1.42, 1.43, 1.44, 1.45]],

       [[2.11, 2.12, 2.13, 2.14, 2.15],
        [2.21, 2.22, 2.23, 2.24, 2.25],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[3.11, 3.12, 3.13, 3.14, 3.15],
        [3.21, 3.22, 3.23, 3.24, 3.25],
        [3.31, 3.32, 3.33, 3.34, 3.35],
        [3.41, 3.42, 3.43, 3.44, 3.45]],

       [[4.11, 4.12, 4.13, 4.14, 4.15],
        [4.21, 4.22, 4.23, 4.24, 4.25],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[5.11, 5.12, 5.13, 5.14, 5.15],
        [5.21, 5.22, 5.23, 5.24, 5.25],
        [5.31, 5.32, 5.33, 5.34, 5.35],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[6.11, 6.12, 6.13, 6.14, 6.15],
        [6.21, 6.22, 6.23, 6.24, 6.25],
        [6.31, 6.32, 6.33, 6.34, 6.35],
        [6.41, 6.42

In [11]:
less_lengths = tf.maximum(lengths - 1, tf.ones_like(lengths))
new_mask = tf.sequence_mask(less_lengths, tf.reduce_max(less_lengths))

In [12]:
d2 = dense_masked_select(batch, new_mask)

In [13]:
d2.shape

TensorShape([6, 3, 5])

In [14]:
d2

<tf.Tensor: shape=(6, 3, 5), dtype=float32, numpy=
array([[[1.11, 1.12, 1.13, 1.14, 1.15],
        [1.21, 1.22, 1.23, 1.24, 1.25],
        [1.31, 1.32, 1.33, 1.34, 1.35]],

       [[2.11, 2.12, 2.13, 2.14, 2.15],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[3.11, 3.12, 3.13, 3.14, 3.15],
        [3.21, 3.22, 3.23, 3.24, 3.25],
        [3.31, 3.32, 3.33, 3.34, 3.35]],

       [[4.11, 4.12, 4.13, 4.14, 4.15],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[5.11, 5.12, 5.13, 5.14, 5.15],
        [5.21, 5.22, 5.23, 5.24, 5.25],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[6.11, 6.12, 6.13, 6.14, 6.15],
        [6.21, 6.22, 6.23, 6.24, 6.25],
        [6.31, 6.32, 6.33, 6.34, 6.35]]], dtype=float32)>

In [15]:
new_mask = tf.convert_to_tensor(np.array(
    [
        [1, 0, 1, 0],
        [0, 1, 1, 1],
        [1, 1, 0, 0],
        [1, 1, 1, 1],
        [0, 1, 0, 0],
        [0, 1, 1, 1],
    ], dtype=np.bool
))

In [16]:
d = dense_masked_select(batch, new_mask)
d

<tf.Tensor: shape=(6, 4, 5), dtype=float32, numpy=
array([[[1.11, 1.12, 1.13, 1.14, 1.15],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [1.31, 1.32, 1.33, 1.34, 1.35],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[0.  , 0.  , 0.  , 0.  , 0.  ],
        [2.21, 2.22, 2.23, 2.24, 2.25],
        [2.31, 2.32, 2.33, 2.34, 2.35],
        [2.41, 2.42, 2.43, 2.44, 2.45]],

       [[3.11, 3.12, 3.13, 3.14, 3.15],
        [3.21, 3.22, 3.23, 3.24, 3.25],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[4.11, 4.12, 4.13, 4.14, 4.15],
        [4.21, 4.22, 4.23, 4.24, 4.25],
        [4.31, 4.32, 4.33, 4.34, 4.35],
        [4.41, 4.42, 4.43, 4.44, 4.45]],

       [[0.  , 0.  , 0.  , 0.  , 0.  ],
        [5.21, 5.22, 5.23, 5.24, 5.25],
        [0.  , 0.  , 0.  , 0.  , 0.  ],
        [0.  , 0.  , 0.  , 0.  , 0.  ]],

       [[0.  , 0.  , 0.  , 0.  , 0.  ],
        [6.21, 6.22, 6.23, 6.24, 6.25],
        [6.31, 6.32, 6.33, 6.34, 6.35],
        [6.41, 6.42

In [17]:
indices = tf.where(new_mask)

In [18]:
indices

<tf.Tensor: shape=(15, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 2],
       [1, 1],
       [1, 2],
       [1, 3],
       [2, 0],
       [2, 1],
       [3, 0],
       [3, 1],
       [3, 2],
       [3, 3],
       [4, 1],
       [5, 1],
       [5, 2],
       [5, 3]])>

In [19]:
# This doesn't work because the calculated index just puts it back where it is supposed to go in the 2D
flat_idx = indices[:, 0] * T + indices[:, 1]
tf.scatter_nd(tf.reshape(flat_idx, (-1, 1)), indices, shape=(B * T, 2))

<tf.Tensor: shape=(24, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 0],
       [0, 2],
       [0, 0],
       [0, 0],
       [1, 1],
       [1, 2],
       [1, 3],
       [2, 0],
       [2, 1],
       [0, 0],
       [0, 0],
       [3, 0],
       [3, 1],
       [3, 2],
       [3, 3],
       [0, 0],
       [4, 1],
       [0, 0],
       [0, 0],
       [0, 0],
       [5, 1],
       [5, 2],
       [5, 3]])>

In [20]:
@tf.function
def to_dense_idx(indices, T):
    idx = tf.TensorArray(dtype=tf.int64, size=tf.shape(indices)[0], dynamic_size=False)
    prev = tf.cast(-1, tf.int64)
    j = tf.cast(0, tf.int64)
    T = tf.cast(T, tf.int64)
    for i in range(tf.shape(indices)[0]):
        if indices[i][0] == prev:
            j += 1
        else:
            prev = indices[i][0]
            j = tf.cast(0, tf.int64)
        idx = idx.write(i, prev * T + j)
    return idx.stack()

In [21]:
to_dense_idx(indices, T)

<tf.Tensor: shape=(15,), dtype=int64, numpy=array([ 0,  1,  4,  5,  6,  8,  9, 12, 13, 14, 15, 16, 20, 21, 22])>

In [22]:
def to_dense_idx2(indices, T):
    _, mapping, count = tf.unique_with_counts(indices[:, 0])
    batch_offsets = indices[:, 0] * T
    prebatch_counts = tf.gather(count, mapping)
    print(batch_offsets)
    print(prebatch_counts)
    batch_idx = indices[:, 1]
    print(batch_idx)
    
    

In [23]:
to_dense_idx2(indices, T)

tf.Tensor([ 0  0  4  4  4  8  8 12 12 12 12 16 20 20 20], shape=(15,), dtype=int64)
tf.Tensor([2 2 3 3 3 2 2 4 4 4 4 1 3 3 3], shape=(15,), dtype=int32)
tf.Tensor([0 2 1 2 3 0 1 0 1 2 3 1 1 2 3], shape=(15,), dtype=int64)


In [24]:
tf.scatter_nd(tf.reshape(to_dense_idx(indices, T), (-1, 1)), indices, shape=(B * T, 2))

<tf.Tensor: shape=(24, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 2],
       [0, 0],
       [0, 0],
       [1, 1],
       [1, 2],
       [1, 3],
       [0, 0],
       [2, 0],
       [2, 1],
       [0, 0],
       [0, 0],
       [3, 0],
       [3, 1],
       [3, 2],
       [3, 3],
       [4, 1],
       [0, 0],
       [0, 0],
       [0, 0],
       [5, 1],
       [5, 2],
       [5, 3],
       [0, 0]])>

In [25]:
def span_select(tensor, mask):
    B = tf.shape(mask)[0]
    T = tf.shape(mask)[1]
    
    indices = tf.where(mask)
    dense_indices = to_dense_idx(indices, T)
    indices = tf.scatter_nd(tf.reshape(dense_indices, (-1, 1)), indices, shape=(B * T, 2))

    selected = tf.gather_nd(tensor, indices)
    dense = tf.reshape(selected, (B, T, -1))
    dense_lengths = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
    dense_mask = tf.sequence_mask(dense_lengths, tf.shape(dense)[1])
    return tf.multiply(dense, tf.expand_dims(tf.cast(dense_mask, dense.dtype), -1)), dense_lengths

In [26]:
span_select(batch, new_mask)

(<tf.Tensor: shape=(6, 4, 5), dtype=float32, numpy=
 array([[[1.11, 1.12, 1.13, 1.14, 1.15],
         [1.31, 1.32, 1.33, 1.34, 1.35],
         [0.  , 0.  , 0.  , 0.  , 0.  ],
         [0.  , 0.  , 0.  , 0.  , 0.  ]],
 
        [[2.21, 2.22, 2.23, 2.24, 2.25],
         [2.31, 2.32, 2.33, 2.34, 2.35],
         [2.41, 2.42, 2.43, 2.44, 2.45],
         [0.  , 0.  , 0.  , 0.  , 0.  ]],
 
        [[3.11, 3.12, 3.13, 3.14, 3.15],
         [3.21, 3.22, 3.23, 3.24, 3.25],
         [0.  , 0.  , 0.  , 0.  , 0.  ],
         [0.  , 0.  , 0.  , 0.  , 0.  ]],
 
        [[4.11, 4.12, 4.13, 4.14, 4.15],
         [4.21, 4.22, 4.23, 4.24, 4.25],
         [4.31, 4.32, 4.33, 4.34, 4.35],
         [4.41, 4.42, 4.43, 4.44, 4.45]],
 
        [[5.21, 5.22, 5.23, 5.24, 5.25],
         [0.  , 0.  , 0.  , 0.  , 0.  ],
         [0.  , 0.  , 0.  , 0.  , 0.  ],
         [0.  , 0.  , 0.  , 0.  , 0.  ]],
 
        [[6.21, 6.22, 6.23, 6.24, 6.25],
         [6.31, 6.32, 6.33, 6.34, 6.35],
         [6.41, 6.42, 6.43, 6.4

In [27]:
example = indices.numpy()

def to_dense_idx_imperative(indices, T):
    dense = []
    prev = -1
    j = 0
    for i, idx in enumerate(indices):
        if idx[0] == prev:
            j += 1
        else:
            prev = idx[0]
            j = 0
        dense.append(prev * T + j)
    return np.array(dense)

print(example)
print(to_dense_idx_imperative(example, T))

[[0 0]
 [0 2]
 [1 1]
 [1 2]
 [1 3]
 [2 0]
 [2 1]
 [3 0]
 [3 1]
 [3 2]
 [3 3]
 [4 1]
 [5 1]
 [5 2]
 [5 3]]
[ 0  1  4  5  6  8  9 12 13 14 15 16 20 21 22]


In [28]:
def unique_prefix_sum(idx):
    ps = []
    prev = -1
    count = 0
    for i in idx:
        if prev == i:
            count += 1
        else:
            prev = i
            count = 0
        ps.append(count)
    return np.array(ps)

def to_dense_idx_vect(indices, T):
    uniq = unique_prefix_sum(indices[:, 0])
    batch_offset = indices[:, 0] * T
    return batch_offset + uniq

In [29]:
to_dense_idx_vect(example, T)

array([ 0,  1,  4,  5,  6,  8,  9, 12, 13, 14, 15, 16, 20, 21, 22])

In [30]:
new_mask

<tf.Tensor: shape=(6, 4), dtype=bool, numpy=
array([[ True, False,  True, False],
       [False,  True,  True,  True],
       [ True,  True, False, False],
       [ True,  True,  True,  True],
       [False,  True, False, False],
       [False,  True,  True,  True]])>

In [31]:
indices

<tf.Tensor: shape=(15, 2), dtype=int64, numpy=
array([[0, 0],
       [0, 2],
       [1, 1],
       [1, 2],
       [1, 3],
       [2, 0],
       [2, 1],
       [3, 0],
       [3, 1],
       [3, 2],
       [3, 3],
       [4, 1],
       [5, 1],
       [5, 2],
       [5, 3]])>