In [2]:
import numpy as np

batch_size = 4
seq_len = 10
n_vocab = 200

# synthetic sequences of tokens
sequences = np.random.randint(0, n_vocab, size=(batch_size, seq_len))

# synthetic logits matrix, which I would obtain by running my network over the sequences
logits = np.random.normal(size=(batch_size, seq_len, n_vocab)) 

# select the appropriate logits from the matrix using the sequences tokens as indices
# question: is there a more numpyic way of doing thing? 
seq_scores = np.array([
    [logits[batch, step, token]for step, token in enumerate(sequence)] 
    for batch, sequence in enumerate(sequences)
])

seq_scores[0,0] == logits[0,0,sequences[0,0]] # True
seq_scores.shape == sequences.shape # True

True

## Optimization on one sequence.

In [11]:
[logits[0, i, token] for i, token in enumerate(sequences[0])] \
    == logits[0, list(range(seq_len)), sequences[0]] 

array([ True,  True,  True,  True,  True,  True,  True,  True,  True,
        True])

In [8]:
%timeit [logits[0, i, token] for i, token in enumerate(sequences[0])]

3.28 µs ± 27.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [9]:
%timeit logits[0, list(range(seq_len)), sequences[0]] 

3.31 µs ± 45.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


## Optimization with batches.

In [105]:
%timeit np.repeat(list(range(batch_size)), seq_len).reshape(-1, seq_len)

6.37 µs ± 241 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [109]:
%timeit  [[x] * seq_len for x in range(batch_size)]

987 ns ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [122]:
%timeit list(zip(*seq_len * [list(range(batch_size))]))

1.51 µs ± 51.6 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [123]:
list(zip(*seq_len * [list(range(batch_size))]))

[(0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1),
 (2, 2, 2, 2, 2, 2, 2, 2, 2, 2),
 (3, 3, 3, 3, 3, 3, 3, 3, 3, 3)]

In [111]:
[[x] * seq_len for x in range(batch_size)]

[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
 [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
 [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]]

In [97]:
i1 = np.repeat(list(range(batch_size)), seq_len).reshape(-1, seq_len)
i1

array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
       [3, 3, 3, 3, 3, 3, 3, 3, 3, 3]])

In [98]:
i2 = batch_size * [list(range(seq_len))]
i2

[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]]

In [85]:
sequences

array([[187,  89, 153, 101,   3,  14, 157,  25,  71,  93],
       [129, 109, 128,  46,  16,  77,  75, 126, 127, 142],
       [ 68,  54, 146, 163, 103, 188,  38,   4,  60, 153],
       [ 29, 183, 178,  31, 186,  73,  61, 143, 196, 149]])

### Let's do it!

In [99]:
logits[i1,i2, sequences].shape

(4, 10)

In [101]:
seq_scores2 = logits[i1,i2, sequences]

In [102]:
seq_scores == seq_scores2

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])

### Tests.

In [104]:
%timeit seq_scores = np.array([[logits[batch, step, token]for step, token in enumerate(sequence)] for batch, sequence in enumerate(sequences)])

23.3 µs ± 393 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [107]:
%timeit logits[np.repeat(list(range(batch_size)), seq_len).reshape(-1, seq_len), batch_size * [list(range(seq_len))], sequences]

15 µs ± 553 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [112]:
%timeit logits[[[x] * seq_len for x in range(batch_size)], batch_size * [list(range(seq_len))], sequences]

13.3 µs ± 257 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


---

## Similar business but with tf

In [1]:
import tensorflow as tf
tf.enable_eager_execution() # for tf 1.14

In [11]:
tf.__version__

'1.14.0'

In [3]:
t = tf.cast(tf.random.normal(shape=(5,2,10))*100, tf.int32) 

In [4]:
t

<tf.Tensor: id=8, shape=(5, 2, 10), dtype=int32, numpy=
array([[[  -8,  -17,   13,  199,  -56,   29, -142, -229, -120,  105],
        [ -88,   63, -160,   -1,   61,  -68, -190, -199,    0,  -19]],

       [[ 196, -163,  162,  195,  117,  -71,  -69,    1,  179,   19],
        [-106,  -61,  -99,   54,  178,  -35,   36,  -95,   -4, -155]],

       [[ -86,    0,   12, -111,   27, -145,  107,  136,   37,  202],
        [-100,    0,  112, -168,   -7,  -29,   82,   35,   90,   31]],

       [[ 100,   62, -175,   35,  165,  -20, -149,    5,   21, -107],
        [  89,   58,   38,  150,  -80,   45,  217,   76,   81,   53]],

       [[ -70,   91,   90,  -42,  227,  235,  151,  -13,  -89,   20],
        [   3,  -11,  -38,  102,    8,   13,  120,  -25, -217,  140]]],
      dtype=int32)>

In [5]:
ind = tf.cast(tf.random.uniform(shape=(5,2,1))*10, tf.int32)

In [6]:
ind

<tf.Tensor: id=19, shape=(5, 2, 1), dtype=int32, numpy=
array([[[1],
        [7]],

       [[5],
        [4]],

       [[8],
        [7]],

       [[6],
        [5]],

       [[9],
        [8]]], dtype=int32)>

In [58]:
dim1 = tf.transpose(tf.reshape(tf.tile(tf.range(5), [2]), (2,5,1)), [1,0,2]) 
dim1

<tf.Tensor: id=406, shape=(5, 2, 1), dtype=int32, numpy=
array([[[0],
        [0]],

       [[1],
        [1]],

       [[2],
        [2]],

       [[3],
        [3]],

       [[4],
        [4]]], dtype=int32)>

In [33]:
dim2 = tf.reshape(tf.tile(tf.range(2), [5]), (5,2))[..., None]
dim2

<tf.Tensor: id=165, shape=(5, 2, 1), dtype=int32, numpy=
array([[[0],
        [1]],

       [[0],
        [1]],

       [[0],
        [1]],

       [[0],
        [1]],

       [[0],
        [1]]], dtype=int32)>

In [59]:
indz = tf.concat([dim1, dim2, ind], axis=-1)
indz

<tf.Tensor: id=409, shape=(5, 2, 3), dtype=int32, numpy=
array([[[0, 0, 1],
        [0, 1, 7]],

       [[1, 0, 5],
        [1, 1, 4]],

       [[2, 0, 8],
        [2, 1, 7]],

       [[3, 0, 6],
        [3, 1, 5]],

       [[4, 0, 9],
        [4, 1, 8]]], dtype=int32)>

In [60]:
t

<tf.Tensor: id=8, shape=(5, 2, 10), dtype=int32, numpy=
array([[[  -8,  -17,   13,  199,  -56,   29, -142, -229, -120,  105],
        [ -88,   63, -160,   -1,   61,  -68, -190, -199,    0,  -19]],

       [[ 196, -163,  162,  195,  117,  -71,  -69,    1,  179,   19],
        [-106,  -61,  -99,   54,  178,  -35,   36,  -95,   -4, -155]],

       [[ -86,    0,   12, -111,   27, -145,  107,  136,   37,  202],
        [-100,    0,  112, -168,   -7,  -29,   82,   35,   90,   31]],

       [[ 100,   62, -175,   35,  165,  -20, -149,    5,   21, -107],
        [  89,   58,   38,  150,  -80,   45,  217,   76,   81,   53]],

       [[ -70,   91,   90,  -42,  227,  235,  151,  -13,  -89,   20],
        [   3,  -11,  -38,  102,    8,   13,  120,  -25, -217,  140]]],
      dtype=int32)>

In [61]:
tf.gather_nd(t, indz)

<tf.Tensor: id=412, shape=(5, 2), dtype=int32, numpy=
array([[ -17, -199],
       [ -71,  178],
       [  37,   35],
       [-149,   45],
       [  20, -217]], dtype=int32)>