---
title: Tensorflow2 function
tags: 小书匠,tensorflow2,function
grammar_cjkRuby: true
renderNumberedHeading: true
---

[toc]

# Tensorflow2 function

In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier
and faster), but this can come at the expense of performance and deployability.

You can use `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.

This guide will help you conceptualize how `tf.function` works under the hood so you can use it effectively.

The main takeaways and recommendations are:

- Debug in eager mode, then decorate with `@tf.function`.
- Don't rely on Python side effects like object mutation or list appends.
- `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.

## Setup

In [2]:
import tensorflow as tf
print(tf.__version__)

2.4.1


Define a helper function to demonstrate the kinds of errors you might encounter:

In [2]:
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
    try:
        yield
    except error_class as e:
        print('Caught expected exception \n  {}:'.format(error_class))
        traceback.print_exc(limit=2)
    except Exception as e:
        raise e
    else:
        raise Exception('Expected {} to be raised but no error was raised!'.format(error_class))

## Basics

### Usage

A `Function` you define (for example by applying the `@tf.function` decorator) is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.

In [4]:
@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
    return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [5]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

You can use `Function`s inside other `Function`s.

In [6]:
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

`Function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.


In [4]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
    return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])

# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.8449165550118778
Function conv: 0.7333571990020573
Note how there's not much difference in performance for convolutions


### Tracing

This section exposes how `Function` works under the hood, including implementation details *which may change in the future*. However, once you understand why and when tracing happens, it's much easier to use `tf.function` effectively!

#### What is "tracing"?

A `Function` runs your program in a [TensorFlow Graph](https://www.tensorflow.org/guide/intro_to_graphs#what_are_graphs). However, a `tf.Graph` cannot represent all the things that you'd write in an eager TensorFlow program. For instance, Python supports polymorphism, but `tf.Graph` requires its inputs to have a specified data type and dimension. Or you may perform side tasks like reading command-line arguments, raising an error, or working with a more complex Python object; none of these things can run in a `tf.Graph`.

`Function` bridges this gap by separating your code in two stages:

  1)  In the first stage, referred to as "**tracing**", `Function` creates a new `tf.Graph`. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are *deferred*: they are captured by the `tf.Graph` and not run.

  2) In the second stage, a `tf.Graph` which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage.

Depending on its inputs, `Function` will not always run the first stage when it is called.  See ["Rules of tracing"](#rules_of_tracing) below to get a better sense of how it makes that determination. Skipping the first stage and only executing the second stage is what gives you TensorFlow's high performance.

When `Function` does decide to trace, the tracing stage is immediately followed by the second stage, so calling the `Function` both creates and runs the `tf.Graph`.

When we pass arguments of different types into a `Function`, both stages are run:

In [5]:
@tf.function
def my_func(a):
    print("Tracing with", a)
    return a + a

print(my_func(tf.constant(1))) # this will create a trace for tf.int32 tensor
print()
print(my_func(tf.constant(1.1))) # this will create a trace for tf.float32 tensor
print()
print(my_func(tf.constant("a"))) # this will create a trace for tf.string tensor
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



Note that if you repeatedly call a `Function` with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical.

In [26]:
# This doesn't print 'Tracing with ...'
print(my_func(tf.constant("b")))

tf.Tensor(b'bb', shape=(), dtype=string)


You can use `pretty_printed_concrete_signatures()` to see all of the available traces:

In [28]:
print(my_func.pretty_printed_concrete_signatures())

my_func(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_func(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

my_func(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()


#### 一些概念

So far, you've seen that `tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:

- A `tf.Graph` is the raw, language-agnostic, portable representation of a TensorFlow computation.
- A `ConcreteFunction` wraps a `tf.Graph`.
- A `Function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs，`Function` 和 `ConcreteFunction` 有点类似于泛型，Function 是没有指定类似的模版，ConcreteFunction 是指定类型后生成的模版函数
- `tf.function` wraps a Python function, returning a `Function` object.
- **Tracing** creates a `tf.Graph` and wraps it in a `ConcreteFunction`, also known as a **trace.**


#### Rules of tracing

A `Function` determines whether to reuse a traced `ConcreteFunction` by computing a **cache key** from an input's args and kwargs. A **cache key** is a key that identifies a `ConcreteFunction` based on the input args and kwargs of the `Function` call, according to the following rules (which may change):

- The key generated for a `tf.Tensor` is its shape and dtype.
- The key generated for a `tf.Variable` is a unique variable id.
- The key generated for a Python primitive (like `int`, `float`, `str`) is its value. 
- The key generated for nested `dict`s, `list`s, `tuple`s, `namedtuple`s, and [`attr`](https://www.attrs.org/en/stable/)s is the flattened tuple of leaf-keys (see `nest.flatten`).
- For all other Python types, the keys are based on the object `id()` so that methods are traced independently for each instance of a class.

Note: Cache keys are based on the `Function` input parameters so changes to global and free variables alone will not create a new trace. See [this section](#depending_on_python_global_and_free_variables) for recommended practices when dealing with Python global and free variables.

#### Controlling retracing



Retracing, which is when your `Function` creates more than one trace, helps ensures that TensorFlow generates correct graphs for each set of inputs. However, tracing is an expensive operation! If your `Function` retraces a new graph for every call, you'll find that your code executes more slowly than if you didn't use `tf.function`.

To control the tracing behavior, you can use the following techniques:

- Specify `input_signature` in `tf.function` to limit tracing.
- Cast Python arguments to Tensors to reduce retracing.

##### 使用 input_signature 限制函数接受的参数类型

In [8]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
    print("Tracing with", x)
    return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))

# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
    next_collatz(tf.constant([[1, 2], [3, 4]]))

# We specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
    next_collatz(tf.constant([1.0, 2.0]))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-2-af6dbe3f6b0e>", line 8, in assert_raises
    yield
  File "<ipython-input-8-f5f57cd00bde>", line 10, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))
Traceback (most recent call last):
  File "<ipython-input-2-af6dbe3f6b0e>", line 8, in assert_raises
    yield
  File "<ipython-input-8-f5f57cd00bde>", line 14, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))


##### 将 Python object 转换为 Tensorflow object

Often, Python arguments are used to control hyperparameters and graph constructions - for example, `num_layers=10` or `training=True` or `nonlinearity='relu'`. So if the Python argument changes, it makes sense that you'd have to retrace the graph.

However, it's possible that a Python argument is not being used to control graph construction. In these cases, a change in the Python value can trigger needless retracing. Take, for example, this training loop, which AutoGraph will dynamically unroll. Despite the multiple traces, the generated graph is actually identical, so retracing is unnecessary.

In [9]:
def train_one_step():
    pass

@tf.function
def train(num_steps):
    print("Tracing with num_steps = ", num_steps)
    tf.print("Executing with num_steps = ", num_steps)
    for _ in tf.range(num_steps):
        train_one_step()


print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20


### Obtaining concrete functions

Every time a function is traced, a new concrete function is created. You can directly obtain a concrete function, by using `get_concrete_function`.

In [50]:
print("Obtaining concrete trace")
double_strings = my_func.get_concrete_function(tf.constant("a"))

print("Executing traced function")
print(double_strings(tf.constant('a')))
print(double_strings(tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)


In [51]:
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = my_func.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))

tf.Tensor(b'cc', shape=(), dtype=string)


Printing a `ConcreteFunction` displays a summary of its input arguments (with types) and its output type.

In [47]:
print(double_strings)

ConcreteFunction my_func(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()


You can also directly retrieve a concrete function's signature.

In [52]:
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)

((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)


Using a concrete trace with incompatible types will throw an error

In [53]:
with assert_raises(tf.errors.InvalidArgumentError):
    double_strings(tf.constant(1))

Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:


Traceback (most recent call last):
  File "<ipython-input-30-af6dbe3f6b0e>", line 8, in assert_raises
    yield
  File "<ipython-input-53-6ba07b14b756>", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_my_func_478 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_my_func_478]


You may notice that Python arguments are given special treatment in a concrete function's input signature. Prior to TensorFlow 2.3, Python arguments were simply removed from the concrete function's signature. Starting with TensorFlow 2.3, Python arguments remain in the signature, but are constrained to take the value set during tracing.

In [54]:
@tf.function
def pow(a, b):
  return a ** b

square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)

ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>


In [56]:
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
    square(tf.constant(10.0), b=3)

Caught expected exception 
  <class 'TypeError'>:


Traceback (most recent call last):
  File "/anaconda3/envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1683, in _call_impl
    cancellation_manager)
  File "/anaconda3/envs/tensorflow2/lib/python3.6/site-packages/tensorflow/python/eager/function.py", line 1728, in _call_with_flat_signature
    self._flat_signature_summary(), ", ".join(sorted(kwargs))))
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "<ipython-input-30-af6dbe3f6b0e>", line 8, in assert_raises
    yield
  File "<ipython-input-56-196bd40bd949>", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3


### Obtaining graphs

Each concrete function is a callable wrapper around a `tf.Graph`. Although retrieving the actual `tf.Graph` object is not something you'll normally need to do, you can obtain it easily from any concrete function.

In [57]:
graph = double_strings.graph
for node in graph.as_graph_def().node:
    print(f'{node.input} -> {node.name}')

[] -> a
['a', 'a'] -> add
['add'] -> Identity


### How to Debug

In general, debugging code is easier in eager mode than inside `tf.function`. You should ensure that your code executes error-free in eager mode before decorating with `tf.function`. To assist in the debugging process, you can call `tf.config.run_functions_eagerly(True)` to globally disable and reenable `tf.function`.

When tracking down issues that only appear within `tf.function`, here are some tips:
- Plain old Python `print` calls only execute during tracing, helping you track down when your function gets (re)traced.
- `tf.print` calls will execute every time, and can help you track down intermediate values during execution.
- `tf.debugging.enable_check_numerics` is an easy way to track down where NaNs and Inf are created.
- `pdb` can help you understand what's going on during tracing. (Caveat: PDB will drop you into AutoGraph-transformed source code.)

## AutoGraph Transformations

AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.

TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python.

In [59]:
# Simple loop

@tf.function
def f(x):
    while tf.reduce_sum(x) > 1:
        tf.print(x)
        x = tf.tanh(x)
    return x

f(tf.random.uniform([5]))

[0.648117423 0.200547457 0.858983517 0.445949674 0.367247343]
[0.570401251 0.197901383 0.695733666 0.418563932 0.351581633]
[0.515653908 0.195357621 0.601652801 0.395719945 0.337777466]
[0.47433874 0.192909718 0.538224638 0.376280844 0.325491756]
[0.441698253 0.190551803 0.491642922 0.359473228 0.314464122]
[0.415051132 0.188278496 0.455519408 0.344749928 0.304492801]
[0.392753124 0.186084837 0.426425368 0.331711322 0.295418739]
[0.373731226 0.183966264 0.402329624 0.320057631 0.287114531]
[0.357251018 0.181918606 0.381940514 0.309559017 0.279476881]
[0.342790306 0.179937974 0.364391506 0.300035864 0.272420824]
[0.329966187 0.17802082 0.349076271 0.291345417 0.265875965]
[0.318490386 0.176163763 0.33555609 0.283372641 0.259783238]
[0.308141291 0.174363762 0.323504299 0.276023626 0.254092753]
[0.298745185 0.172617927 0.312672079 0.269220859 0.248762026]
[0.290163845 0.170923606 0.302866 0.262899667 0.243754581]
[0.282285601 0.169278309 0.293933213 0.257005662 0.239038929]
[0.275019109 0

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.21683306, 0.1514686 , 0.2219612 , 0.20475668, 0.1953355 ],
      dtype=float32)>

If you're curious you can inspect the code autograph generates.

In [60]:
print(tf.autograph.to_code(f.python_function))

def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)



### Conditionals

AutoGraph will convert some `if <condition>` statements into the equivalent `tf.cond` calls. This substitution is made if `<condition>` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

`tf.cond` traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see [AutoGraph tracing effects](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process) for more.

In [68]:
@tf.function
def fizzbuzz(n):
    for i in tf.range(1, n + 1):
        print('Tracing for loop')
        if i % 15 == 0:
            print('Tracing fizzbuzz branch')
            tf.print('fizzbuzz')
        elif i % 3 == 0:
            print('Tracing fizz branch')
            tf.print('fizz')
        elif i % 5 == 0:
            print('Tracing buzz branch')
            tf.print('buzz')
        else:
            print('Tracing default branch')
        tf.print(i)

fizzbuzz(tf.constant(3))
print()
fizzbuzz(tf.constant(5))

Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
3

1
2
fizz
3
4
buzz
5


### Loops

AutoGraph will convert some `for` and `while` statements into the equivalent TensorFlow looping ops, like `tf.while_loop`. If not converted, the `for` or `while` loop is executed as a Python loop.

This substitution is made in the following situations:

- `for x in y`: if `y` is a Tensor, convert to `tf.while_loop`. In the special case where `y` is a `tf.data.Dataset`, a combination of `tf.data.Dataset` ops are generated.
- `while <condition>`: if `<condition>` is a Tensor, convert to `tf.while_loop`.

A Python loop executes during tracing, adding additional ops to the `tf.Graph` for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time.  The loop body only appears once in the generated `tf.Graph`.

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) for additional restrictions on AutoGraph-converted `for` and `while` statements.

#### 定义计算图中节点数的辅助函数

In [24]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

#### Looping over Python data

A common pitfall is to loop over Python/Numpy data within a `tf.function`. This loop will execute during the tracing process, adding a copy of your model to the `tf.Graph` for each iteration of the loop.

If you want to wrap the entire training loop in `tf.function`, the safest way to do this is to wrap your data as a `tf.data.Dataset` so that AutoGraph will dynamically unroll the training loop.

In [6]:
@tf.function
def train(dataset):
    loss = tf.constant(0)
    for x, y in dataset:
        loss += tf.abs(y - x)  # Some dummy computation.
    return loss


small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10

measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 10 nodes in its graph


When wrapping Python/Numpy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensors`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the [tf.data guide](../../guide/data).

#### 使用 tf.range 代替 range

使用 range，会展开图，因此其中会有许多节点

In [21]:
@tf.function
def train(n):
    for i in range(n):
        tf.print(i)  
        
measure_graph_size(train, 100)

train(100) contains 200 nodes in its graph


而使用 tf.range，不会展开图，

In [22]:
@tf.function
def train(n):
    for i in tf.range(n):
        tf.print(i)  

measure_graph_size(train, 100)

train(100) contains 16 nodes in its graph


#### Use tf.TensorArray to accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use `tf.TensorArray` to accumulate results from a dynamically unrolled loop.

In [79]:
batch_size = 2
seq_len = 3
feature_size = 4


def rnn_step(inp, state):
    return inp + state


@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
    # [batch, time, features] -> [time, batch, features]
    input_data = tf.transpose(input_data, [1, 0, 2])
    max_seq_len = input_data.shape[0]

    states = tf.TensorArray(tf.float32, size=max_seq_len)
    state = initial_state
    for i in tf.range(max_seq_len):
        state = rnn_step(input_data[i], state)
        states = states.write(i, state)
    return tf.transpose(states.stack(), [1, 0, 2])


dynamic_rnn(rnn_step, tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.94137156, 0.9822793 , 0.64925754, 0.2964362 ],
        [0.99526787, 1.4181706 , 1.0210431 , 0.47227967],
        [1.1410419 , 2.0621977 , 1.3571397 , 0.99516356]],

       [[0.801937  , 0.52088714, 0.38815653, 0.54956937],
        [1.7857704 , 0.87730646, 0.41332817, 0.8816924 ],
        [2.5854597 , 1.5726612 , 0.8366827 , 1.3635818 ]]], dtype=float32)>

## Limitations

TensorFlow `Function` has a few limitations by design that you should be aware of when converting a Python function to a `Function`.

### Executing Python side effects

Side effects, like printing, appending to lists, and mutating globals, can behave unexpectedly inside a `Function`, sometimes executing twice or not all. They only happen the first time you call a `Function` with a set of inputs.  Afterwards, the traced `tf.Graph` is reexecuted, without executing the Python code.

The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces. Otherwise, TensorFlow APIs like `tf.data`, `tf.print`, `tf.summary`, `tf.Variable.assign`, and `tf.TensorArray` are the best way to ensure your code will be executed by the TensorFlow runtime with each call.

In [28]:
@tf.function
def f(x):
    print("Traced with", x)
    tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2


If you would like to execute Python code during each invocation of a `Function`, `tf.py_function` is an exit hatch. The drawback of `tf.py_function` is that it's not portable or particularly performant, cannot be saved with SavedModel, and does not work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors.

#### Changing Python global and free variables

Changing Python global and free variables counts as a Python side effect, so it only happens during tracing.


In [101]:
external_list = []

@tf.function
def side_effect(x):
    print('Python side effect') # only traced once
    external_list.append(x) # only traced once

side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1

Python side effect


You should avoid mutating containers like lists, dicts, other objects that live outside the `Function`. Instead, use arguments and TF objects. For example, the section ["Accumulating values in a loop"](#accumulating_values_in_a_loop) has one example of how list-like operations can be implemented.

You can, in some cases, capture and manipulate state if it is a [`tf.Variable`](https://www.tensorflow.org/guide/variable). This is how the weights of Keras models are updated with repeated calls to the same `ConcreteFunction`.

#### Using Python iterators and generators

Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, they are examples of Python side effects and therefore only happen during tracing.

Python 对于 iterators 和 generators 属于 Python side effect，因此只 tracing 一次。m

In [104]:
@tf.function
def buggy_consume_next(iterator):
    tf.print("Value:", next(iterator))

iterator = iter([1, 2, 3])
buggy_consume_next(iterator)

# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)

Value: 1
Value: 1
Value: 1


解决方法是用 tf.data.Iterator 来代替 python 的 iterators 或者 generators

In [103]:
@tf.function
def good_consume_next(iterator):
    # This is ok, iterator is a tf.data.Iterator
    tf.print("Value:", next(iterator))

ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)

Value: 1
Value: 2
Value: 3


### Deleting tf.Variables between `Function` calls

Another error you may encounter is a garbage-collected variable. `ConcreteFunction`s only retain [WeakRefs](https://docs.python.org/3/library/weakref.html) to the variables they close over, so you must retain a reference to any variables.

In [32]:
external_var = tf.Variable(3)
@tf.function
def f(x):
    return x * external_var


traced_f = f.get_concrete_function(4)
print("Calling concrete function...")
print(traced_f(4))

# The original variable object gets garbage collected, since there are no more
# references to it.
external_var = tf.Variable(4)
print()
print("Calling concrete function after garbage collecting its closed Variable...")
with assert_raises(tf.errors.FailedPreconditionError):
    traced_f(4)

Calling concrete function...
tf.Tensor(12, shape=(), dtype=int32)

Calling concrete function after garbage collecting its closed Variable...
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.FailedPreconditionError'>:


Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-9a93d2e07632>", line 16, in <module>
    traced_f(4)
tensorflow.python.framework.errors_impl.FailedPreconditionError: 2 root error(s) found.
  (0) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
	 [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
  (1) Failed precondition:  Error while reading resource variable _AnonymousVar3 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar3/N10tensorflow3VarE does not exist.
	 [[node ReadVariableOp (defined at <ipython-input-1-9a93d2e07632>:4) ]]
	 [[ReadVariableOp/_2]]
0 successful operations.
0 derived errors ignored. [Op:__inference_f_78

## Known Issues

If your `Function` is not evaluating correctly, the error may be explained by these known issues which are planned to be fixed in the future.

### Close free variables

In [123]:
@tf.function
def buggy_add():
    return 1 + foo

@tf.function
def recommended_add(foo):
    return 1 + foo

foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))

print()
print("Updating the value of `foo` to 2!")
foo += 1
print("Buggy:", buggy_add())  # Did not change!
print("Correct:", recommended_add(foo))

Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)

Updating the value of `foo` to 2!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(3, shape=(), dtype=int32)


You can close over outer names, **as long as you don't update their values.**

如果将上面的 foo 换成 tf.Variable，会怎么样呢？

In [125]:
@tf.function
def buggy_add():
    return 1 + foo

@tf.function
def recommended_add(foo):
    return 1 + foo

foo = tf.Variable(0)
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))

print()
print("Updating the value of `foo` to 2!")
foo.assign_add(1) # 2
print("Buggy:", buggy_add())  # correct!
print("Correct:", recommended_add(foo)) # correct

Buggy: tf.Tensor(1, shape=(), dtype=int32)
Correct: tf.Tensor(1, shape=(), dtype=int32)

Updating the value of `foo` to 2!
Buggy: tf.Tensor(2, shape=(), dtype=int32)
Correct: tf.Tensor(2, shape=(), dtype=int32)


可以看到，结果是正确，因为在 tracing 的时候将 foo 这个节点也添加到 graph 中了，因此 foo 变化时，graph 仍然可以访问到 foom

但是如果换成下面这种写法的话，会直接报错！这次我们用 foo = foo + 1 来更新 foo，和 foo.assign_add 不同，foo.assign_add 会原地修改 foo 的值，而 foo = foo + 1 会重新生成一个节点，因此 graph 无法找到原来的节点，所以会直接报错。

In [138]:
@tf.function
def buggy_add():
    return 1 + foo

@tf.function
def recommended_add(foo):
    return 1 + foo

foo = tf.Variable(0)
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))

print()
print("Updating the value of `foo` to 2!")
foo = foo + 1
with assert_raises(Exception):
    print("Buggy:", buggy_add())  # Error!
print("Correct:", recommended_add(foo))

Buggy: tf.Tensor(1, shape=(), dtype=int32)
Correct: tf.Tensor(1, shape=(), dtype=int32)

Updating the value of `foo` to 2!
Caught expected exception 
  <class 'Exception'>:
Correct: tf.Tensor(2, shape=(), dtype=int32)


Traceback (most recent call last):
  File "<ipython-input-129-af6dbe3f6b0e>", line 8, in assert_raises
    yield
  File "<ipython-input-138-4c7ba1ed9810>", line 17, in <module>
    print("Buggy:", buggy_add())  # Error!
tensorflow.python.framework.errors_impl.FailedPreconditionError:  Error while reading resource variable _AnonymousVar40 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/_AnonymousVar40/N10tensorflow3VarE does not exist.
	 [[node ReadVariableOp (defined at <ipython-input-117-ac0e3fd57131>:3) ]] [Op:__inference_buggy_add_3299]

Function call stack:
buggy_add



#### Depending on Python objects

The recommendation to pass Python objects as arguments into `tf.function` has a number of known issues, that are expected to be fixed in the future. In general, you can rely on consistent tracing if you use a Python primitive or `tf.nest`-compatible structure as an argument or pass in a *different* instance of an object into a `Function`. However, `Function` will *not* create a new trace when you pass **the same object and only change its attributes**.

In [74]:
class SimpleModel(tf.Module):
    def __init__(self):
        # These values are *not* tf.Variables.
        self.bias = 0.
        self.weight = 2.


@tf.function
def evaluate(model, x):
    return model.weight * x + model.bias


simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))

tf.Tensor(20.0, shape=(), dtype=float32)


In [75]:
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x))  # Didn't change :(

Adding bias!
tf.Tensor(20.0, shape=(), dtype=float32)


Using the same `Function` to evaluate the updated instance of the model will be buggy since the updated model has the [same cache key](#rules_of_tracing) as the original model.

For that reason, we recommend that you write your `Function` to avoid depending on mutable object attributes or create new objects.

If that is not possible, one workaround is to make new `Function`s each time you modify your object to force retracing:

In [76]:
def evaluate(model, x):
    return model.weight * x + model.bias


new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))

tf.Tensor(20.0, shape=(), dtype=float32)


In [77]:

print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.

Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)


As [retracing can be expensive](https://www.tensorflow.org/guide/intro_to_graphs#tracing_and_performance), you can use `tf.Variable`s as object attributes, which can be mutated (but not changed, careful!) for a similar effect without needing a retrace.

In [39]:
class BetterModel:

    def __init__(self):
        self.bias = tf.Variable(0.)
        self.weight = tf.Variable(2.)


@tf.function
def evaluate(model, x):
    return model.weight * x + model.bias


better_model = BetterModel()
print(evaluate(better_model, x))

tf.Tensor(20.0, shape=(), dtype=float32)


In [40]:
print("Adding bias!")
better_model.bias.assign_add(5.0)  # Note: instead of better_model.bias += 5
print(evaluate(better_model, x))  # This works!

Adding bias!
tf.Tensor(25.0, shape=(), dtype=float32)


### Creating tf.Variables

`Function` only supports creating variables once, when first called, and then reusing them. You cannot create `tf.Variables` in new traces. Creating new variables in subsequent calls is currently not allowed, but will be in the future.


Example:

In [41]:
@tf.function
def f(x):
    v = tf.Variable(1.0)
    return v


with assert_raises(ValueError):
    f(1.0)

Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-8a0913e250e0>", line 7, in <module>
    f(1.0)
ValueError: in user code:

    <ipython-input-1-8a0913e250e0>:3 f  *
        v = tf.Variable(1.0)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:262 __call__  **
        return cls._variable_v2_call(*args, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:256 _variable_v2_call
        shape=shape)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/ops/variables.py:67 getter
        return captured_getter(captured_previous, **kwargs)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/eager/def_function.py:731 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first

You can create variables inside a `Function` as long as those variables are only created the first time the function is executed.

In [42]:
class Count(tf.Module):
  def __init__(self):
    self.count = None

  @tf.function
  def __call__(self):
    if self.count is None:
      self.count = tf.Variable(0)
    return self.count.assign_add(1)

c = Count()
print(c())
print(c())

tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)


#### Using with multiple Keras optimizers
You may encounter `ValueError: tf.function-decorated function tried to create variables on non-first call.` when using more than one Keras optimizer with a `tf.function`. This error occurs because optimizers internally create `tf.Variables` when they apply gradients for the first time.

In [43]:
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(w, x, y, optimizer):
    with tf.GradientTape() as tape:
        L = tf.reduce_sum(tf.square(w*x - y))
    gradients = tape.gradient(L, [w])
    optimizer.apply_gradients(zip(gradients, [w]))


w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
    train_step(w, x, y, opt2)

Calling `train_step` with different optimizer...
Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-1-73d0ca52e838>", line 8, in assert_raises
    yield
  File "<ipython-input-1-d3d3937dbf1a>", line 18, in <module>
    train_step(w, x, y, opt2)
ValueError: in user code:

    <ipython-input-1-d3d3937dbf1a>:9 train_step  *
        optimizer.apply_gradients(zip(gradients, [w]))
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:604 apply_gradients  **
        self._create_all_weights(var_list)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:781 _create_all_weights
        _ = self.iterations
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:788 __getattribute__
        return super(OptimizerV2, self).__getattribute__(name)
    /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:926 iterations
        aggreg

If you need to change the optimizer during training, a workaround is to create a new `Function` for each optimizer, calling the [`ConcreteFunction`](#obtaining_concrete_functions) directly.

In [78]:
opt1 = tf.keras.optimizers.Adam(learning_rate=1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Not a tf.function.


def train_step(w, x, y, optimizer):
    with tf.GradientTape() as tape:
        L = tf.reduce_sum(tf.square(w*x - y))
    gradients = tape.gradient(L, [w])
    optimizer.apply_gradients(zip(gradients, [w]))


w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])

# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
    if i % 2 == 0:
        train_step_1(w, x, y)  # `opt1` is not used as a parameter.
    else:
        train_step_2(w, x, y)  # `opt2` is not used as a parameter.

#### Using with multiple Keras models

You may also encounter `ValueError: tf.function-decorated function tried to create variables on non-first call.` when passing different model instances to the same `Function`.

This error occurs because Keras models (which [do not have their input shape defined](https://www.tensorflow.org/guide/keras/custom_layers_and_models#best_practice_deferring_weight_creation_until_the_shape_of_the_inputs_is_known)) and Keras layers create `tf.Variables`s when they are first called. You may be attempting to initialize those variables inside a `Function`, which has already been called. To avoid this error, try calling `model.build(input_shape)` to initialize all the weights before training the model.


## Further reading

To learn about how to export and load a `Function`, see the [SavedModel guide](../../guide/saved_model). To learn more about graph optimizations that are performed after tracing, see the [Grappler guide](../../guide/graph_optimization). To learn how to optimize your data pipeline and profile your model, see the [Profiler guide](../../guide/profiler.md).

# References
- http://localhost:8888/lab/workspaces/auto-i/tree/DL-Project/learnTensorflow/Tensorflow2%20guide/function.ipynb#rules_of_tracing
- [Better performance with tf.function  |  TensorFlow Core](https://www.tensorflow.org/guide/function)