In [1]:
import numpy as np
import tensorflow as tf

# NOTE
Tensorflow indexing/slicing are basically **NOT compatible** with NumPy indexing methods with ```[...]```. 

* [How a seemingly straightforward operation in NumPy turns into a nightmare with TensorFlow ](https://towardsdatascience.com/how-to-replace-values-by-index-in-a-tensor-with-tensorflow-2-0-510994fe6c5f)

Re-learn Tensorflow method manner to extract and update slices. Avoid the half-measure slice notation e.g ```X[1:None:2]``` to avoid confusions.

* [Introduction to tensor slicing](https://www.tensorflow.org/guide/tensor_slicing)

## Slicing
* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) - Extract slices at **indices** along the **axis**.
```
tf.gather(
    params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
```
* [tf.gather_nd](https://www.tensorflow.org/api_docs/python/tf/gather_nd) - Extract slices at junctions located at **indices**.
```
tf.gather_nd(
    params, indices, batch_dims=0, name=None
)
```

## Updating

* [tf.scatter_nd](https://www.tensorflow.org/api_docs/python/tf/scatter_nd) - **NOT** update but create a **new zeros** tensor of **shape** and *initialize* with **updates** at **indices**.
```
tf.scatter_nd(
    indices, updates, shape, name=None
)
```
* [tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) - Updates an **existing** tensor with **updates** at **indices**.
```
tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)
```

---

There are subset of NumPy but better not to use half-way measures.

* [NumPy API on TensorFlow](https://www.tensorflow.org/guide/tf_numpy)

> TensorFlow implements a subset of the NumPy API, available as tf.experimental.numpy.

* [tf.tensor_scatter_nd_update](http://localhost:8888/notebooks/indexing/tf_slicing_update.ipynb)

```
tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)
```

# Data X

In [2]:
X = tf.Variable(tf.reshape(tf.range(25, dtype=tf.int32), shape=(5,5)))
print(X)

<tf.Variable 'Variable:0' shape=(5, 5) dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]], dtype=int32)>


---
# tf.gather - Row or Column extraction

### Extract rows

In [3]:
tf.gather(X, indices=[1, 3], axis=0)  # Rows 1 and 3 (same with axix=None)

<tf.Tensor: shape=(2, 5), dtype=int32, numpy=
array([[ 5,  6,  7,  8,  9],
       [15, 16, 17, 18, 19]], dtype=int32)>

In [4]:
# same with slicing with Python slice object
print(X[slice(1, None, 2)])
print(X[1::2])

tf.Tensor(
[[ 5  6  7  8  9]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[ 5  6  7  8  9]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)


### Extract columns

In [5]:
tf.gather(X, indices=[1, 3], axis=1)  # Columns 1 and 3

<tf.Tensor: shape=(5, 2), dtype=int32, numpy=
array([[ 1,  3],
       [ 6,  8],
       [11, 13],
       [16, 18],
       [21, 23]], dtype=int32)>

# tf.gather.nd - extract junctions

In [6]:
indices = [[1,2], [3, 4]]  # Junction (row, col)=(1,2) and (3,4)
tf.gather_nd(X, indices).numpy()

array([ 7, 19], dtype=int32)

---

# tf.scatter_nd - Initialize a new zeros tensor with values at indices


In [7]:
indices = [[1], [3]]
updates = tf.constant(np.ones(shape=(2,5)))
tf.scatter_nd(indices=indices, updates=updates, shape=(5,5))

<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0.]])>

# tf.tensor_scatter_nd_update - update values at indices


In [8]:
X = tf.Variable(tf.reshape(tf.range(25, dtype=tf.int32), shape=(5,5)))
print(f"{X}\n")

indices = [
    [1, 2], 
    [2, 3]
]
updates = [0, 0]
print(f"Update {tf.gather_nd(X, indices)} with {updates}\n")

tf.tensor_scatter_nd_update(X, indices, updates)

<tf.Variable 'Variable:0' shape=(5, 5) dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]], dtype=int32)>

Update [ 7 13] with [0, 0]



<tf.Tensor: shape=(5, 5), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  0,  8,  9],
       [10, 11, 12,  0, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]], dtype=int32)>