Skip to content

Commit

Permalink
[jax2tf] Add a backward compatibility test for tf.call_tf_function
Browse files Browse the repository at this point in the history
`tf.call_tf_function` arises from `jax2tf.call_tf(tf_fun, call_tf_graph)`. However, a function that contains this can be lowered and executed only with `jax2tf.convert` and ought to be serialized as ` tf.Graph` because the serialization includes a tf.function as well.

In order to support this we need to add some code to back_compat_test.py to serialize and run the serialized code as tf.Graph.

PiperOrigin-RevId: 537062963
  • Loading branch information
gnecula authored and jax authors committed Jun 1, 2023
1 parent ae9d149 commit 37e254e
Show file tree
Hide file tree
Showing 3 changed files with 487 additions and 37 deletions.
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -581,6 +581,7 @@ def _check_lowering(lowering) -> None:
"LuDecomposition",
# ApproxTopK on TPU
"ApproxTopK",
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
]

def _check_module(mod: ir.Module, *,
Expand Down
163 changes: 126 additions & 37 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Expand Up @@ -43,18 +43,21 @@ def func(...): ...
The test will fail, but will save to a file the test data you will need. The
file name will be printed in the logs. Create a new
file ./back_compat_testdata/cuda_foo_call.py and paste the test data that
file ./back_compat_testdata/foo_call.py and paste the test data that
you will see printed in the logs. You may want to
edit the serialization string to remove any pathnames that may be included at
the end, or gxxxxx3 at the beginning.
Name the literal `data_YYYYY_MM_DD` to include the date of serializaton
(for readability only). Then add here:
(for readability only). Then add to this file:
from jax.experimental.jax2tf.tests.back_compat_testdata import foo_call
then update `test_custom_call_coverage`, and then update your `test_foo_call`:
def test_foo_call(self):
def func(...): ...
data = load_testdata(foo_call.data_YYYY_MM_DD)
data = load_testdata(foo_call.data_YYYY_MM_DD) # <-- this is new
self.run_one_test(func, data)
"""
Expand All @@ -81,13 +84,15 @@ def func(...): ...
from jax import core
from jax import lax
from jax import tree_util
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_geqrf
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_syev
from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_threefry2x32
from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Eigh
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Lu
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_ApproxTopK
Expand All @@ -101,10 +106,14 @@ def func(...): ...
from jax.sharding import Mesh
from jax.sharding import PartitionSpec as P

from jax._src.lib import xla_extension
from jax._src import test_util as jtu
from jax._src.interpreters import pxla
from jax._src import xla_bridge as xb

import tensorflow as tf
from tensorflow.core.framework import graph_pb2 # type: ignore[import]

config.parse_flags_with_absl()


Expand Down Expand Up @@ -180,7 +189,8 @@ class CompatTest(jtu.JaxTestCase):
def run_one_test(self, func: Callable[..., jax.Array],
data: CompatTestData,
rtol = None,
check_results: Optional[Callable[..., None]] = None):
check_results: Optional[Callable[..., None]] = None,
use_tf_graph = False):
"""Run one compatibility test.
Args:
Expand All @@ -189,29 +199,77 @@ def run_one_test(self, func: Callable[..., jax.Array],
rtol: relative tolerance for numerical comparisons
check_results: invoked with the results obtained from running the
serialized code, and those stored in the test data, and the kwarg rtol.
use_tf_graph: if False (default), uses jax_export to serialize JAX
functions and to invoke them. If True then uses tf.Graph to serialize
and run the functions; expects that `func` contains a `jax2tf.call_tf`
and uses `jax2tf.convert` to generate a tf.Graph containing a
XlaCallModule with the actual MLIR module.
"""
if not isinstance(data, CompatTestData):
raise ValueError(f"Expecting data: CompatTestData but got {data}. "
"Did you forget to `load_testdata`?")

if default_jax_backend() != data.platform:
self.skipTest(f"Test enabled only for {data.platform}")

logging.info("Running the function at the current version")
# Check that it runs in JAX native
res_from_jax_run_now = jax.jit(func)(*data.inputs)
if not use_tf_graph:
tf_func = None
jax_func_to_export = func
res_from_jax_run_now = jax.jit(func)(*data.inputs)
else:
# Is there a better way to serialize/deserialize TF functions? I thought
# about using tf.saved_model, but then we have to zip/unzip a whole
# directory.
@tf.function(autograph=False, jit_compile=True)
def tf_func(the_input): # Use recognizeable names for input and result
res = jax2tf.convert(func, native_serialization=True)(the_input)
return tf.identity(res, name="the_result")
jax_func_to_export = jax2tf.call_tf(tf_func) # type: ignore
res_from_jax_run_now = tf_func(*data.inputs) # type: ignore

if not isinstance(res_from_jax_run_now, (list, tuple)):
res_from_jax_run_now = (res_from_jax_run_now,)
res_from_jax_run_now = tuple(np.array(a) for a in res_from_jax_run_now)
logging.info("Result of current version run is %s", res_from_jax_run_now)

if not use_tf_graph:
# Use the native exporter, to make sure we get the proper serialization.
exported = jax_export.export(
jax.jit(jax_func_to_export),
lowering_platform=default_jax_backend(),
# Must turn off strict checks to allow custom calls.
strict_checks=False
)(*(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in data.inputs))

module_str = str(exported.mlir_module)
serialized = exported.mlir_module_serialized
module_version = exported.xla_call_module_version
else:
# We serialize as a tf.Graph
assert len(data.inputs) == 1 # We only support a single input now
tf_graph = tf_func.get_concrete_function(*data.inputs).graph
for op in tf_graph.get_operations():
if op.type == "XlaCallModule":
serialized_module = op.get_attr("module")
module_str = xla_extension.mlir.deserialize_portable_artifact(
serialized_module)
module_version = op.get_attr("version")
break
else:
raise ValueError("Cannot find an XlaCallModule")
tf_graph_def = tf_graph.as_graph_def()
# module_str is just for human readability, add both the MLIR module
# and the tf.Graph
module_str = ("# First the MLIR module:\n" + module_str +
"\n# Then the tf.Graph:\n" + str(tf_graph_def))
serialized = tf_graph_def.SerializeToString()

# Use the native exporter, to make sure we get the proper serialized module.
exported = jax_export.export(
jax.jit(func),
lowering_platform=default_jax_backend(),
# Must turn off strict checks because the custom calls may be unallowed.
strict_checks=False
)(*(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in data.inputs))

module_str = str(exported.mlir_module)
custom_call_re = r"stablehlo.custom_call\s*@([^\(]+)\("
custom_call_targets = sorted(
list(set(re.findall(custom_call_re, module_str)))
)
list(set(re.findall(custom_call_re, module_str))))

np.set_printoptions(threshold=sys.maxsize, floatmode="unique")
# Print the test data to simplify updating the test
updated_testdata = f"""
Expand All @@ -224,10 +282,13 @@ def run_one_test(self, func: Callable[..., jax.Array],
inputs={repr(data.inputs)},
expected_outputs={repr(res_from_jax_run_now)},
mlir_module_text=r\"\"\"\n{module_str}\"\"\",
mlir_module_serialized={repr(exported.mlir_module_serialized)},
xla_call_module_version={exported.xla_call_module_version},
mlir_module_serialized={repr(serialized)},
xla_call_module_version={module_version},
) # End paste
"""
# Replace the word that should not appear.
updated_testdata = re.sub(r"google.", "googlex", updated_testdata)
output_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR",
"/tmp/back_compat_testdata")
if not os.path.exists(output_dir):
Expand All @@ -244,15 +305,18 @@ def run_one_test(self, func: Callable[..., jax.Array],
else:
self.assertAllClose(res_from_jax_run_now, data.expected_outputs, rtol=rtol)

res_from_serialized_run_now = self.run_serialized(data)
logging.info("Running the serialized module")
res_from_serialized_run_now = self.run_serialized(data,
use_tf_graph=use_tf_graph)
logging.info("Result of serialized run is %s", res_from_serialized_run_now)
if check_results is not None:
check_results(res_from_serialized_run_now, data.expected_outputs, rtol=rtol)
else:
self.assertAllClose(res_from_serialized_run_now, data.expected_outputs, rtol=rtol)
self.assertListEqual(custom_call_targets, data.custom_call_targets)

def run_serialized(self, data: CompatTestData):
def run_serialized(self, data: CompatTestData,
use_tf_graph=False):
def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
return core.ShapedArray(a.shape, a.dtype)
in_avals_tree = tree_util.tree_map(ndarray_to_aval, data.inputs)
Expand All @@ -262,25 +326,37 @@ def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray:
out_avals, out_tree = tree_util.tree_flatten(out_avals_tree)
def _get_vjp(_):
assert False # We do not have and do not need VJP
exported = jax_export.Exported(
fun_name="run_serialized",
in_tree=in_tree,
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
lowering_platform=data.platform,
strict_checks=True,
mlir_module_serialized=data.mlir_module_serialized,
xla_call_module_version=data.xla_call_module_version,
module_kept_var_idx=tuple(range(len(in_avals))),
module_uses_dim_vars=any(not core.is_constant_shape(a.shape)
for a in in_avals),
if not use_tf_graph:
exported = jax_export.Exported(
fun_name="run_serialized",
in_tree=in_tree,
in_avals=tuple(in_avals),
out_tree=out_tree,
out_avals=tuple(out_avals),
in_shardings=(pxla.UNSPECIFIED,) * len(in_avals),
out_shardings=(pxla.UNSPECIFIED,) * len(out_avals),
lowering_platform=data.platform,
strict_checks=True,
mlir_module_serialized=data.mlir_module_serialized,
xla_call_module_version=data.xla_call_module_version,
module_kept_var_idx=tuple(range(len(in_avals))),
module_uses_dim_vars=any(not core.is_constant_shape(a.shape)
for a in in_avals),
_get_vjp=_get_vjp)

# We use pjit in case there are shardings in the exported module.
return pjit.pjit(jax_export.call_exported(exported))(*data.inputs)
# We use pjit in case there are shardings in the exported module.
return pjit.pjit(jax_export.call_exported(exported))(*data.inputs)
else:
loaded_f_tf_graph = graph_pb2.GraphDef()
loaded_f_tf_graph.ParseFromString(data.mlir_module_serialized)

@tf.function(autograph=False)
def loaded_fun(x):
result = tf.import_graph_def(loaded_f_tf_graph,
input_map={"the_input":x},
return_elements=["the_result:0"])
return result[0]
return (loaded_fun(*data.inputs).numpy(),)

def test_dummy(self):
# Tests the test mechanism. Let this test run on all platforms
Expand Down Expand Up @@ -314,6 +390,7 @@ def test_custom_call_coverage(self):
cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17,
cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15,
cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17,
tf_call_tf_function.data_2023_05_31,
tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17,
tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17,
tpu_ApproxTopK.data_2023_05_16]
Expand Down Expand Up @@ -504,6 +581,18 @@ def func(x): # b: f32[2, 4]
with mesh:
self.run_one_test(func, data)

def test_tf_call_tf_function(self):
# A custom call tf.call_tf_function is generated when we lower call_tf
# with the call_tf_graph=True option.
def func(x):
def func_tf(x):
return tf.math.sin(x)
return jnp.cos(jax2tf.call_tf(func_tf, call_tf_graph=True,
output_shape_dtype=x)(x))

data = load_testdata(tf_call_tf_function.data_2023_05_31)
self.run_one_test(func, data, use_tf_graph=True)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 37e254e

Please sign in to comment.