In [1]:
%%html
<style>
table {float:left}
</style>

# Tensorflow while loop

By default, Graph computation (Church) does not support loop, and Python loop primitives will not work. Hence, need [while_loop](https://www.tensorflow.org/api_docs/python/tf/while_loop) or use ```@tf.function``` to run python code within TF.

```
result = tf.while_loop(
    cond,
    body,
    loop_vars,
    shape_invariants=None,
    parallel_iterations=10,
    back_prop=True,
    swap_memory=False,
    maximum_iterations=None,
    name=None
) -> loop_vars

The result value has the same structure as loop_vars.
``` 

| Args                |                                                                                                                                                                                                                       |
|---------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| cond                | A callable that represents the termination condition of the loop.                                                                                                                                                     |
| body                | A callable that represents the loop body. The function get the elements of ```loop_vars``` as args, and returns updates of ```loop_vars```.                                                                           |
| loop_vars           | A (possibly nested) tuple, namedtuple or list of numpy array, Tensor, and TensorArray objects.                                                                                                                        |
| shape_invariants    | The shape invariants for the loop variables.                                                                                                                                                                          |
| parallel_iterations | The number of iterations allowed to run in parallel. It must be a positive integer.                                                                                                                                   |
| back_prop           | (optional) Deprecated. False disables support for back propagation. Prefer using tf.stop_gradient instead.                                                                                                            |
| swap_memory         | Whether GPU-CPU memory swap is enabled for this loop.                                                                                                                                                                 |
| maximum_iterations  | Optional maximum number of iterations of the while loop to run. If provided, the cond output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than maximum_iterations. |
| name                | Optional name prefix for the returned tensors.                                                                                                                                                                        |

## Constraints

**body** is a callable returning a ```(possibly nested) tuple, namedtuple or list of tensors of the same arity (length and structure) and types as loop_vars```. **loop_vars** is a (possibly nested) tuple, namedtuple or list of tensors that is passed to both cond and body. cond and body both take as many arguments as there are loop_vars.

In [2]:
import numpy as np
import tensorflow as tf

2023-11-14 16:10:37.506118: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-11-14 16:10:37.532507: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-14 16:10:37.532533: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-14 16:10:37.532549: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-14 16:10:37.537685: I tensorflow/core/platform/cpu_feature_g

In [3]:
# Cannot use scalar as loop vars, otherwise causes "TypeError: Scalar tensor has no `len()`" See the constraints.
# loop_vars = tf.constant(1)
loop_vars = [    # cannot use Tuple causing "ValueError: 'loop_vars' must be provided."
    tf.constant(0)
]

def condition(loop_var_element):
    return tf.less(loop_var_element, tf.constant(10))

def body(loop_var_element):
    """
    body receive loop_var element(s) and return updated loop_vars.
    loop_bars -> body_fn -> loop_vars. 
    This is how the loop_vars kept being updated and used.
    
    Args:
        loop_var_element: 

    Returns: updated loop_vars to be used as the next loop_vars
    """
    # cannot use scalar, need to return iterable. Otherwise, TypeError: Cannot iterate over a scalar tensor.
    # See constraints.
    # return tf.add(i, 1)
    as_tuple: bool = True
    
    # return value is to be used as next loop_vars
    return (tf.add(loop_var_element, 1), ) \
        if as_tuple else [tf.add(loop_var_element, 1)]    
    
result: tf.Tensor = tf.while_loop(
    condition, 
    body, 
    loop_vars
)[0]    # Use the first element of the loop_vars

result.numpy()

2023-11-14 16:10:38.673350: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2023-11-14 16:10:38.713089: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


10

# Loop through (N,S,S,C+2P)

P=(cp,x,y,w,h)  

In [4]:
TYPE_FLOAT = np.float32
TYPE_INT = np.int32
N = 2
S = 3
C = 5
B = 2
P = 5

## Tensor to loop through

In [5]:
predictions: tf.Tensor = tf.reshape(
    tensor=tf.zeros(N*S*S*(C+B*P), dtype=TYPE_FLOAT), 
    shape=(N,S,S,(C+B*P)),
)
predictions.shape

TensorShape([2, 3, 3, 15])

### loop_vars

In [13]:
num_cells_in_batch = tf.constant(S*S, dtype=TYPE_INT)
num_total_cells = tf.constant(N * num_cells_in_batch, dtype=TYPE_INT)
current_cell_index = tf.constant(0, dtype=TYPE_INT)
loop_vars = (
    current_cell_index,
    # prediction as a sequence of cells
    tf.reshape(tensor=predictions, shape=(-1, C+B*P))
)

TensorShape([2, 3, 3, 15])

### Condition

In [17]:
def condition(
        _current_cell_index, 
        _predictions
):
    """loop exit condition where _index >= _total"""
    return tf.less(_current_cell_index, num_total_cells)

### body

In [34]:
def body(
        _current_cell_index,
        _predictions
):
    """Update the (x,y) in P0 and P1 of the current cell"""
    _cell_index_in_current_batch = tf.cast(
        _current_cell_index % num_cells_in_batch,
        dtype=TYPE_INT
    )
    
    row: TYPE_FLOAT = tf.cast(
        tf.math.floor(_cell_index_in_current_batch / S),
        dtype=TYPE_FLOAT
    )
    col: TYPE_FLOAT = tf.cast(
        _cell_index_in_current_batch % S,
        dtype=TYPE_FLOAT
    )
    # tf.print("_current_cell_index", _current_cell_index)
    # tf.print("_cell_index_in_current_batch", _cell_index_in_current_batch)
    # tf.print("row", row)
    # tf.print("col", col)    
    
    _predictions = tf.tensor_scatter_nd_add(
        tensor=_predictions,
        indices=[
            [_current_cell_index, C+1],     # p0_x: x in (C,(cp,x,y,w,h)) 
            [_current_cell_index, C+2],     # p0_y 
            [_current_cell_index, C+P+1],   # p1_x  
            [_current_cell_index, C+P+2]    # p1_y 
        ],
        # https://stackoverflow.com/q/77478517/4281353
        # updates=tf.constant([
        #     col,                            # p0_x + col
        #     row,                            # p0_y + row
        #     col,                            # p0_x + col
        #     row                             # p0_y + row
        # ])
        updates=[
            col,                            # p0_x + col
            row,                            # p0_y + row
            col,                            # p0_x + col
            row                             # p0_y + row
        ]
    )        
    return [
        _current_cell_index+1,
        _predictions
    ]

In [39]:
result = tf.while_loop(
    cond=condition,
    body=body,
    loop_vars=loop_vars
)
final_cell_index = result[0]
updated_predictions = result[1]
tf.reshape(updated_predictions, (N,S,S,-1))
tf.assert_equal(final_cell_index, num_total_cells)
