Skip to content

Commit

Permalink
Create the failure test when tf.SavedModel miss the XLACallModule fun…
Browse files Browse the repository at this point in the history
…ction_list after loading.

PiperOrigin-RevId: 554726455
  • Loading branch information
maxwillzq authored and jax authors committed Aug 8, 2023
1 parent d17adde commit dec2366
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests for call_tf."""
from functools import partial
import os
from typing import Callable
import unittest

Expand Down Expand Up @@ -1597,6 +1598,50 @@ def tf_f_2(x):
x = np.arange(3, dtype=np.int32)
_ = tf.function(tf_f_2, autograph=False).get_concrete_function(x)

# TODO(b/293927250): call_tf_graph=True only accept concrete_function. The
# workaround here is to set `module.call=concrete_fn.`.
@unittest.skip(
"The root cause here is because the XLACallModule.function_list attribute"
" depends on JAX call_tf lowering. The 2nd time tf.SavedModel TF tracing"
" will not trigger call_tf tracing since it was already cached. The"
" solution is to create the `CallTFContext` to make TF tracing and JAX"
" tracing work together correctly."
)
def test_call_tf_graph_save_and_load(self):
def jax_func(x):
def func_tf(x):
return tf.math.sin(x)

return jnp.cos(
jax2tf.call_tf(func_tf, output_shape_dtype=x, call_tf_graph=True)(x)
)
data_inputs = (np.array([0.5, 0.7], dtype=np.float32),)

def tf_func(the_input):
res = jax2tf.convert(jax_func, native_serialization=True)(the_input)
return tf.identity(res, name="the_result")

jit_tf_func = tf.function(
tf_func,
autograph=False,
jit_compile=True,
)
# The next line is necessary to reproduce this issue. It trigger TF
# ConcreteFunction tracing. Otherwise, you will fail with another error
# `Found zero restored functions for caller function`.
_ = jit_tf_func.get_concrete_function(*data_inputs)
module = tf.Module()
module.call = jit_tf_func # Switching to concrete_function works.
root_dir = self.create_tempdir()
saved_model_dir = os.path.join(root_dir, "saved_model")
tf.saved_model.save(
module,
saved_model_dir,
options=tf.saved_model.SaveOptions(experimental_custom_gradients=False),
)
loaded_model = tf.saved_model.load(saved_model_dir)
res = loaded_model.call(*data_inputs)
self.assertAllClose(jax_func(*data_inputs), res)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit dec2366

Please sign in to comment.