From 37e254e982396c0ce786edadb55e366ef75d3cb6 Mon Sep 17 00:00:00 2001 From: George Necula Date: Thu, 1 Jun 2023 10:29:12 -0700 Subject: [PATCH] [jax2tf] Add a backward compatibility test for tf.call_tf_function `tf.call_tf_function` arises from `jax2tf.call_tf(tf_fun, call_tf_graph)`. However, a function that contains this can be lowered and executed only with `jax2tf.convert` and ought to be serialized as ` tf.Graph` because the serialization includes a tf.function as well. In order to support this we need to add some code to back_compat_test.py to serialize and run the serialized code as tf.Graph. PiperOrigin-RevId: 537062963 --- jax/experimental/jax2tf/jax_export.py | 1 + .../jax2tf/tests/back_compat_test.py | 163 ++++++-- .../tf_call_tf_function.py | 360 ++++++++++++++++++ 3 files changed, 487 insertions(+), 37 deletions(-) create mode 100644 jax/experimental/jax2tf/tests/back_compat_testdata/tf_call_tf_function.py diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 8cf5fe2b6d9e..61ee392dbe73 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -581,6 +581,7 @@ def _check_lowering(lowering) -> None: "LuDecomposition", # ApproxTopK on TPU "ApproxTopK", + "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) ] def _check_module(mod: ir.Module, *, diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index d24a0960be57..d8f831a3999d 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -43,18 +43,21 @@ def func(...): ... The test will fail, but will save to a file the test data you will need. The file name will be printed in the logs. Create a new -file ./back_compat_testdata/cuda_foo_call.py and paste the test data that +file ./back_compat_testdata/foo_call.py and paste the test data that you will see printed in the logs. You may want to edit the serialization string to remove any pathnames that may be included at the end, or gxxxxx3 at the beginning. Name the literal `data_YYYYY_MM_DD` to include the date of serializaton -(for readability only). Then add here: +(for readability only). Then add to this file: from jax.experimental.jax2tf.tests.back_compat_testdata import foo_call + +then update `test_custom_call_coverage`, and then update your `test_foo_call`: + def test_foo_call(self): def func(...): ... - data = load_testdata(foo_call.data_YYYY_MM_DD) + data = load_testdata(foo_call.data_YYYY_MM_DD) # <-- this is new self.run_one_test(func, data) """ @@ -81,6 +84,7 @@ def func(...): ... from jax import core from jax import lax from jax import tree_util +from jax.experimental import jax2tf from jax.experimental.jax2tf import jax_export from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_ducc_fft from jax.experimental.jax2tf.tests.back_compat_testdata import cpu_lapack_geqrf @@ -88,6 +92,7 @@ def func(...): ... from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_geqrf from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_cusolver_syev from jax.experimental.jax2tf.tests.back_compat_testdata import cuda_threefry2x32 +from jax.experimental.jax2tf.tests.back_compat_testdata import tf_call_tf_function from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Eigh from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Lu from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_ApproxTopK @@ -101,10 +106,14 @@ def func(...): ... from jax.sharding import Mesh from jax.sharding import PartitionSpec as P +from jax._src.lib import xla_extension from jax._src import test_util as jtu from jax._src.interpreters import pxla from jax._src import xla_bridge as xb +import tensorflow as tf +from tensorflow.core.framework import graph_pb2 # type: ignore[import] + config.parse_flags_with_absl() @@ -180,7 +189,8 @@ class CompatTest(jtu.JaxTestCase): def run_one_test(self, func: Callable[..., jax.Array], data: CompatTestData, rtol = None, - check_results: Optional[Callable[..., None]] = None): + check_results: Optional[Callable[..., None]] = None, + use_tf_graph = False): """Run one compatibility test. Args: @@ -189,29 +199,77 @@ def run_one_test(self, func: Callable[..., jax.Array], rtol: relative tolerance for numerical comparisons check_results: invoked with the results obtained from running the serialized code, and those stored in the test data, and the kwarg rtol. + use_tf_graph: if False (default), uses jax_export to serialize JAX + functions and to invoke them. If True then uses tf.Graph to serialize + and run the functions; expects that `func` contains a `jax2tf.call_tf` + and uses `jax2tf.convert` to generate a tf.Graph containing a + XlaCallModule with the actual MLIR module. """ + if not isinstance(data, CompatTestData): + raise ValueError(f"Expecting data: CompatTestData but got {data}. " + "Did you forget to `load_testdata`?") + if default_jax_backend() != data.platform: self.skipTest(f"Test enabled only for {data.platform}") + logging.info("Running the function at the current version") # Check that it runs in JAX native - res_from_jax_run_now = jax.jit(func)(*data.inputs) + if not use_tf_graph: + tf_func = None + jax_func_to_export = func + res_from_jax_run_now = jax.jit(func)(*data.inputs) + else: + # Is there a better way to serialize/deserialize TF functions? I thought + # about using tf.saved_model, but then we have to zip/unzip a whole + # directory. + @tf.function(autograph=False, jit_compile=True) + def tf_func(the_input): # Use recognizeable names for input and result + res = jax2tf.convert(func, native_serialization=True)(the_input) + return tf.identity(res, name="the_result") + jax_func_to_export = jax2tf.call_tf(tf_func) # type: ignore + res_from_jax_run_now = tf_func(*data.inputs) # type: ignore + if not isinstance(res_from_jax_run_now, (list, tuple)): res_from_jax_run_now = (res_from_jax_run_now,) res_from_jax_run_now = tuple(np.array(a) for a in res_from_jax_run_now) + logging.info("Result of current version run is %s", res_from_jax_run_now) + + if not use_tf_graph: + # Use the native exporter, to make sure we get the proper serialization. + exported = jax_export.export( + jax.jit(jax_func_to_export), + lowering_platform=default_jax_backend(), + # Must turn off strict checks to allow custom calls. + strict_checks=False + )(*(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in data.inputs)) + + module_str = str(exported.mlir_module) + serialized = exported.mlir_module_serialized + module_version = exported.xla_call_module_version + else: + # We serialize as a tf.Graph + assert len(data.inputs) == 1 # We only support a single input now + tf_graph = tf_func.get_concrete_function(*data.inputs).graph + for op in tf_graph.get_operations(): + if op.type == "XlaCallModule": + serialized_module = op.get_attr("module") + module_str = xla_extension.mlir.deserialize_portable_artifact( + serialized_module) + module_version = op.get_attr("version") + break + else: + raise ValueError("Cannot find an XlaCallModule") + tf_graph_def = tf_graph.as_graph_def() + # module_str is just for human readability, add both the MLIR module + # and the tf.Graph + module_str = ("# First the MLIR module:\n" + module_str + + "\n# Then the tf.Graph:\n" + str(tf_graph_def)) + serialized = tf_graph_def.SerializeToString() - # Use the native exporter, to make sure we get the proper serialized module. - exported = jax_export.export( - jax.jit(func), - lowering_platform=default_jax_backend(), - # Must turn off strict checks because the custom calls may be unallowed. - strict_checks=False - )(*(jax.ShapeDtypeStruct(a.shape, a.dtype) for a in data.inputs)) - - module_str = str(exported.mlir_module) custom_call_re = r"stablehlo.custom_call\s*@([^\(]+)\(" custom_call_targets = sorted( - list(set(re.findall(custom_call_re, module_str))) - ) + list(set(re.findall(custom_call_re, module_str)))) + np.set_printoptions(threshold=sys.maxsize, floatmode="unique") # Print the test data to simplify updating the test updated_testdata = f""" @@ -224,10 +282,13 @@ def run_one_test(self, func: Callable[..., jax.Array], inputs={repr(data.inputs)}, expected_outputs={repr(res_from_jax_run_now)}, mlir_module_text=r\"\"\"\n{module_str}\"\"\", - mlir_module_serialized={repr(exported.mlir_module_serialized)}, - xla_call_module_version={exported.xla_call_module_version}, + mlir_module_serialized={repr(serialized)}, + xla_call_module_version={module_version}, ) # End paste + """ + # Replace the word that should not appear. + updated_testdata = re.sub(r"google.", "googlex", updated_testdata) output_dir = os.getenv("TEST_UNDECLARED_OUTPUTS_DIR", "/tmp/back_compat_testdata") if not os.path.exists(output_dir): @@ -244,7 +305,9 @@ def run_one_test(self, func: Callable[..., jax.Array], else: self.assertAllClose(res_from_jax_run_now, data.expected_outputs, rtol=rtol) - res_from_serialized_run_now = self.run_serialized(data) + logging.info("Running the serialized module") + res_from_serialized_run_now = self.run_serialized(data, + use_tf_graph=use_tf_graph) logging.info("Result of serialized run is %s", res_from_serialized_run_now) if check_results is not None: check_results(res_from_serialized_run_now, data.expected_outputs, rtol=rtol) @@ -252,7 +315,8 @@ def run_one_test(self, func: Callable[..., jax.Array], self.assertAllClose(res_from_serialized_run_now, data.expected_outputs, rtol=rtol) self.assertListEqual(custom_call_targets, data.custom_call_targets) - def run_serialized(self, data: CompatTestData): + def run_serialized(self, data: CompatTestData, + use_tf_graph=False): def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: return core.ShapedArray(a.shape, a.dtype) in_avals_tree = tree_util.tree_map(ndarray_to_aval, data.inputs) @@ -262,25 +326,37 @@ def ndarray_to_aval(a: np.ndarray) -> core.ShapedArray: out_avals, out_tree = tree_util.tree_flatten(out_avals_tree) def _get_vjp(_): assert False # We do not have and do not need VJP - exported = jax_export.Exported( - fun_name="run_serialized", - in_tree=in_tree, - in_avals=tuple(in_avals), - out_tree=out_tree, - out_avals=tuple(out_avals), - in_shardings=(pxla.UNSPECIFIED,) * len(in_avals), - out_shardings=(pxla.UNSPECIFIED,) * len(out_avals), - lowering_platform=data.platform, - strict_checks=True, - mlir_module_serialized=data.mlir_module_serialized, - xla_call_module_version=data.xla_call_module_version, - module_kept_var_idx=tuple(range(len(in_avals))), - module_uses_dim_vars=any(not core.is_constant_shape(a.shape) - for a in in_avals), + if not use_tf_graph: + exported = jax_export.Exported( + fun_name="run_serialized", + in_tree=in_tree, + in_avals=tuple(in_avals), + out_tree=out_tree, + out_avals=tuple(out_avals), + in_shardings=(pxla.UNSPECIFIED,) * len(in_avals), + out_shardings=(pxla.UNSPECIFIED,) * len(out_avals), + lowering_platform=data.platform, + strict_checks=True, + mlir_module_serialized=data.mlir_module_serialized, + xla_call_module_version=data.xla_call_module_version, + module_kept_var_idx=tuple(range(len(in_avals))), + module_uses_dim_vars=any(not core.is_constant_shape(a.shape) + for a in in_avals), _get_vjp=_get_vjp) - # We use pjit in case there are shardings in the exported module. - return pjit.pjit(jax_export.call_exported(exported))(*data.inputs) + # We use pjit in case there are shardings in the exported module. + return pjit.pjit(jax_export.call_exported(exported))(*data.inputs) + else: + loaded_f_tf_graph = graph_pb2.GraphDef() + loaded_f_tf_graph.ParseFromString(data.mlir_module_serialized) + + @tf.function(autograph=False) + def loaded_fun(x): + result = tf.import_graph_def(loaded_f_tf_graph, + input_map={"the_input":x}, + return_elements=["the_result:0"]) + return result[0] + return (loaded_fun(*data.inputs).numpy(),) def test_dummy(self): # Tests the test mechanism. Let this test run on all platforms @@ -314,6 +390,7 @@ def test_custom_call_coverage(self): cpu_ducc_fft.data_2023_03_17, cpu_lapack_syev.data_2023_03_17, cpu_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15, cuda_cusolver_geqrf.data_2023_03_18, cuda_cusolver_syev.data_2023_03_17, + tf_call_tf_function.data_2023_05_31, tpu_Eigh.data, tpu_Lu.data_2023_03_21, tpu_Qr.data_2023_03_17, tpu_Sharding.data_2023_03_16, tpu_ApproxTopK.data_2023_04_17, tpu_ApproxTopK.data_2023_05_16] @@ -504,6 +581,18 @@ def func(x): # b: f32[2, 4] with mesh: self.run_one_test(func, data) + def test_tf_call_tf_function(self): + # A custom call tf.call_tf_function is generated when we lower call_tf + # with the call_tf_graph=True option. + def func(x): + def func_tf(x): + return tf.math.sin(x) + return jnp.cos(jax2tf.call_tf(func_tf, call_tf_graph=True, + output_shape_dtype=x)(x)) + + data = load_testdata(tf_call_tf_function.data_2023_05_31) + self.run_one_test(func, data, use_tf_graph=True) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/tf_call_tf_function.py b/jax/experimental/jax2tf/tests/back_compat_testdata/tf_call_tf_function.py new file mode 100644 index 000000000000..e2e39cad8498 --- /dev/null +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/tf_call_tf_function.py @@ -0,0 +1,360 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + + +# Pasted from the test output (see back_compat_test.py module docstring) +data_2023_05_31 = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['tf.call_tf_function'], + serialized_date=datetime.date(2023, 5, 31), + inputs=(array([0.5, 0.7], dtype=float32),), + expected_outputs=(array([0.88726 , 0.79956985], dtype=float32),), + mlir_module_text=r""" +# First the MLIR module: +#loc = loc(unknown) +module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2xf32> {jax.arg_info = "x", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<2xf32> {jax.result_info = ""}) { + %0 = stablehlo.custom_call @tf.call_tf_function(%arg0) {api_version = 2 : i32, has_side_effect = true, tf.backend_config = {called_index = 0 : i64, called_name = "__inference_callable_flat_tf_10", has_token_input_output = false}} : (tensor<2xf32>) -> tensor<2xf32> loc(#loc2) + %1 = stablehlo.cosine %0 : tensor<2xf32> loc(#loc3) + return %1 : tensor<2xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":586:0) +#loc2 = loc("jit(func)/jit(main)/call_tf[callable_flat_tf=.make_call..callable_flat_tf at 0x7f565ae72170> function_flat_tf= args_flat_sig_tf=(TensorSpec(shape=(2,), dtype=tf.float32, name=None),) output_avals=(ShapedArray(float32[2]),) has_side_effects=True ordered=False call_tf_graph=True]"(#loc1)) +#loc3 = loc("jit(func)/jit(main)/cos"(#loc1)) + +# Then the tf.Graph: +node { + name: "the_input" + op: "Placeholder" + attr { + key: "_user_specified_name" + value { + s: "the_input" + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + } + } + } +} +node { + name: "jax2tf_arg_0" + op: "Identity" + input: "the_input" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "XlaSharding" + op: "XlaSharding" + input: "jax2tf_arg_0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_XlaSharding" + value { + s: "" + } + } + attr { + key: "sharding" + value { + s: "" + } + } + attr { + key: "unspecified_dims" + value { + list { + } + } + } +} +node { + name: "XlaCallModule" + op: "XlaCallModule" + input: "XlaSharding" + attr { + key: "Sout" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "Tin" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "Tout" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "dim_args_spec" + value { + list { + } + } + } + attr { + key: "function_list" + value { + list { + func { + name: "__inference_callable_flat_tf_10" + } + } + } + } + attr { + key: "has_token_input_output" + value { + b: false + } + } + attr { + key: "module" + value { + s: "ML\357R\001StableHLO_v0.9.0\000\001\031\005\001\003\001\003\005\003\t\007\t\013\r\003\207i\013\0019\007\017\013\027#\013\013\0133\013\013\013\013S\013\013\013\013\013\013\013\013\013\017\013\013\017\013\0031\013\013\017\033\013\013\013\013\013\017\023\013\013\013\013\013\013#\013\017\013\013\013\013\001\003\017\003\t\023\027\007\007\002\262\002\037\021\001\005\005\017\0273*\t\001\003\007\013\003\r\003\005\017\005\021\005\023\005\025\003\013\023=\025I\027K\005Q\031S\005\027\005\031\005\033\005\035\003\023\035U\037;!W#9%Y\'9)9+9-[\005\037\005!\005#\005%\005\'\005)\005+\005-\005/\0351\007\0051\0053\0357\007\0055\003\001\0357\003\003?\r\005ACEG\0359\035;\035=\035?#\005\003\003M\r\003O;\035A\035C\035E\013\005\035G\005\003\r\007]_aceg\035I\023\t\001\035K\035M\035O\005\001\001\002\002)\003\t\007\021\003\003\003\003\t\035\004Q\005\001\021\001\t\007\003\001\005\003\021\001\021\005\003\007\017\003\003\001\005\007/\033\003\003\003\001\007\0065\003\003\003\003\t\004\001\003\005\006\003\001\005\001\000\202\020Q/A\031\033)\017\013!\033\035\005\033\0031\203\312\006%\037/!!)#\037\031\037\025\035\025\023%)\023\025\025\037\021\017\013\021builtin\000vhlo\000module\000func_v1\000custom_call_v1\000cosine_v1\000return_v1\000sym_name\000mhlo.num_partitions\000mhlo.num_replicas\000jit_func\000arg_attrs\000function_type\000res_attrs\000sym_visibility\000api_version\000backend_config\000call_target_name\000called_computations\000has_side_effect\000operand_layouts\000output_operand_aliases\000result_layouts\000tf.backend_config\000jit(func)/jit(main)/call_tf[callable_flat_tf=.make_call..callable_flat_tf at 0x7f565ae72170> function_flat_tf= args_flat_sig_tf=(TensorSpec(shape=(2,), dtype=tf.float32, name=None),) output_avals=(ShapedArray(float32[2]),) has_side_effects=True ordered=False call_tf_graph=True]\000third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\000jit(func)/jit(main)/cos\000\000jax.arg_info\000x\000mhlo.sharding\000{replicated}\000jax.result_info\000main\000public\000tf.call_tf_function\000called_index\000called_name\000__inference_callable_flat_tf_10\000has_token_input_output\000" + } + } + attr { + key: "platforms" + value { + list { + s: "CPU" + } + } + } + attr { + key: "version" + value { + i: 5 + } + } +} +node { + name: "Identity" + op: "Identity" + input: "XlaCallModule" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "IdentityN" + op: "IdentityN" + input: "XlaCallModule" + input: "jax2tf_arg_0" + attr { + key: "T" + value { + list { + type: DT_FLOAT + type: DT_FLOAT + } + } + } + attr { + key: "_gradient_op_type" + value { + s: "CustomGradient-11" + } + } +} +node { + name: "jax2tf_out" + op: "Identity" + input: "IdentityN" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "the_result" + op: "Identity" + input: "jax2tf_out" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "Identity_1" + op: "Identity" + input: "the_result" + input: "^NoOp" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "NoOp" + op: "NoOp" + input: "^XlaCallModule" +} +library { + function { + signature { + name: "__inference_callable_flat_tf_10" + input_arg { + name: "args_tf_flat_0" + type: DT_FLOAT + } + output_arg { + name: "identity" + type: DT_FLOAT + } + } + node_def { + name: "Sin" + op: "Sin" + input: "args_tf_flat_0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + node_def { + name: "EnsureShape" + op: "EnsureShape" + input: "Sin:y:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: 2 + } + } + } + } + } + node_def { + name: "Identity" + op: "Identity" + input: "EnsureShape:output:0" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + } + ret { + key: "identity" + value: "Identity:output:0" + } + attr { + key: "_XlaMustCompile" + value { + b: false + } + } + attr { + key: "_construction_context" + value { + s: "kEagerRuntime" + } + } + arg_attr { + key: 0 + value { + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 2 + } + } + } + } + } + attr { + key: "_user_specified_name" + value { + s: "args_tf_flat_0" + } + } + } + } + } +} +versions { + producer: 1513 + min_consumer: 12 +} +""", + mlir_module_serialized=b'\n[\n\tthe_input\x12\x0bPlaceholder*#\n\x14_user_specified_name\x12\x0b\x12\tthe_input*\x0b\n\x05dtype\x12\x020\x01*\x0f\n\x05shape\x12\x06:\x04\x12\x02\x08\x02\n,\n\x0cjax2tf_arg_0\x12\x08Identity\x1a\tthe_input*\x07\n\x01T\x12\x020\x01\nm\n\x0bXlaSharding\x12\x0bXlaSharding\x1a\x0cjax2tf_arg_0*\x16\n\x10unspecified_dims\x12\x02\n\x00*\x07\n\x01T\x12\x020\x01*\x0e\n\x08sharding\x12\x02\x12\x00*\x12\n\x0c_XlaSharding\x12\x02\x12\x00\n\xe5\x0c\n\rXlaCallModule\x12\rXlaCallModule\x1a\x0bXlaSharding*\x0c\n\x03Tin\x12\x05\n\x032\x01\x01*\xf8\n\n\x06module\x12\xed\n\x12\xea\nML\xefR\x01StableHLO_v0.9.0\x00\x01\x19\x05\x01\x03\x01\x03\x05\x03\t\x07\t\x0b\r\x03\x87i\x0b\x019\x07\x0f\x0b\x17#\x0b\x0b\x0b3\x0b\x0b\x0b\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0b\x0f\x0b\x031\x0b\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f\x13\x0b\x0b\x0b\x0b\x0b\x0b#\x0b\x0f\x0b\x0b\x0b\x0b\x01\x03\x0f\x03\t\x13\x17\x07\x07\x02\xb2\x02\x1f\x11\x01\x05\x05\x0f\x173*\t\x01\x03\x07\x0b\x03\r\x03\x05\x0f\x05\x11\x05\x13\x05\x15\x03\x0b\x13=\x15I\x17K\x05Q\x19S\x05\x17\x05\x19\x05\x1b\x05\x1d\x03\x13\x1dU\x1f;!W#9%Y\'9)9+9-[\x05\x1f\x05!\x05#\x05%\x05\'\x05)\x05+\x05-\x05/\x1d1\x07\x051\x053\x1d7\x07\x055\x03\x01\x1d7\x03\x03?\r\x05ACEG\x1d9\x1d;\x1d=\x1d?#\x05\x03\x03M\r\x03O;\x1dA\x1dC\x1dE\x0b\x05\x1dG\x05\x03\r\x07]_aceg\x1dI\x13\t\x01\x1dK\x1dM\x1dO\x05\x01\x01\x02\x02)\x03\t\x07\x11\x03\x03\x03\x03\t\x1d\x04Q\x05\x01\x11\x01\t\x07\x03\x01\x05\x03\x11\x01\x11\x05\x03\x07\x0f\x03\x03\x01\x05\x07/\x1b\x03\x03\x03\x01\x07\x065\x03\x03\x03\x03\t\x04\x01\x03\x05\x06\x03\x01\x05\x01\x00\x82\x10Q/A\x19\x1b)\x0f\x0b!\x1b\x1d\x05\x1b\x031\x83\xca\x06%\x1f/!!)#\x1f\x19\x1f\x15\x1d\x15\x13%)\x13\x15\x15\x1f\x11\x0f\x0b\x11builtin\x00vhlo\x00module\x00func_v1\x00custom_call_v1\x00cosine_v1\x00return_v1\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00tf.backend_config\x00jit(func)/jit(main)/call_tf[callable_flat_tf=.make_call..callable_flat_tf at 0x7f565ae72170> function_flat_tf= args_flat_sig_tf=(TensorSpec(shape=(2,), dtype=tf.float32, name=None),) output_avals=(ShapedArray(float32[2]),) has_side_effects=True ordered=False call_tf_graph=True]\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00jit(func)/jit(main)/cos\x00\x00jax.arg_info\x00x\x00mhlo.sharding\x00{replicated}\x00jax.result_info\x00main\x00public\x00tf.call_tf_function\x00called_index\x00called_name\x00__inference_callable_flat_tf_10\x00has_token_input_output\x00*\r\n\x07version\x12\x02\x18\x05*\x10\n\x04Sout\x12\x08\n\x06:\x04\x12\x02\x08\x02*\x14\n\tplatforms\x12\x07\n\x05\x12\x03CPU*\x1c\n\x16has_token_input_output\x12\x02(\x00*\r\n\x04Tout\x12\x05\n\x032\x01\x01*\x13\n\rdim_args_spec\x12\x02\n\x00*6\n\rfunction_list\x12%\n#J!\n\x1f__inference_callable_flat_tf_10\n,\n\x08Identity\x12\x08Identity\x1a\rXlaCallModule*\x07\n\x01T\x12\x020\x01\nj\n\tIdentityN\x12\tIdentityN\x1a\rXlaCallModule\x1a\x0cjax2tf_arg_0*(\n\x11_gradient_op_type\x12\x13\x12\x11CustomGradient-11*\x0b\n\x01T\x12\x06\n\x042\x02\x01\x01\n*\n\njax2tf_out\x12\x08Identity\x1a\tIdentityN*\x07\n\x01T\x12\x020\x01\n+\n\nthe_result\x12\x08Identity\x1a\njax2tf_out*\x07\n\x01T\x12\x020\x01\n2\n\nIdentity_1\x12\x08Identity\x1a\nthe_result\x1a\x05^NoOp*\x07\n\x01T\x12\x020\x01\n\x1c\n\x04NoOp\x12\x04NoOp\x1a\x0e^XlaCallModule\x12\x8d\x03\n\x8a\x03:J\x08\x00\x12F\n(\n\x14_user_specified_name\x12\x10\x12\x0eargs_tf_flat_0\n\x1a\n\x0e_output_shapes\x12\x08\n\x06:\x04\x12\x02\x08\x02*\x15\n\x0f_XlaMustCompile\x12\x02(\x00*(\n\x15_construction_context\x12\x0f\x12\rkEagerRuntime"\x1d\n\x08identity\x12\x11Identity:output:0\x1a#\n\x03Sin\x12\x03Sin\x1a\x0eargs_tf_flat_0*\x07\n\x01T\x12\x020\x01\x1a=\n\x0bEnsureShape\x12\x0bEnsureShape\x1a\x07Sin:y:0*\x0f\n\x05shape\x12\x06:\x04\x12\x02\x08\x02*\x07\n\x01T\x12\x020\x01\x1a3\n\x08Identity\x12\x08Identity\x1a\x14EnsureShape:output:0*\x07\n\x01T\x12\x020\x01\nC\n\x1f__inference_callable_flat_tf_10\x12\x12\n\x0eargs_tf_flat_0\x18\x01\x1a\x0c\n\x08identity\x18\x01"\x05\x08\xe9\x0b\x10\x0c', + xla_call_module_version=5, +) # End paste