In [1]:
import numpy as np
import tensorflow as tf

# Extract N rows from (N,R,C) tensor with row indices

You have a N batch of ```(R, C)``` matrix e.g. (N=2, R=3, C=4).

```
[
    # N=0
    [
        [ 0,  1,  2,  3],
        [ 4,  5,  6,  7],    # <--- row index 1
        [ 8,  9, 10, 11]
    ],
    # N=1
    [
        [12, 13, 14, 15],    # <--- row index 0
        [16, 17, 18, 19],
        [20, 21, 22, 23]
    ]
]
```

You want to extract 
1. row with index=1 from the first batch (N=0)
2. row with index=0 from the second batch (N=1)

# Approach
Use **One Hot Encoded** indices to extract the rows with indices.

```einsum("nrc,nr->nc", x:shape(N,R,C), OHE=[[0,1,0]]:shape(N,R))``` will extract row index=1 (2nd row).

In [2]:
N = 1
R = 3
C = 4

In [3]:
x=tf.constant(np.array([[
    [ 0,  1,  2,  3],
    [ 4,  5,  6,  7],  # <--- extract row index=1
    [ 8,  9, 10, 11]
]]), dtype=np.float32)
print(x)

tf.Tensor(
[[[ 0.  1.  2.  3.]
  [ 4.  5.  6.  7.]
  [ 8.  9. 10. 11.]]], shape=(1, 3, 4), dtype=float32)


OHE

In [4]:
indices=tf.one_hot(indices=[1], depth=R)
print(indices)

tf.Tensor([[0. 1. 0.]], shape=(1, 3), dtype=float32)


# Einsum

Extract the row with index=1

In [5]:
print(tf.einsum("NRC,NR->NC", x, indices))

tf.Tensor([[4. 5. 6. 7.]], shape=(1, 4), dtype=float32)


---
# Examples

From a tensor of shape ```(N=2,M=2,R=3,C=4)```, extrct rows with indices (0,2,1,0).

```
[
    # N = 0
    [
        # M = 0
        [
            [ 0,  1,  2,  3],    # <--- index 0
            [ 4,  5,  6,  7],
            [ 8,  9, 10, 11]
        ],
        # M = 1
        [
            [12, 13, 14, 15],
            [16, 17, 18, 19],
            [20, 21, 22, 23]     # <--- index 2
        ]
    ],
    # N = 1
    [
        # M = 0
        [
            [24, 25, 26, 27],
            [28, 29, 30, 31],    # <--- index 1
            [32, 33, 34, 35]
        ],
        # M = 1
        [
            [36, 37, 38, 39],    # <--- index 0
            [40, 41, 42, 43],
            [44, 45, 46, 47]
        ]
    ]
]
```

The result shape will be ```(N=2,M=2,C=4)```.

## X

Target ```X:shape(N,M,R,C)``` to extract rows.

In [6]:
N=2
M=2
R=3    # Number of rows
C=4    # Number of columns in a row=(x, y, h, w)

x = tf.constant(
    np.arange(N*N*R*C).reshape(N,N,R,C),
    dtype=np.float32
)
print(x)

tf.Tensor(
[[[[ 0.  1.  2.  3.]
   [ 4.  5.  6.  7.]
   [ 8.  9. 10. 11.]]

  [[12. 13. 14. 15.]
   [16. 17. 18. 19.]
   [20. 21. 22. 23.]]]


 [[[24. 25. 26. 27.]
   [28. 29. 30. 31.]
   [32. 33. 34. 35.]]

  [[36. 37. 38. 39.]
   [40. 41. 42. 43.]
   [44. 45. 46. 47.]]]], shape=(2, 2, 3, 4), dtype=float32)


### OHE Indices

In [7]:
indices=tf.one_hot(indices=(0,2,1,0), depth=3)
print(indices)

tf.Tensor(
[[1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]], shape=(4, 3), dtype=float32)


### Einsum to extract

In [8]:
print(tf.einsum("nrc,nr->nc", tf.reshape(x,(-1,R,C)), indices))

tf.Tensor(
[[ 0.  1.  2.  3.]
 [20. 21. 22. 23.]
 [28. 29. 30. 31.]
 [36. 37. 38. 39.]], shape=(4, 4), dtype=float32)


---
# Function


In [9]:
import sys
import os
import pathlib

## PYTHONPATH

In [10]:
path_to_lib: str = str(pathlib.Path(os.path.join(os.getcwd(), "../../lib")).resolve())
sys.path.append(path_to_lib)

In [48]:
%load_ext autoreload
%autoreload 2

from util_tf.tensor import (
    take_rows_by_indices
)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [49]:
N = 2
M = 3
D = 4
x = tf.constant(
    np.arange(N*N*M*D).reshape(N,N,M,D),
    dtype=np.float32
)
indices=np.random.randint(low=0, high=M, size=N*N)
print(x)
print(indices)

tf.Tensor(
[[[[ 0.  1.  2.  3.]
   [ 4.  5.  6.  7.]
   [ 8.  9. 10. 11.]]

  [[12. 13. 14. 15.]
   [16. 17. 18. 19.]
   [20. 21. 22. 23.]]]


 [[[24. 25. 26. 27.]
   [28. 29. 30. 31.]
   [32. 33. 34. 35.]]

  [[36. 37. 38. 39.]
   [40. 41. 42. 43.]
   [44. 45. 46. 47.]]]], shape=(2, 2, 3, 4), dtype=float32)
[1 1 0 2]


In [50]:
take_rows_by_indices(X=x, M=M, D=D, indices=indices)

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[ 4.,  5.,  6.,  7.],
       [16., 17., 18., 19.],
       [24., 25., 26., 27.],
       [44., 45., 46., 47.]], dtype=float32)>