<a href="https://colab.research.google.com/github/o0windseed0o/tensorflow_code_examples/blob/master/tensorflow_code_example.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
import numpy as np

'''
use a matrix as the input of the step function of tf.scan, and recurrently update
values in each index with values from inputs

** the usage of **
* tf.concat -> tf.split: concat and split
* tf.tensor_scatter_nd_update: update slices in matrix
'''

def step(matrix, inputs):
    update, idx = tf.split(inputs, [3,1], axis=1)
    idx = idx[0,:]
    idx = tf.reshape(tf.cast(idx, tf.int32), [1, 1])
    update = tf.expand_dims(update, 0)
    matrix = tf.tensor_scatter_nd_update(matrix, idx, update)
    return matrix

a = tf.zeros([4, 3, 3])
values = tf.ones([4, 3, 3])

# [4,3,1]
indices = tf.constant([[[0], [0], [0]],
                 [[1], [1], [1]],
                 [[2], [2], [2]],
                 [[3], [3], [3]]], dtype=float)

# [4,3,4]
input_array = tf.concat([values, indices], axis=2)

split0, split1 = tf.split(input_array, [3,1], axis=2)

states = tf.scan(step, input_array, initializer=a)

with tf.Session() as sess:
    sess.run(states)
    print(split0.eval())
    print(split1.eval())
    print(input_array.eval())
    print(states.eval())

[[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]]
[[[0.]
  [0.]
  [0.]]

 [[1.]
  [1.]
  [1.]]

 [[2.]
  [2.]
  [2.]]

 [[3.]
  [3.]
  [3.]]]
[[[1. 1. 1. 0.]
  [1. 1. 1. 0.]
  [1. 1. 1. 0.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]

 [[1. 1. 1. 2.]
  [1. 1. 1. 2.]
  [1. 1. 1. 2.]]

 [[1. 1. 1. 3.]
  [1. 1. 1. 3.]
  [1. 1. 1. 3.]]]
[[[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[1. 1. 1.]
   [1. 1. 1.]
   [1. 1. 1.]]

  [[0. 0. 0.]
   [0. 0. 0.]
   [0. 0. 0.]]]


 [[[1. 1. 

In [2]:
"""
dynamic mask: input a mask and update it from an index
"""
index = tf.convert_to_tensor(2)

source_mask = tf.constant([1,1,1,0,0,0,0])

update_mask = tf.sequence_mask(index, 7)

update_mask = tf.cast(update_mask, tf.int32)

result_mask = source_mask - update_mask

with tf.Session() as sess:
    sess.run(result_mask)
    print(index.eval())
    print(source_mask.eval())
    print(update_mask.eval())
    print(result_mask.eval())

2
[1 1 1 0 0 0 0]
[1 1 0 0 0 0 0]
[0 0 1 0 0 0 0]
