In [1]:
import tensorflow as tf

In [2]:
import traceback
import contextlib


# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
    try:
        yield
    except error_class as e:
        print('Caught expected exception \n  {}:'.format(error_class))
        traceback.print_exc(limit=2)
    except Exception as e:
        raise e
    else:
        raise Exception('Expected {} to be raised but no error was raised!'.format(
            error_class))

In [3]:
@tf.function  # The decorator converts `add` into a `Function`.
def add(a, b):
    return a + b


add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [4]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

In [5]:
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)


dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

In [5]:
import timeit

conv_layer = tf.keras.layers.Conv2D(100, 3)


@tf.function
def conv_fn(image):
    return conv_layer(image)


image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image)
conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")


Eager conv: 0.003970499999994104
Function conv: 0.004773100000001307
Note how there's not much difference in performance for convolutions


In [6]:
@tf.function
def double(a):
    print("Tracing with", a)
    return a + a


print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



In [7]:
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))

tf.Tensor(b'bb', shape=(), dtype=string)


In [8]:
print(double.pretty_printed_concrete_signatures())

double(a)
  Args:
    a: int32 Tensor, shape=()
  Returns:
    int32 Tensor, shape=()

double(a)
  Args:
    a: float32 Tensor, shape=()
  Returns:
    float32 Tensor, shape=()

double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()


In [9]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32), ))
def next_collatz(x):
    print("Tracing with", x)
    return tf.where(x % 2 == 0, x // 2, 3 * x + 1)


print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
    next_collatz(tf.constant([[1, 2], [3, 4]]))

# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
    next_collatz(tf.constant([1.0, 2.0]))


Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:
Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\1159894030.py", line 9, in assert_raises
    yield
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\4031220677.py", line 10, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Traceback (most recent call last):
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\1159894030.py", line 9, in assert_raises
    yield
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\4031220677.py", line 14, in <module>
    next_collatz(tf.constant([1.0, 2.0]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor([1. 2.], shape=(2,), dtype=float32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None)).


In [10]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32), ))
def g(x):
    print('Tracing with', x)
    return x


# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))

Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([1 2 3], shape=(3,), dtype=int32)
tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)


In [11]:
def train_one_step():
    pass


@tf.function
def train(num_steps):
    print("Tracing with num_steps = ", num_steps)
    tf.print("Executing with num_steps = ", num_steps)
    for _ in tf.range(num_steps):
        train_one_step()


print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)

print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

Retracing occurs for different Python arguments.
Tracing with num_steps =  10
Executing with num_steps =  10
Tracing with num_steps =  20
Executing with num_steps =  20

Traces are reused for Tensor arguments.
Tracing with num_steps =  Tensor("num_steps:0", shape=(), dtype=int32)
Executing with num_steps =  10
Executing with num_steps =  20


In [12]:
def f():
    print('Tracing!')
    tf.print('Executing')


tf.function(f)()
tf.function(f)()

Tracing!
Executing
Tracing!
Executing


In [13]:
@tf.function
def get_mixed_flavor(fruit_a, fruit_b):
    return fruit_a.flavor + fruit_b.flavor


class Fruit:
    flavor = tf.constant([0, 0])


class Apple(Fruit):
    flavor = tf.constant([1, 2])


class Mango(Fruit):
    flavor = tf.constant([3, 4])


# As described in the above rules, a generic TraceType for `Apple` and `Mango`
# is generated (and a corresponding ConcreteFunction is traced) but it fails to
# match the second function call since the first pair of Apple() and Mango()
# have gone out out of scope by then and deleted.
get_mixed_flavor(Apple(), Mango())  # Traces a new concrete function
get_mixed_flavor(Apple(), Mango())  # Traces a new concrete function again

# However, each subclass of the `Fruit` class has a fixed flavor, and you
# can reuse an existing traced concrete function if it was the same
# subclass. Avoiding such unnecessary tracing of concrete functions
# can have significant performance benefits.


class FruitTraceType(tf.types.experimental.TraceType):

    def __init__(self, fruit_type):
        self.fruit_type = fruit_type

    def is_subtype_of(self, other):
        return (type(other) is FruitTraceType
                and self.fruit_type is other.fruit_type)

    def most_specific_common_supertype(self, others):
        return self if all(self == other for other in others) else None

    def __eq__(self, other):
        return type(
            other) is FruitTraceType and self.fruit_type == other.fruit_type

    def __hash__(self):
        return hash(self.fruit_type)


class FruitWithTraceType:

    def __tf_tracing_type__(self, context):
        return FruitTraceType(type(self))


class AppleWithTraceType(FruitWithTraceType):
    flavor = tf.constant([1, 2])


class MangoWithTraceType(FruitWithTraceType):
    flavor = tf.constant([3, 4])


# Now if you try calling it again:
get_mixed_flavor(AppleWithTraceType(),
                 MangoWithTraceType())  # Traces a new concrete function
get_mixed_flavor(AppleWithTraceType(),
                 MangoWithTraceType())  # Re-uses the traced concrete function


<tf.Tensor: shape=(2,), dtype=int32, numpy=array([4, 6])>

In [14]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))

Obtaining concrete trace
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)


In [15]:
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(
    tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))

tf.Tensor(b'cc', shape=(), dtype=string)


In [16]:
print(double_strings)

ConcreteFunction double(a)
  Args:
    a: string Tensor, shape=()
  Returns:
    string Tensor, shape=()


In [17]:
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)

((TensorSpec(shape=(), dtype=tf.string, name='a'),), {})
Tensor("Identity:0", shape=(), dtype=string)


In [18]:
with assert_raises(tf.errors.InvalidArgumentError):
    double_strings(tf.constant(1))


Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:


Traceback (most recent call last):
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\1159894030.py", line 9, in assert_raises
    yield
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\2585775733.py", line 2, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_141 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_141]


In [19]:
@tf.function
def pow(a, b):
    return a**b


square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)

ConcreteFunction pow(a, b=2)
  Args:
    a: float32 Tensor, shape=<unknown>
  Returns:
    float32 Tensor, shape=<unknown>


In [20]:
assert square(tf.constant(10.0)) == 100

with assert_raises(TypeError):
    square(tf.constant(10.0), b=3)


Caught expected exception 
  <class 'TypeError'>:


Traceback (most recent call last):
  File "d:\conda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1617, in _call_impl
    return self._call_with_flat_signature(args, kwargs,
  File "d:\conda3\envs\tf\lib\site-packages\tensorflow\python\eager\function.py", line 1662, in _call_with_flat_signature
    raise TypeError(f"{self._flat_signature_summary()} got unexpected "
TypeError: pow(a) got unexpected keyword arguments: b.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\1159894030.py", line 9, in assert_raises
    yield
  File "C:\Users\happy\AppData\Local\Temp\ipykernel_15120\2991996595.py", line 4, in <module>
    square(tf.constant(10.0), b=3)
TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.


In [21]:
graph = double_strings.graph
for node in graph.as_graph_def().node:
    print(f'{node.input} -> {node.name}')


[] -> a
['a', 'a'] -> add
['add'] -> Identity


In [22]:
# A simple loop


@tf.function
def f(x):
    while tf.reduce_sum(x) > 1:
        tf.print(x)
        x = tf.tanh(x)
    return x


f(tf.random.uniform([5]))

[0.281349301 0.302678704 0.061591506 0.982579 0.350352168]
[0.274153411 0.293762058 0.0615137368 0.754180193 0.336687833]
[0.26748535 0.285593718 0.061436262 0.637636244 0.32451725]
[0.261283368 0.278074205 0.0613590814 0.563287914 0.313585699]
[0.255495489 0.27112174 0.0612821914 0.510412872 0.303695619]
[0.250077486 0.264668286 0.0612055846 0.470266789 0.294690937]
[0.244991496 0.25865671 0.0611292645 0.438414872 0.286446601]
[0.240204811 0.253038675 0.0610532276 0.412329614 0.278860956]
[0.235689193 0.247772902 0.0609774776 0.390448898 0.271850526]
[0.231419861 0.242823988 0.0609020069 0.371747166 0.265345901]
[0.227375224 0.23816134 0.0608268119 0.355518967 0.259288877]
[0.223536208 0.233758301 0.0607519 0.34126091 0.253630251]
[0.219885901 0.229591593 0.0606772639 0.328602582 0.248328075]
[0.216409296 0.225640744 0.0606029034 0.317264557 0.243346363]
[0.213093013 0.221887738 0.0605288148 0.307031393 0.238654]
[0.2099251 0.21831654 0.0604549944 0.297734022 0.234223977]
[0.206894785

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.20399237, 0.21166414, 0.06030817, 0.2814329 , 0.22605935],
      dtype=float32)>

In [23]:
print(tf.autograph.to_code(f.python_function))

def tf__f(x):
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
        do_return = False
        retval_ = ag__.UndefinedReturnValue()

        def get_state():
            return (x,)

        def set_state(vars_):
            nonlocal x
            (x,) = vars_

        def loop_body():
            nonlocal x
            ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope)
            x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope)

        def loop_test():
            return ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = ag__.ld(x)
        except:
            do_return = False
            raise
        return fscope.ret(retval_, do_return)



In [24]:
@tf.function
def fizzbuzz(n):
    for i in tf.range(1, n + 1):
        print('Tracing for loop')
        if i % 15 == 0:
            print('Tracing fizzbuzz branch')
            tf.print('fizzbuzz')
        elif i % 3 == 0:
            print('Tracing fizz branch')
            tf.print('fizz')
        elif i % 5 == 0:
            print('Tracing buzz branch')
            tf.print('buzz')
        else:
            print('Tracing default branch')
            tf.print(i)


fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))

Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz


In [25]:
def measure_graph_size(f, *args):
    g = f.get_concrete_function(*args).graph
    print("{}({}) contains {} nodes in its graph".format(
        f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))


@tf.function
def train(dataset):
    loss = tf.constant(0)
    for x, y in dataset:
        loss += tf.abs(y - x)  # Some dummy computation.
    return loss


small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(
    train,
    tf.data.Dataset.from_generator(lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(
    train,
    tf.data.Dataset.from_generator(lambda: big_data, (tf.int32, tf.int32)))


train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
train(<FlatMapDataset element_spec=(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None), TensorSpec(shape=<unknown>, dtype=tf.int32, name=None))>) contains 6 nodes in its graph
