In [1]:
import tensorflow as tf
import timeit
from datetime import datetime


In [2]:
# 定义一个 Python 函数
def a_regular_function(x, y, b):
    x = tf.matmul(x, y)
    x = x + b
    return x


# `a_function_that_uses_a_graph` 是一个 TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)

# 创建张量
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

orig_value = a_regular_function(x1, y1, b1).numpy()
# 调用 `Function` 和调用 Python 函数一样
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert orig_value == tf_function_value


In [3]:
def simple_relu(x):
    if tf.greater(x, 0):
        return x
    else:
        return 0


# `tf_simple_relu` is a TensorFlow `Function` that wraps `simple_relu`.
tf_simple_relu = tf.function(simple_relu)

print("First branch, with graph:", tf_simple_relu(tf.constant(1)).numpy())
print("Second branch, with graph:", tf_simple_relu(tf.constant(-1)).numpy())

First branch, with graph: 1
Second branch, with graph: 0


In [4]:
# This is the graph-generating output of AutoGraph.
print(tf.autograph.to_code(simple_relu))

def tf__simple_relu(x):
    with ag__.FunctionScope('simple_relu', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal do_return, retval_
            (do_return, retval_) = vars_

        def if_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = ag__.ld(x)
            except:
                do_return = False
                raise

        def else_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = 0
            except:
                do_return = False
                raise
        ag__.if_stmt(ag__.converted_call(ag__.ld(tf).greater, (ag__.ld(x), 0), None, fscope), if_bo

In [5]:
# This is the graph itself.
print(tf_simple_relu.get_concrete_function(tf.constant(1)).graph.as_graph_def())

node {
  name: "x"
  op: "Placeholder"
  attr {
    key: "_user_specified_name"
    value {
      s: "x"
    }
  }
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "shape"
    value {
      shape {
      }
    }
  }
}
node {
  name: "Greater/y"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
        }
        int_val: 0
      }
    }
  }
}
node {
  name: "Greater"
  op: "Greater"
  input: "x"
  input: "Greater/y"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
}
node {
  name: "cond"
  op: "StatelessIf"
  input: "Greater"
  input: "x"
  attr {
    key: "Tcond"
    value {
      type: DT_BOOL
    }
  }
  attr {
    key: "Tin"
    value {
      list {
        type: DT_INT32
      }
    }
  }
  attr {
    key: "Tout"
    value {
      list {
        type: DT_BOOL
        type: DT_INT32
      }
    }
  

In [6]:
@tf.function
def my_relu(x):
    return tf.maximum(0., x)


# `my_relu` creates new graphs as it observes more signatures.
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1]))
print(my_relu(tf.constant([3., -3.])))

tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)


In [7]:
# These two calls do *not* create new graphs.
print(my_relu(tf.constant(-2.5)))  # Signature matches `tf.constant(5.5)`.
print(my_relu(tf.constant([-1., 1.])))  # Signature matches `tf.constant([3., -3.])`.

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor([0. 1.], shape=(2,), dtype=float32)


In [8]:
# There are three `ConcreteFunction`s (one for each graph) in `my_relu`.
# The `ConcreteFunction` also knows the return type and shape!
print(my_relu.pretty_printed_concrete_signatures())

my_relu(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

my_relu(x=[1, -1])
  Returns:
    float32 Tensor, shape=(2,)

my_relu(x)
  Args:
    x: float32 Tensor, shape=(2,)
  Returns:
    float32 Tensor, shape=(2,)


In [9]:
@tf.function
def get_MSE(y_true, y_pred):
    sq_diff = tf.pow(y_true - y_pred, 2)
    return tf.reduce_mean(sq_diff)

In [10]:
y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)
print(y_true)
print(y_pred)

tf.Tensor([2 8 7 7 7], shape=(5,), dtype=int32)
tf.Tensor([6 9 6 7 8], shape=(5,), dtype=int32)


In [11]:
get_MSE(y_true, y_pred)

<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [12]:
tf.config.run_functions_eagerly(True)

In [13]:
get_MSE(y_true, y_pred)

<tf.Tensor: shape=(), dtype=int32, numpy=3>

In [14]:
# Don't forget to set it back when you are done.
tf.config.run_functions_eagerly(False)

In [15]:
@tf.function
def get_MSE(y_true, y_pred):
    print("Calculating MSE!")
    sq_diff = tf.pow(y_true - y_pred, 2)
    return tf.reduce_mean(sq_diff)

In [16]:
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

Calculating MSE!


In [17]:
# Now, globally set everything to run eagerly to force eager execution.
tf.config.run_functions_eagerly(True)

In [18]:
# Observe what is printed below.
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

Calculating MSE!
Calculating MSE!
Calculating MSE!


In [19]:
tf.config.run_functions_eagerly(False)


In [20]:
def unused_return_eager(x):
    # Get index 1 will fail when `len(x) == 1`
    tf.gather(x, [1])  # unused
    return x


try:
    print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:
    # All operations are run during eager execution so an error is raised.
    print(f'{type(e).__name__}: {e}')

tf.Tensor([0.], shape=(1,), dtype=float32)


In [21]:
@tf.function
def unused_return_graph(x):
    tf.gather(x, [1])  # unused
    return x


# Only needed operations are run during graph execution. The error is not raised.
print(unused_return_graph(tf.constant([0.0])))

tf.Tensor([0.], shape=(1,), dtype=float32)


In [22]:
x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)


def power(x, y):
    result = tf.eye(10, dtype=tf.dtypes.int32)
    for _ in range(y):
        result = tf.matmul(x, result)
    return result

In [23]:
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))

Eager execution: 4.8215627999998105


In [24]:
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))

Graph execution: 0.6316222000004927


In [25]:
@tf.function
def a_function_with_python_side_effect(x):
    print("Tracing!")  # An eager-only side effect.
    return x * x + tf.constant(2)


# This is traced the first time.
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect.
print(a_function_with_python_side_effect(tf.constant(3)))

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)


In [26]:
# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter.
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)


In [27]:
a = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3])

In [28]:
a

<tf.Tensor: shape=(2, 3), dtype=int32, numpy=
array([[1, 2, 3],
       [4, 5, 6]])>

In [29]:
a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])

NameError: name 'np' is not defined

In [30]:
import numpy as np

a = tf.constant(np.arange(1, 13, dtype=np.int32), shape=[2, 2, 3])

In [31]:
a

<tf.Tensor: shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])>