In [2]:
import numpy as np
import tensorflow as tf

2023-11-15 11:42:59.384119: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-15 11:42:59.409178: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-15 11:42:59.409199: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-15 11:42:59.409215: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-15 11:42:59.413967: I tensorflow/core/platform/cpu_feature_g

# NOTE
Tensorflow indexing/slicing are basically **NOT compatible** with NumPy indexing methods with ```[...]```. 

* [How a seemingly straightforward operation in NumPy turns into a nightmare with TensorFlow ](https://towardsdatascience.com/how-to-replace-values-by-index-in-a-tensor-with-tensorflow-2-0-510994fe6c5f)

Re-learn Tensorflow method manner to extract and update slices. Avoid the half-measure slice notation e.g ```X[1:None:2]``` to avoid confusions.

* [Introduction to tensor slicing](https://www.tensorflow.org/guide/tensor_slicing)

## Slicing
* [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather) - Extract slices at **indices** along the **axis**.
```
tf.gather(
    params, indices, validate_indices=None, axis=None, batch_dims=0, name=None
)
```
* [tf.gather_nd](https://www.tensorflow.org/api_docs/python/tf/gather_nd) - Extract slices at junctions located at **indices**.
```
tf.gather_nd(
    params, indices, batch_dims=0, name=None
)
```

## Updating

* [tf.scatter_nd](https://www.tensorflow.org/api_docs/python/tf/scatter_nd) - **NOT** update but create a **new zeros** tensor of **shape** and *initialize* with **updates** at **indices**.
```
tf.scatter_nd(
    indices, updates, shape, name=None
)
```
* [tf.tensor_scatter_nd_update](https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update) - Updates an **existing** tensor with **updates** at **indices**.
```
tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)
```

---

There are subset of NumPy but better not to use half-way measures.

* [NumPy API on TensorFlow](https://www.tensorflow.org/guide/tf_numpy)

> TensorFlow implements a subset of the NumPy API, available as tf.experimental.numpy.

* [tf.tensor_scatter_nd_update](http://localhost:8888/notebooks/indexing/tf_slicing_update.ipynb)

```
tf.tensor_scatter_nd_update(
    tensor, indices, updates, name=None
)
```

# Data X

In [2]:
X = tf.Variable(tf.reshape(tf.range(25, dtype=tf.int32), shape=(5,5)))
print(X)

<tf.Variable 'Variable:0' shape=(5, 5) dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]], dtype=int32)>


---
# tf.gather - Row or Column extraction

### Extract rows

In [3]:
tf.gather(X, indices=[1, 3], axis=0)  # Rows 1 and 3 (same with axix=None)

<tf.Tensor: shape=(2, 5), dtype=int32, numpy=
array([[ 5,  6,  7,  8,  9],
       [15, 16, 17, 18, 19]], dtype=int32)>

In [4]:
# same with slicing with Python slice object
print(X[slice(1, None, 2)])
print(X[1::2])

tf.Tensor(
[[ 5  6  7  8  9]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)
tf.Tensor(
[[ 5  6  7  8  9]
 [15 16 17 18 19]], shape=(2, 5), dtype=int32)


### Extract columns

In [5]:
tf.gather(X, indices=[1, 3], axis=1)  # Columns 1 and 3

<tf.Tensor: shape=(5, 2), dtype=int32, numpy=
array([[ 1,  3],
       [ 6,  8],
       [11, 13],
       [16, 18],
       [21, 23]], dtype=int32)>

# tf.gather_nd - extract junctions

```tf.gather``` can extract columns with ```axis=``` argument. However, there is no method to update columns. 

To update columns, first create a transposed shape to update, call ```tf.scatter_nd``` and transpose it to the target shape.

In [6]:
indices = [[1,2], [3, 4]]  # Junction (row, col)=(1,2) and (3,4)
tf.gather_nd(X, indices).numpy()

array([ 7, 19], dtype=int32)

---

# tf.scatter_nd - Initialize a new zeros tensor with values at indices

* [tf.scatter_nd](https://www.tensorflow.org/api_docs/python/tf/scatter_nd)

>``` 
> tf.scatter_nd(
>    indices, updates, shape, name=None
>)
>```
> Calling ```tf.scatter_nd(indices, updates, shape)``` is identical to calling ```tf.tensor_scatter_nd_add(tf.zeros(shape, updates.dtype), indices, updates)```.

In [12]:
indices = [
    [1],   # index depth = rank = 1 levels to reach the target to update 
    [3]
]
updates = tf.constant(np.ones(shape=(2,5)))
print(f"updates=\n{updates}")

tf.scatter_nd(indices=indices, updates=updates, shape=(5,5))

updates=
[[1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1.]]


<tf.Tensor: shape=(5, 5), dtype=float64, numpy=
array([[0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0.]])>

In [17]:
indices = [
    [0, 0],   # index depth = rank = 2 levels to reach the target to update
    [4, 4]
]

updates = [1, 1]
print(f"updates=\n{updates}")

tf.scatter_nd(indices=indices, updates=updates, shape=(5,5))

updates=
[1, 1]


<tf.Tensor: shape=(5, 5), dtype=int32, numpy=
array([[1, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 1]], dtype=int32)>

## Update column with tf.scatter_nd

Update the 1st and 3rd columns of x.
```
x = [
    [ 0,  1,  2,  3,  4,  5,  6,  7,  8],
    [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
    [18, 19, 20, 21, 22, 23, 24, 25, 26]
]
```

In [3]:
x = tf.reshape(tf.range(3*3*3), (3,9))
x

2023-11-15 11:43:17.606298: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-11-15 11:43:17.658508: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


<tf.Tensor: shape=(3, 9), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
       [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23, 24, 25, 26]], dtype=int32)>

In [17]:
indices_transposed = tf.constant([
    [1],    # row 1 in the transposed x which is column 1 in original x  
    [3]     # row 3 in the transposed x which is column 3
])
shape_transposed = (x.shape[1], x.shape[0])    # swap the shape

# create elements to update the transposed x like rows (which is columns in original x)
updates_transposed = tf.constant([
    [100,200,300],
    [400,500,600]
])

# Create transposed x like Tensor and update its rows with transposed_updates.
update_material_transposed = tf.scatter_nd(
    indices=indices_transposed, 
    updates=updates_transposed, 
    shape=shape_transposed
)
update_x_material = tf.transpose(update_material_transposed)
print(f"Tensor to use to update x\n{update_x_material}")

# update the original x
x + update_x_material

Tensor to use to update x
[[  0 100   0 400   0   0   0   0   0]
 [  0 200   0 500   0   0   0   0   0]
 [  0 300   0 600   0   0   0   0   0]]


<tf.Tensor: shape=(3, 9), dtype=int32, numpy=
array([[  0, 101,   2, 403,   4,   5,   6,   7,   8],
       [  9, 210,  11, 512,  13,  14,  15,  16,  17],
       [ 18, 319,  20, 621,  22,  23,  24,  25,  26]], dtype=int32)>

# tf.tensor_scatter_nd_update - update values at indices

* [Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?](https://stackoverflow.com/questions/67361081)

> indices has at least two axes, the last axis is the **depth of the index vectors**. For a higher rank input tensor scalar updates can be inserted by using an index_depth that matches tf.rank(tensor):


**Each index selects a scalar value**. For a tensor ```X:(N,S,D)```, each index is ```(n,s,d)``` that matches with subscripts to identify a unique element in ```X```. 

This also means ***```index depth == rank(X)```***. 

## index depth = rank(input)

* [Tensorflow 2 - what is 'index depth' in tensor_scatter_nd_update?](https://stackoverflow.com/a/67363360/4281353)

> * The **index depth** of indices must equal the **rank of the input tensor**
> * The length of updates must equal the length of the indices

If the shape of input X is ```(N,S,D)``` then the index depth is ```rank(X)=3```

### Indices for rank 1

<img src="image/ScatterNd1_1D.png" align="left" width="500"/><br>


### Index format

Cannot use slice as the indices.

In [6]:
x:tf.Tensor = tf.constant(np.arange(12).reshape(3, 4))
print(f"x before mutate: \n{x}")

tf.tensor_scatter_nd_update(  
    tensor=x,
    indices=[:, 1:3],   # Cannot use slice expression 
    updaes=0
)

SyntaxError: invalid syntax (2836163019.py, line 6)

In [16]:
X = tf.Variable(tf.reshape(tf.range(5, dtype=tf.int32), shape=(5,)))
print(f"{X.numpy()}\n")

indices = [   # Indices is of shape (N,1). 
    [1],      # [n=0,d=1]
    [3]       # [n=0,d=3]
]
updates = [0, 0]
print(f"Update values {tf.gather_nd(X, indices)} at indices {indices} with {updates}.\n")

print(f"Result {tf.tensor_scatter_nd_update(X, indices, updates).numpy()}")

[0 1 2 3 4]

Update values [1 3] at indices [[1], [3]] with [0, 0].

Result [0 0 2 0 4]


### more examples for 1D indices

In [29]:
def random_choice(a, size):
    """Random choice from 'a' based on size without duplicates
    Args:
        a: Tensor
        size: int or shape as tuple of ints e.g., (m, n, k).
    Returns: Tensor of the shape specified with 'size' arg.

    Examples:
        X = tf.constant([[1,2,3],[4,5,6]])
        random_choice(X, (2,1,2)).numpy()
        -----
        [
          [
            [5 4]
          ],
          [
            [1 2]
          ]
        ]
    """
    is_size_scalar: bool = \
        isinstance(size, int) or np.issubdtype(type(a), np.integer) or\
        (tf.is_tensor(a) and a.shape == () and a.dtype.is_integer)
    if is_size_scalar:
        shape = (size,)
    elif isinstance(size, tuple) and len(size) > 0:
        shape = size
    else:
        raise AssertionError(f"Unexpected size arg {size}")

    sample_size = tf.math.reduce_prod(size, axis=None)
    assert sample_size > 0

    # --------------------------------------------------------------------------------
    # Select elements from a flat array
    # --------------------------------------------------------------------------------
    a = tf.reshape(a, (-1))
    length = tf.size(a)
    assert sample_size <= length

    # --------------------------------------------------------------------------------
    # Shuffle a sequential numbers (0, ..., length-1) and take size.
    # To select 'sample_size' elements from a 1D array of shape (length,),
    # TF Indices needs to have the shape (sample_size,1) where each index
    # has shape (1,),
    # --------------------------------------------------------------------------------
    indices = tf.reshape(
        tensor=tf.random.shuffle(tf.range(0, length, dtype=tf.int32))[:sample_size],
        shape=(-1, 1)   # Convert to the shape:(sample_size,1)
    )
    return tf.reshape(tensor=tf.gather_nd(a, indices), shape=shape)

In [30]:
X = tf.constant([[1,2,3],[4,5,6]])
print(random_choice(X, (2,1)).numpy())

[[5]
 [3]]


### Indices for rank > 1

Further confusing, if rank > 1, then the shape of indices ```Indices:(N,D)``` match with the shape of ```X:(N,D)```.

In [20]:
X = tf.Variable(tf.reshape(tf.range(25, dtype=tf.int32), shape=(5,5)))
print(f"{X}\n")

indices = [
    [1, 2],     # (n=1,d=2)
    [2, 3]      # (n=2,d=3)
]
updates = [0, 0]
print(f"Update values {tf.gather_nd(X, indices).numpy()} at indices {indices} with {updates}\n")

print(f"result {tf.tensor_scatter_nd_update(X, indices, updates).numpy()}")

<tf.Variable 'Variable:0' shape=(5, 5) dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]], dtype=int32)>

Update values [ 7 13] at indices [[1, 2], [2, 3]] with [0, 0]

result [[ 0  1  2  3  4]
 [ 5  6  0  8  9]
 [10 11 12  0 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]
