diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 4517a869a5c3..e268ffa30d66 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -834,7 +834,7 @@ def _convert_value(val, aval): kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] - version = exported.xla_call_module_version + version = exported.serialization_version call_module_attrs = dict( version=version, Tout=out_types, diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 1712b33c455b..f6a35baa3ed8 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -25,7 +25,6 @@ from typing import Any, Callable, Optional, Union from absl import logging -import numpy as np import jax from jax import config @@ -113,6 +112,9 @@ def __hash__(self) -> int: return hash(self._impl) +minimum_supported_serialization_version = 6 +maximum_supported_serialization_version = 7 + @dataclasses.dataclass(frozen=True) class Exported: """A JAX function lowered to StableHLO. @@ -135,14 +137,15 @@ class Exported: out_shardings: the flattened output shardings, as long as `in_avals`. lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm' mlir_module_serialized: the serialized lowered VHLO module. - xla_call_module_version: a version number for the serialized module. + serialization_version: a version number for the serialized module. See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions. module_kept_var_idx: the sorted indices of the arguments among `in_avals` that must be passed to the module. The other arguments have been dropped because they are not used. Same length as `in_shardings`. - module_uses_dim_vars: whether the `mlir_module_serialized` uses shape - polymorphic dimension variables. This may be from `in_avals` but also - from inner calls of shape-polymorphic Exported modules. + uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape + polymorphism. This may be because `in_avals` contains dimension + variables, but also from inner calls of shape-polymorphic + Exported modules. disabled_checks: a list of descriptors of safety checks that have been disabled at export time. See docstring for `DisabledSafetyCheck`. _get_vjp: an optional function that takes the current exported function and @@ -164,9 +167,9 @@ class Exported: disabled_checks: Sequence[DisabledSafetyCheck] mlir_module_serialized: bytes - xla_call_module_version: int + serialization_version: int module_kept_var_idx: tuple[int, ...] - module_uses_dim_vars: bool + uses_shape_polymorphism: bool _get_vjp: Optional[Callable[["Exported"], "Exported"]] @@ -311,6 +314,12 @@ def f_jax(*args, **kwargs): ... """ fun_name = getattr(fun_jax, "__name__", "unknown") version = config.jax_serialization_version + if (version < minimum_supported_serialization_version or + version > maximum_supported_serialization_version): + raise ValueError( + f"The requested jax_serialization version {version} is outside the " + f"range of supported versions [{minimum_supported_serialization_version}" + f"..{maximum_supported_serialization_version}]") def do_export(*args_specs, **kwargs_specs) -> Exported: if not hasattr(fun_jax, "lower"): @@ -395,8 +404,8 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: disabled_checks=tuple(disabled_checks), mlir_module_serialized=mlir_module_serialized, module_kept_var_idx=module_kept_var_idx, - module_uses_dim_vars=shape_poly_state.uses_dim_vars, - xla_call_module_version=version, # type: ignore + uses_shape_polymorphism=shape_poly_state.uses_dim_vars, + serialization_version=version, # type: ignore _get_vjp=lambda exported: _export_native_vjp(fun_jax, exported)) return do_export @@ -826,6 +835,12 @@ def flattened_primal_fun_jax(*args_flat): ### Importing def call_exported(exported: Exported) -> Callable[..., jax.Array]: + + if (exported.serialization_version >= 7 and + exported.uses_shape_polymorphism): + if xla_client.mlir_api_version < 52: + raise NotImplementedError( + "Current jaxlib does not support shape polymorphism with serialization version >= 7") @jax.custom_vjp def f_flat(*args_flat): return call_exported_p.bind(*args_flat, exported=exported) @@ -945,7 +960,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args, f"The exported function '{exported.fun_name}' was lowered for " f"platform '{exported.lowering_platform}' but it is used " f"on '{platform}'.") - if exported.module_uses_dim_vars: + if exported.uses_shape_polymorphism: ctx.module_context.shape_poly_state.uses_dim_vars = True submodule = ir.Module.parse(exported.mlir_module) diff --git a/jax/experimental/jax2tf/tests/back_compat_test_util.py b/jax/experimental/jax2tf/tests/back_compat_test_util.py index 06993683ccb7..12056360960f 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test_util.py +++ b/jax/experimental/jax2tf/tests/back_compat_test_util.py @@ -292,7 +292,7 @@ def serialize(self, module_str = str(exported.mlir_module) serialized = exported.mlir_module_serialized - module_version = exported.xla_call_module_version + module_version = exported.serialization_version return serialized, module_str, module_version def run_serialized(self, data: CompatTestData, @@ -323,10 +323,10 @@ def _get_vjp(_): lowering_platform=data.platform, disabled_checks=(), mlir_module_serialized=data.mlir_module_serialized, - xla_call_module_version=data.xla_call_module_version, + serialization_version=data.xla_call_module_version, module_kept_var_idx=tuple(range(len(in_avals))), - module_uses_dim_vars=any(not core.is_constant_shape(a.shape) - for a in in_avals), + uses_shape_polymorphism=any(not core.is_constant_shape(a.shape) + for a in in_avals), _get_vjp=_get_vjp) # We use pjit in case there are shardings in the exported module. diff --git a/jax/experimental/jax2tf/tests/jax_export_test.py b/jax/experimental/jax2tf/tests/jax_export_test.py index 4ac1320bd5a7..46be6d4365c4 100644 --- a/jax/experimental/jax2tf/tests/jax_export_test.py +++ b/jax/experimental/jax2tf/tests/jax_export_test.py @@ -13,6 +13,8 @@ # limitations under the License. import contextlib import math +import functools +import logging import re from typing import Optional import unittest @@ -39,6 +41,23 @@ class JaxExportTest(jtu.JaxTestCase): + def override_serialization_version(self, version_override: int): + version = config.jax_serialization_version + if version != version_override: + self.addCleanup(functools.partial(config.update, + "jax_serialization_version", + version_override)) + config.update("jax_serialization_version", version_override) + logging.info( + "Using JAX serialization version %s", + config.jax_serialization_version) + + def setUp(self): + super().setUp() + # Run tests with the maximum supported version by default + self.override_serialization_version( + jax_export.maximum_supported_serialization_version) + def test_basic_export_only(self): def my_fun(x): return jnp.sin(x) @@ -250,7 +269,6 @@ def f1_exp(a, b): # For VJP, make a function without kwargs exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct) self.assertAllClose(jax_vjp, exp_vjp) - def test_roundtrip(self): def f1(x): return jnp.sin(x) @@ -265,6 +283,29 @@ def f2(x): self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))), jax_export.call_exported(exp_f2)(a)) + @jtu.parameterized_filterable( + kwargs=[ + dict(v=v) + for v in range(jax_export.minimum_supported_serialization_version - 1, + jax_export.maximum_supported_serialization_version + 2)]) + def test_shape_poly_basic_versions(self, v: int): + self.override_serialization_version(v) + with contextlib.ExitStack() as e: + if not (jax_export.minimum_supported_serialization_version <= v + <= jax_export.maximum_supported_serialization_version): + e.enter_context(self.assertRaisesRegex( + ValueError, + f"The requested jax_serialization version {v} is outside the range of supported versions")) + + if (xc.mlir_api_version <= 51 and + config.jax_serialization_version >= 7): + raise unittest.SkipTest("Not supported in old jaxlib") + exp = jax_export.export(jnp.sin)( + jax_export.poly_spec((3, 4), np.float32, "w, h")) + x = np.arange(30, dtype=np.float32).reshape((5, 6)) + res = jax_export.call_exported(exp)(x) + self.assertAllClose(res, np.sin(x)) + # A function is exported with f32[poly_spec] and is called with different arg # shapes. We use jax_export.call_exported and we also run the shape check # module. @@ -299,6 +340,8 @@ def test_poly_shape_checks( arg_shape=(3, 4, 12), arg_dtype=np.float32, expect_error=None): # If given, error from running the exported module + if xc.mlir_api_version <= 51: + raise unittest.SkipTest("Not supported in old jaxlib") def f(x): # x: f32[poly_spec] return jnp.reshape(x, (-1, x.shape[1])) @@ -308,7 +351,7 @@ def f(x): # x: f32[poly_spec] disabled_checks = () exp_f = jax_export.export(f, disabled_checks=disabled_checks)( jax_export.poly_spec((3, 4, 12), np.float32, poly_spec)) - self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12") + self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12") arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12] @@ -404,18 +447,16 @@ def test_poly_shape_checks_nested( expect_error_run=None): # Polymorphic export called with static or polymorphic shapes if xc.mlir_api_version <= 51: - disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),) - else: - disabled_checks = () + raise unittest.SkipTest("Not supported in old jaxlib") def inner(x): # x: inner_poly_spec return jnp.reshape(x, (-1, x.shape[1])) arg = np.arange(np.prod(arg_shape), dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12] - inner_exp = jax_export.export(inner, disabled_checks=disabled_checks)( + inner_exp = jax_export.export(inner)( jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec)) - self.assertEqual(inner_exp.module_uses_dim_vars, + self.assertEqual(inner_exp.uses_shape_polymorphism, (inner_poly_spec != "3,4,12")) def outer(x): # x: outer_poly_spec # Use an addition to test that the shapes are refined properly for the @@ -427,13 +468,13 @@ def outer(x): # x: outer_poly_spec stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp)) # Call it after exporting again, with polymorphic shapes - outer_exp = jax_export.export(outer, disabled_checks=disabled_checks)( + outer_exp = jax_export.export(outer)( jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec)) if expect_error_outer_exp is not None: return - self.assertEqual(outer_exp.module_uses_dim_vars, + self.assertEqual(outer_exp.uses_shape_polymorphism, (inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12")) with contextlib.ExitStack() as stack: diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 74d3f1afdf03..d77898f41ad6 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -31,6 +31,7 @@ from jax import config from jax.experimental import jax2tf +from jax.experimental.jax2tf import jax_export from jax._src import xla_bridge import numpy as np import tensorflow as tf # type: ignore[import] @@ -156,7 +157,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, jax_numpy_dtype_promotion='standard') class JaxToTfTestCase(jtu.JaxTestCase): # We want most tests to use the maximum available version, from the locally - # installed tfxla module. + # installed tfxla module and jax_export. use_max_serialization_version = True def setUp(self): @@ -179,14 +180,17 @@ def setUp(self): self.addCleanup(functools.partial(config.update, "jax_serialization_version", version)) if self.use_max_serialization_version: - # The largest version we support is 7 - max_version = min(7, tfxla.call_module_maximum_supported_version()) - self.assertLessEqual(version, max_version) - version = max_version - config.update("jax_serialization_version", max_version) - logging.info("Using JAX serialization version %s%s", - version, - " (max_version)" if self.use_max_serialization_version else "") + # Use the largest supported by both jax_export and tfxla.call_module + version = min(jax_export.maximum_supported_serialization_version, + tfxla.call_module_maximum_supported_version()) + self.assertGreaterEqual(version, + jax_export.minimum_supported_serialization_version) + config.update("jax_serialization_version", version) + logging.info( + "Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)", + version, + jax_export.maximum_supported_serialization_version, + tfxla.call_module_maximum_supported_version()) with contextlib.ExitStack() as stack: stack.enter_context(tf.device(self.tf_default_device))