In [3]:
import inspect
import time
import numpy as np
import tensorflow as tf
from pprint import pprint
print(tf.__version__)
tf.random.set_seed(42)
np.random.seed(42)
true_weights = tf.constant(list(range(5)), dtype=tf.float32)[:, tf.newaxis]
x = tf.constant(tf.random.uniform((32, 5)), dtype=tf.float32)
y = tf.constant(x @ true_weights, dtype=tf.float32)

2.2.0


In [4]:
def f(a, b, power=2, d=3):
    return tf.pow(a, power) + d * b
converted_f = tf.autograph.to_graph(f)
print(inspect.getsource(converted_f))

        def tf__f(a, b, power=None, d=None):
            do_return = False
            retval_ = ag__.UndefinedReturnValue()
            with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
                try:
                    do_return = True
                    retval_ = fscope.mark_return_value((ag__.converted_call(tf.pow, (a, power), None, fscope) + (d * b)))
                except:
                    do_return = False
                    raise
            (do_return,)
            return ag__.retval(retval_)



In [6]:
def cube(x):
    o = x
    for _ in range(2):
        o *= x
    return o

converted_cube = tf.autograph.to_graph(cube)
print(inspect.getsource(converted_cube))

        def tf__cube(x):
            do_return = False
            retval_ = ag__.UndefinedReturnValue()
            with ag__.FunctionScope('cube', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:
                o = x

                def get_state():
                    return (o,)

                def set_state(loop_vars):
                    nonlocal o
                    (o,) = loop_vars

                def loop_body(itr):
                    nonlocal o
                    _ = itr
                    o *= x
                ag__.for_stmt(ag__.converted_call(range, (2,), None, fscope), None, loop_body, get_state, set_state, ('o',), {})
                try:
                    do_return = True
                    retval_ = fscope.mark_return_value(o)
                except:
                    do_return = False
                    raise
            (do_return,)
            return ag__.retval(ret

In [8]:
def g(x):
    if tf.reduce_any(x < 0):
        return tf.square(x)
    return x
converted_g = tf.autograph.to_graph(g)
print(inspect.getsource(converted_g))

        def tf__g(x):
            do_return = False
            retval_ = ag__.UndefinedReturnValue()
            with ag__.FunctionScope('g', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

                def get_state():
                    return ()

                def set_state(loop_vars):
                    pass

                def if_true():
                    try:
                        do_return = True
                        retval_ = fscope.mark_return_value(ag__.converted_call(tf.square, (x,), None, fscope))
                    except:
                        do_return = False
                        raise
                    return (do_return, retval_)

                def if_false():
                    try:
                        do_return = True
                        retval_ = fscope.mark_return_value(x)
                    except:
                        do_return = False


In [11]:
tf_func_f = tf.function(autograph=False)(f)
tf_func_g = tf.function(autograph=False)(converted_g)
tf_func_g2 = tf.function(autograph=True)(g)
print(tf_func_f.python_function is f)
print(tf_func_g.python_function is converted_g)
print(tf_func_g2.python_function is g)

True
True
True


In [15]:
a = tf.function(autograph=False)(tf.autograph.to_graph(g))
# is roughly equivlent to
b = tf.function(autograph=True)(g)
print(a)
print(b)

<tensorflow.python.eager.def_function.Function object at 0x0000022D26E1CB88>
<tensorflow.python.eager.def_function.Function object at 0x0000022D26E1CEC8>


In [16]:
concrete_g = tf_func_g.get_concrete_function(x=tf.TensorSpec(shape=[3], dtype=tf.float32))
print(concrete_g)

<tensorflow.python.eager.function.ConcreteFunction object at 0x0000022D01F7EE48>


In [17]:
pprint(concrete_g(tf.constant([-1, 1, -2], dtype=tf.float32)))
pprint(tf_func_g(tf.constant([-1, 1, -2], dtype=tf.float32)))

<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 4.], dtype=float32)>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 4.], dtype=float32)>




In [19]:
concrete_f = tf_func_f.get_concrete_function(a=tf.TensorSpec(shape=[1], dtype=tf.float32), b=tf.TensorSpec(shape=[1], dtype=tf.float32))
print(concrete_f)
pprint(concrete_f(tf.constant(1.), tf.constant(2.)))
pprint(tf_func_f(1., 2.))
pprint(tf_func_f(a=tf.constant(1., dtype=tf.float32), b=2, power=2.))
pprint(tf_func_f(a=tf.constant(1., dtype=tf.float32), b=2., d=3))
pprint(tf_func_f(a=tf.constant(1., dtype=tf.float32), b=2., d=3., power=3.))
print(tf_func_f._get_tracing_count())

<tensorflow.python.eager.function.ConcreteFunction object at 0x0000022D17606B88>
<tf.Tensor: shape=(), dtype=float32, numpy=7.0>
<tf.Tensor: shape=(), dtype=float32, numpy=7.0>
<tf.Tensor: shape=(), dtype=float32, numpy=7.0>
<tf.Tensor: shape=(), dtype=float32, numpy=7.0>
<tf.Tensor: shape=(), dtype=float32, numpy=7.0>
4


In [20]:
for i, f in enumerate(tf_func_f._list_all_concrete_functions_for_serialization()):
    print(i, f.structured_input_signature)

0 ((TensorSpec(shape=(1,), dtype=tf.float32, name='a'), TensorSpec(shape=(1,), dtype=tf.float32, name='b'), 2, 3), {})
1 ((TensorSpec(shape=(), dtype=tf.float32, name='a'), 2, 2.0, 3), {})
2 ((1.0, 2.0, 2, 3), {})
3 ((TensorSpec(shape=(), dtype=tf.float32, name='a'), 2.0, 3.0, 3.0), {})


In [22]:
# The following two are equivalent.
@tf.function(autograph=False)
def square(x):
    return x * x

def square(x):
    return x * x
square = tf.function(autograph=False)(square)