In [3]:
import tensorflow as tf
import numpy as np

# Update the tensor values where the condition matches


Numpy can use boolean indexing to directy update the array.

In [2]:
x = np.random.uniform(-1, 1, size=(3, 4))
print(x)
x[x > 0] = 0
print(x)

[[-0.21862459 -0.11904506 -0.65759051 -0.02053271]
 [ 0.85862027  0.59237733 -0.78714513 -0.0176912 ]
 [-0.94936706 -0.30497186  0.25053833  0.18807091]]
[[-0.21862459 -0.11904506 -0.65759051 -0.02053271]
 [ 0.          0.         -0.78714513 -0.0176912 ]
 [-0.94936706 -0.30497186  0.          0.        ]]


Tensorflow does not have such syntax, hence need to use ```tf.where```.

[tf.where(condition, x=None, y=None, name=None)](https://www.tensorflow.org/api_docs/python/tf/where)

```
Returns:
    If x and y are provided: 
        A Tensor with the same type as x and y, and shape that is broadcast from condition, x, and y.
    Otherwise: 
        A Tensor with shape (num_true, dim_size(condition)).
```

* y is the **target** Tensor to set the element values from source Tensor ```x```
* x is the **source** Tensor the element values from which get set in the target tensor ```y```


It does **NOT** directly update the ```tf.Variable``` itself but need to use tf.Variable.assign() method. It can **NOT update elements** but **requires a entire tensor** that has the same shape of the target.

* [tensorflow 2 - how to conditionally update values directly in tf.Variable](https://stackoverflow.com/questions/66980404/tensorflow-2-how-to-conditionally-update-values-directly-in-tf-variable)

In [12]:
x = tf.Variable(np.random.uniform(-1, 1, size=(3,4)), dtype=tf.float32)
print(f"x:\n{x}\n")

# --------------------------------------------------------------------------------
# Boolean indices with the condition
# --------------------------------------------------------------------------------
select = x > 0
print(f"Boolean indices (x > 0):\n{select}\n")

# --------------------------------------------------------------------------------
# Update values in Variable wih boolean indices
# --------------------------------------------------------------------------------
x.assign(tf.where(select, 1, x))
print(f"x.assign(tf.where(x>0, 1, x)):\n{x}")

x:
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[-0.29402795,  0.85458314,  0.635479  ,  0.53761697],
       [ 0.48885283,  0.23887686, -0.8549232 ,  0.5042577 ],
       [ 0.5697476 , -0.7121538 ,  0.03971618, -0.14969674]],
      dtype=float32)>

Boolean indices (x > 0):
[[False  True  True  True]
 [ True  True False  True]
 [ True False  True False]]

x.assign(tf.where(x>0, 1, x)):
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[-0.29402795,  1.        ,  1.        ,  1.        ],
       [ 1.        ,  1.        , -0.8549232 ,  1.        ],
       [ 1.        , -0.7121538 ,  1.        , -0.14969674]],
      dtype=float32)>


* [How to efficiently update a tensor slice? #36559](https://github.com/tensorflow/tensorflow/issues/36559)

```
import tensorflow.keras.backend as K
units, timesteps = 4, 6
x = K.zeros((units, timesteps), dtype='float32', name='x')
x_new = x[:units, 0].assign(K.ones((units,), dtype='float32'))  # dummy example
K.set_value(x, K.get_value(x_new))
print(K.get_value(x))
```

---
# Usage of Where
## Update target Y with X upon condition tensor T

NOTE: The target ```Y``` itself will not be updated but creates a new Tensor of updated version of ```Y```.

In [18]:
TYPE_FLOAT = np.float32
N = 3
D = 3
shape = (N,D)

# Target
Y = tf.random.uniform(shape=shape, dtype=tf.dtypes.as_dtype(TYPE_FLOAT))
print(f"Target Y: \n{Y}\n")

# Condition (using numpy as doing the same is not simple in TF)
T = np.zeros(shape=shape, dtype=TYPE_FLOAT)
T[
    ::,
    0
] = 1
print(f"Condition T: \n{T}\n")

# Source
X = tf.reshape(tf.range(N*D, dtype=tf.dtypes.as_dtype(TYPE_FLOAT)), shape=shape)
print(f"Source X: \n{X}\n")

tf.where(condition=T, x=X, y=Y)

Target Y: 
[[0.5828624  0.62737167 0.13059998]
 [0.4447347  0.7729218  0.8134774 ]
 [0.87045276 0.90917516 0.52166355]]

Condition T: 
[[1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]]

Source X: 
[[0. 1. 2.]
 [3. 4. 5.]
 [6. 7. 8.]]



<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[0.        , 0.62737167, 0.13059998],
       [3.        , 0.7729218 , 0.8134774 ],
       [6.        , 0.90917516, 0.52166355]], dtype=float32)>