In TensorFlow 2, eager execution is turned on by default.

- `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflows graphs out of Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.

- The main takeaways and recommendations are:

    1. Debug in eager mode, then decorate with `@tf.function`.
    2. Don't rely on Python side effects like object mutation or list appends.
    3. `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.

### Setup

In [1]:
import contextlib
import traceback

import tensorflow as tf

2025-04-09 20:25:41.307766: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-09 20:25:41.314216: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-04-09 20:25:41.334116: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744210541.367803   20415 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744210541.377940   20415 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744210541.400663   20415 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linkin

Define a helper function to demonstrate the kinds of errors you might encounter:

In [2]:
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raise(error_class):
    try:
        yield
    except error_class as e:
        print("Caught expected exception \n {}:".format(error_class))
        traceback.print_exc(limit=2)
    except Exception as e:
        raise e
    else:
        raise Exception("Expected {} to be raised by error was raised!".format(error_class))

### Basics

#### Usage

A `tf.function` that you define (for example by applying the `@tf.function` decorator) is just like a core TensorFlow operation: You can executed it eagerly; you can compute gradients; and so on.

In [3]:
# The decorator converts `add` into a `PolymorphicFunction`
@tf.function
def add(a, b):
    return a + b

# [[2., 2.], [2., 2.]]
add(tf.ones([2, 2]), tf.ones([2, 2]))

2025-04-09 20:25:46.030565: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [4]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
    result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

`tf.function`s inside other `tf.function`s.

In [5]:
@tf.function
def dense_layer(x, w, b):
    return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

`tf.function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.

In [6]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
    return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image)
conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

Eager conv: 0.6259469399997215
Function conv: 0.5865928770003848
Note how there's not much difference in performance for convolutions


#### Tracing

##### What is "tracing"?

A `tf.function` runs program in a TensorFlow Graph. However, a `tf.Graph` cannot represent all the things that you'd write in an eager TensorFlor program. For instance, Python support polymorphism, but `tf.Graph` requires its inputs to have a specified data type and dimension. Or performan side tasks like reading command-line arguments, raising an error, or working with a more complex Python object; none of these things can run in a `tf.Graph`.

`tf.function` bridges this gap by separating your code in two strages:

1. In the first stage, referred to as **"tracing"**, `tf.function` creates a new `tf.Graph`. Python code runs normally, but all TensorFlow operations (like adding two Tensors) are *defered*: they are captured by the `tf.Graph` and not run.
2. In the second stage, a `tf.Graph` which contains everything that was deferred in the first stage is run. This stage is much faster than the tracing stage.

Depending on its inputs, `tf.function` will not always run the first stage when it is called (see `Rules of tracing`). Skipping the first stage and only executing the second stage is what gives TensorFlow's high performance.

When `tf.function` does decide to trace, the tracing stage is immediately followed by the second stage, so calling the `tf.function` both creates and runs the `tf.Graph`.

When passing arguments of different types into a `tf.function`, both stages are run:

In [7]:
@tf.function
def double(a):
    print("Tracing with", a)
    return a + a

In [8]:
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant('a')))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



If `tf.function` repeatedly called with the same argument type, TensorFlow will skip the tracing stage and reuse a previously traced graph, as the generated graph would be identical.

In [9]:
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))

tf.Tensor(b'bb', shape=(), dtype=string)


Using `pretty_printed_concrete_signatures` to see all the available traces:

In [11]:
print(double.pretty_printed_concrete_signatures())

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.int32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.int32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.float32, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.float32, name=None)
Captures:
  None

Input Parameters:
  a (POSITIONAL_OR_KEYWORD): TensorSpec(shape=(), dtype=tf.string, name=None)
Output Type:
  TensorSpec(shape=(), dtype=tf.string, name=None)
Captures:
  None


`tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic.

- A `tf.Graph` is the raw, language-agnostic, portable representation of a TensorFlow computation.
- Tracing is the process through which new `tf.Graph`s are generated from Python code.
- An instance of `tf.Graph` is specialized to the specific input types it was traced with. Differing types require retracing.
- Each traced `tf.Graph` has a corresponding `ConcreteFunction`.
- A `tf.function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs.
- `tf.function` wraps the Python function that will be traced, returning a `tf.types.experimental.PolymorphicFunction` object.

##### Rules of tracing

