The main takeaways and recommendations are:

* Debug in eager mode, then decorate with @tf.function.
* Don't rely on Python side effects like object mutation or list appends.
* tf.function works best with TensorFlow ops; NumPy and Python calls are converted to constants.

In [90]:
import tensorflow as tf
import numpy as np

### Variables may only be created once

In [91]:
state = tf.Variable(1)


@tf.function
def out_var(x):
    state.assign_add(x)


out_var(tf.constant(2))  # Non-pure functional style
state

<tf.Variable 'Variable:0' shape=() dtype=int32, numpy=3>

In [92]:
# 报错
@tf.function
def inner_var(x):
    y = tf.Variable(1)  # 避免在@tf.function修饰的函数内部定义tf.Variable
    y.assign_add(x)
    return y


inner_var(tf.constant(2))

ValueError: in user code:

    File "<ipython-input-9-6338904b622b>", line 4, in inner_var  *
        y = tf.Variable(1)

    ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.


### Python operations execute only once per trace

In [93]:
@tf.function
def f(a, b):
    print('this runs at trace time; a is', a, 'and b is', b)
    return b

In [94]:
f(1, tf.constant(1))

this runs at trace time; a is 1 and b is Tensor("b:0", shape=(), dtype=int32)


<tf.Tensor: shape=(), dtype=int32, numpy=1>

In [95]:
f(1, tf.constant(2))

<tf.Tensor: shape=(), dtype=int32, numpy=2>

In [96]:
l = []


@tf.function
def f(x):
    for i in x:
        l.append(i + 1)  # Caution! Will only happen once when tracing


f(tf.constant([1, 2, 3]))

In [97]:
# Any Python side-effects (appending to a list, printing with print, etc) will only happen once, when func is traced.
# To have side-effects executed into your tf.function they need to be written as TF ops:
@tf.function
def f(x):
    ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True)
    for i in range(len(x)):
        ta = ta.write(i, x[i] + 1)
    return ta.stack()


f(tf.constant([1, 2, 3]))

<tf.Tensor: shape=(3,), dtype=int32, numpy=array([2, 3, 4])>

In [98]:
@tf.function
def np_random():
    a = np.random.randn(3, 3)
    print(a)
    tf.print(a)


@tf.function
def tf_random():
    # 尽量使用TensorFlow中的函数
    a = tf.random.normal((3, 3))
    print(a)
    tf.print(a)

In [99]:
# 相同的返回值
np_random()
np_random()

[[ 1.31277232 -2.11379283  0.01048271]
 [-0.0120855  -0.15500913 -0.66727723]
 [ 0.4849212  -0.12695594 -0.29195129]]
array([[ 1.31277232, -2.11379283,  0.01048271],
       [-0.0120855 , -0.15500913, -0.66727723],
       [ 0.4849212 , -0.12695594, -0.29195129]])
array([[ 1.31277232, -2.11379283,  0.01048271],
       [-0.0120855 , -0.15500913, -0.66727723],
       [ 0.4849212 , -0.12695594, -0.29195129]])


In [100]:
tf_random()
tf_random()

Tensor("random_normal:0", shape=(3, 3), dtype=float32)
[[1.09408426 -0.149554357 0.353617042]
 [-0.530942678 -0.484824598 -0.156965494]
 [-0.0118089858 -0.0196795296 -1.14975464]]
[[0.477424145 -1.410748 -2.50422335]
 [0.569126129 -0.376163065 0.544540107]
 [0.207152724 -0.778396904 -0.743118286]]
