The main takeaways and recommendations are:

* Debug in eager mode, then decorate with @tf.function.
* Don't rely on Python side effects like object mutation or list appends.
* tf.function works best with TensorFlow ops; NumPy and Python calls are converted to constants.

In [12]:
import tensorflow as tf
import numpy as np

### Variables may only be created once

In [13]:
state = tf.Variable(1)


@tf.function
def out_var(x):
    state.assign_add(x)


out_var(tf.constant(2))  # Non-pure functional style
state

<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3>

In [14]:
# 报错
@tf.function
def inner_var(x):
    y = tf.Variable(1)  # 不能在@tf.function修饰的函数内部定义tf.Variable
    y.assign_add(x)
    return y


inner_var(tf.constant(2))

ValueError: in user code:

    File "<ipython-input-3-294165622916>", line 4, in inner_var  *
        y = tf.Variable(1)  # 不能在@tf.function修饰的函数内部定义tf.Variable

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.


### Python operations execute only once per trace

In [15]:
@tf.function
def f(a, b):
    print('this runs at trace time; a is', a, 'and b is', b)
    return b

In [16]:
f(1, tf.constant(1))

this runs at trace time; a is 1 and b is Tensor("b:0", shape=(), dtype=int32)


<tf.Tensor: shape=(), dtype=int32, numpy=1>

In [17]:
f(1, tf.constant(2))

<tf.Tensor: shape=(), dtype=int32, numpy=2>

In [18]:
l = []


@tf.function
def f(x):
    for i in x:
        l.append(i + 1)  # Caution! Will only happen once when tracing


f(tf.constant([1, 2, 3]))
l

[<tf.Tensor 'while/add:0' shape=() dtype=int32>]

In [19]:
# Any Python side-effects (appending to a list, printing with print, etc) will only happen once, when func is traced.
# To have side-effects executed into your tf.function they need to be written as TF ops:
@tf.function
def f(x):
    ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    for i in range(len(x)):
        ta = ta.write(i, x[i] + 1)
    return ta.stack()


f(tf.constant([1, 2, 3]))

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4])>

In [20]:
@tf.function
def np_random():
    a = np.random.randn(3, 3)
    print(a)
    tf.print(a)


@tf.function
def tf_random():
    # 尽量使用TensorFlow中的函数
    a = tf.random.normal((3, 3))
    print(a)
    tf.print(a)

In [21]:
# 相同的随机值
np_random()
np_random()

[[-0.70106162 -0.13377005  0.34766666]
 [-0.76547601 -0.41500235 -1.03717462]
 [-0.65215427 -0.48686552 -0.58266352]]
array([[-0.70106162, -0.13377005,  0.34766666],
       [-0.76547601, -0.41500235, -1.03717462],
       [-0.65215427, -0.48686552, -0.58266352]])
array([[-0.70106162, -0.13377005,  0.34766666],
       [-0.76547601, -0.41500235, -1.03717462],
       [-0.65215427, -0.48686552, -0.58266352]])


In [22]:
tf_random()
tf_random()

Tensor("random_normal:0", shape=(3, 3), dtype=float32)
[[-0.330740333 -0.549511969 0.142941669]
 [-0.695861876 0.651481628 -0.749943256]
 [0.0407161936 -0.24985671 -1.41266918]]
[[-1.29149783 1.04094684 0.972941]
 [-0.0687840134 -0.141839415 -0.313603848]
 [0.857643664 -1.21635115 0.0439375415]]
