In [1]:
import tensorflow as tf
import timeit
from datetime import datetime

In [2]:
# Define a Python function
def function_to_get_faster(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Create a `Function` object that contains a graph
a_function_that_uses_a_graph = tf.function(function_to_get_faster)

# Make some tensors
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)

# It just works!
a_function_that_uses_a_graph(x1, y1, b1).numpy()

array([[12.]], dtype=float32)

In [3]:
def inner_function(x, y, b):
  x = tf.matmul(x, y)
  x = x + b
  return x

# Use the decorator
@tf.function
def outer_function(x):
  y = tf.constant([[2.0], [3.0]])
  b = tf.constant(4.0)

  return inner_function(x, y, b)

# Note that the callable will create a graph that
# includes inner_function() as well as outer_function()
outer_function(tf.constant([[1.0, 2.0]])).numpy()

array([[12.]], dtype=float32)

In [4]:
def my_function(x):
  if tf.reduce_sum(x) <= 1:
    return x * x
  else:
    return x-1

a_function = tf.function(my_function)

print("First branch, with graph:", a_function(tf.constant(1.0)).numpy())
print("Second branch, with graph:", a_function(tf.constant([5.0, 5.0])).numpy())

First branch, with graph: 1.0
Second branch, with graph: [4. 4.]


In [5]:
# Don't read the output too carefully.
print(tf.autograph.to_code(my_function))

def tf__my_function(x):
    with ag__.FunctionScope('my_function', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (do_return, retval_)

        def set_state(vars_):
            nonlocal do_return, retval_
            (do_return, retval_) = vars_

        def if_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = (ag__.ld(x) * ag__.ld(x))
            except:
                do_return = False
                raise

        def else_body():
            nonlocal do_return, retval_
            try:
                do_return = True
                retval_ = (ag__.ld(x) - 1)
            except:
                do_return = False
                raise
        ag__.if_stmt((ag__.converted_call(ag__.ld(tf).reduce_sum, (ag

In [6]:
# Create an oveerride model to classify pictures
class SequentialModel(tf.keras.Model):
  def __init__(self, **kwargs):
    super(SequentialModel, self).__init__(**kwargs)
    self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
    self.dense_1 = tf.keras.layers.Dense(128, activation="relu")
    self.dropout = tf.keras.layers.Dropout(0.2)
    self.dense_2 = tf.keras.layers.Dense(10)

  def call(self, x):
    x = self.flatten(x)
    x = self.dense_1(x)
    x = self.dropout(x)
    x = self.dense_2(x)
    return x

input_data = tf.random.uniform([60, 28, 28])

eager_model = SequentialModel()
graph_model = tf.function(eager_model)

print("Eager time:", timeit.timeit(lambda: eager_model(input_data), number=10000))
print("Graph time:", timeit.timeit(lambda: graph_model(input_data), number=10000))


Eager time: 5.0723486938513815
Graph time: 2.7211625669151545


In [7]:
print(a_function)

print("Calling a `Function`:")
print("Int:", a_function(tf.constant(2)))
print("Float:", a_function(tf.constant(2.0)))
print("Rank-1 tensor of floats", a_function(tf.constant([2.0, 2.0, 2.0])))

<tensorflow.python.eager.def_function.Function object at 0x7f0a1453acf8>
Calling a `Function`:
Int: tf.Tensor(1, shape=(), dtype=int32)
Float: tf.Tensor(1.0, shape=(), dtype=float32)
Rank-1 tensor of floats tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)


In [8]:
# Get the concrete function that works on floats
print("Inspecting concrete functions")
print("Concrete function for float:")
print(a_function.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.float32)))
print("Concrete function for tensor of floats:")
print(a_function.get_concrete_function(tf.constant([2.0, 2.0, 2.0])))


Inspecting concrete functions
Concrete function for float:
ConcreteFunction my_function(x)
  Args:
    x: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()
Concrete function for tensor of floats:
ConcreteFunction my_function(x)
  Args:
    x: float32 Tensor, shape=(3,)
  Returns:
    float32 Tensor, shape=(3,)


In [9]:
# Concrete functions are callable
# Note: You won't normally do this, but instead just call the containing `Function`
cf = a_function.get_concrete_function(tf.constant(2))
print("Directly calling a concrete function:", cf(tf.constant(2)))

Directly calling a concrete function: tf.Tensor(1, shape=(), dtype=int32)


In [10]:
 # Define an identity layer with an eager side effect
class EagerLayer(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super(EagerLayer, self).__init__(**kwargs)
    # Do some kind of initialization here

  def call(self, inputs):
    print("\nCurrently running eagerly", str(datetime.now()))
    return inputs

# Create an override model to classify pictures, adding the custom layer
class SequentialModel(tf.keras.Model):
  def __init__(self):
    super(SequentialModel, self).__init__()
    self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
    self.dense_1 = tf.keras.layers.Dense(128, activation="relu")
    self.dropout = tf.keras.layers.Dropout(0.2)
    self.dense_2 = tf.keras.layers.Dense(10)
    self.eager = EagerLayer()

  def call(self, x):
    x = self.flatten(x)
    x = self.dense_1(x)
    x = self.dropout(x)
    x = self.dense_2(x)
    return self.eager(x)

# Create an instance of this model
model = SequentialModel()

# Generate some nonsense pictures and labels
input_data = tf.random.uniform([60, 28, 28])
labels = tf.random.uniform([60])

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [11]:
model.compile(run_eagerly=False, loss=loss_fn)

In [12]:
model.fit(input_data, labels, epochs=3)
# 为啥有两次trace，我就不清楚了

Epoch 1/3

Currently running eagerly 2020-10-05 08:10:30.720751

Currently running eagerly 2020-10-05 08:10:30.804063
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7f09a01c40f0>

In [13]:
print("Running eagerly")
# When compiling the model, set it to run eagerly
model.compile(run_eagerly=True, loss=loss_fn)

model.fit(input_data, labels, epochs=1)
# 一个epoch也会触发两次运算，不知为啥

Running eagerly

Currently running eagerly 2020-10-05 08:10:44.247480
Currently running eagerly 2020-10-05 08:10:44.261138


<tensorflow.python.keras.callbacks.History at 0x7f09a0317be0>

In [15]:
eager_layer=EagerLayer()
eager_layer(1)
eager_layer(1)


Currently running eagerly 2020-10-05 08:12:15.821909

Currently running eagerly 2020-10-05 08:12:15.822113


<tf.Tensor: shape=(), dtype=int32, numpy=1>

In [16]:
# Now, globally set everything to run eagerly
tf.config.run_functions_eagerly(True)
print("Run all functions eagerly.")

# Create a polymorphic function
polymorphic_function = tf.function(model)

print("Tracing")
# This does, in fact, trace the function
print(polymorphic_function.get_concrete_function(input_data))

print("\nCalling twice eagerly")
# When you run the function again, you will see the side effect
# twice, as the function is running eagerly.
result = polymorphic_function(input_data)
result = polymorphic_function(input_data)

Run all functions eagerly.
Tracing

Currently running eagerly 2020-10-05 08:16:24.609254
ConcreteFunction function(self)
  Args:
    self: float32 Tensor, shape=(60, 28, 28)
  Returns:
    float32 Tensor, shape=(60, 10)

Calling twice eagerly

Currently running eagerly 2020-10-05 08:16:24.613181

Currently running eagerly 2020-10-05 08:16:24.614320


In [2]:
# Use @tf.function decorator
# tf.config.run_functions_eagerly(False)
@tf.function
def a_function_with_python_side_effect(x):
  print("Tracing!")  # This eager
  return x * x + tf.constant(2)

# This is traced the first time
print(a_function_with_python_side_effect(tf.constant(2)))
# The second time through, you won't see the side effect
print(a_function_with_python_side_effect(tf.constant(3)))

# This retraces each time the Python argument changes,
# as a Python argument could be an epoch count or other
# hyperparameter
print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

# 不会再tracing，因为之前接收过这种类型的输入，已经trace过了
print(a_function_with_python_side_effect(tf.constant(3)))

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)
Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)


In [None]:
# trace，直面翻译是追踪，功能是由python的动态类型方法转换成graph里的静态类型方法。