diff --git a/CHANGELOG.md b/CHANGELOG.md index 33603fa0c7cc..299631235b28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,11 @@ Remember to align the itemized text with the first line of an item within a list * The deprecated config option `jax_jit_pjit_api_merge`, which did nothing, has been removed. +* New features + * JAX now supports a configuration flag --jax_serialization_version + and a JAX_SERIALIZATION_VERSION environment variable to control the + serialization version ({jax-issue}`#16746`). + ## jaxlib 0.4.14 * Deprecations diff --git a/jax/_src/config.py b/jax/_src/config.py index 3b2bea6a9c8a..ee1ce471c478 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -687,6 +687,21 @@ def update_thread_local_jit_state(**kw): ) ) +jax_serialization_version = config.define_int_state( + name='jax_serialization_version', + # Note: bump the default serialization version at least one month after + # we update XlaCallModule to support the new version, so that serialized + # modules are forward compatible with deployed versions of XlaCallModule. + # Version 6 of XlaCallModule is supported since June 7th, 2023. + default=int_env('JAX_SERIALIZATION_VERSION', 6), + help=( + 'The version number to use for native serialization. This must be ' + 'within the range of versions supported by the tf.XlaCallModule ' + 'used in your deployment environment. ' + 'See https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code.' + ) +) + jax_platforms = config.define_string_state( name='jax_platforms', default=None, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ca83082a4129..a2ed5eccecb9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -843,15 +843,9 @@ def _convert_value(val, aval): kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] - if hasattr(tfxla, "call_module_maximum_supported_version"): - max_version_supported = tfxla.call_module_maximum_supported_version() - else: - max_version_supported = 5 - # TODO(necula): cleanup handling of Exported.xla_call_module_version - assert exported.xla_call_module_version == 6 - + version = exported.xla_call_module_version call_module_attrs = dict( - version=max_version_supported, + version=version, Tout=out_types, Sout=out_shapes_tf, function_list=[ @@ -861,12 +855,12 @@ def _convert_value(val, aval): ) call_module_attrs["platforms"] = (exported.lowering_platform.upper(),) - if max_version_supported >= 6: + if version >= 6: call_module_attrs["disabled_checks"] = tuple( str(dc) for dc in exported.disabled_checks) else: - if exported.xla_call_module_version >= 3: + if version >= 3: if DisabledSafetyCheck.platform() in exported.disabled_checks: call_module_attrs["platforms"] = () # No platform checking diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 343206849a30..1b8f44a32817 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -26,6 +26,7 @@ import numpy as np import jax +from jax import config from jax import sharding from jax._src import core @@ -438,7 +439,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: def _serialize_module(module: ir.Module) -> tuple[bytes, int]: - xla_call_module_version = 6 + xla_call_module_version = config.jax_serialization_version mlir_str = mlir.module_to_bytecode(module) if hlo.get_api_version() < 4: target_version = hlo.get_earliest_forward_compatible_version() diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 6fccff8c3b9d..fbec68c946e0 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1677,6 +1677,7 @@ def get_serialized_computation( class XlaCallModuleTest(tf_test_util.JaxToTfTestCase): """Unit tests for XlaCallModule. Will move these eventually to TF.""" + # TODO(necula): move these tests to TF def test_simple(self): def f_jax(x): @@ -1796,6 +1797,15 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) +class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): + # Use a separate test case with the default jax_serialization_version + def setUp(self): + self.use_max_serialization_version = False + super().setUp() + + def test_simple(self): + self.ConvertAndCompare(jnp.sin, 0.7) + if __name__ == "__main__": # TODO: Remove once tensorflow is 2.10.0 everywhere. diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 04bfe899bf5f..27b39c119bcc 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -14,6 +14,7 @@ import contextlib import dataclasses +import functools import re import os @@ -34,6 +35,7 @@ import numpy as np import tensorflow as tf # type: ignore[import] from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import] +from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] DType = Any @@ -153,6 +155,9 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, @jtu.with_config(jax_numpy_rank_promotion="allow", jax_numpy_dtype_promotion='standard') class JaxToTfTestCase(jtu.JaxTestCase): + # We want most tests to use the maximum available version, from the locally + # installed tfxla module. + use_max_serialization_version = True def setUp(self): super().setUp() @@ -167,6 +172,19 @@ def setUp(self): self.assertEqual(jtu.device_under_test().upper(), self.tf_default_device.device_type) + # We run the tests using the maximum version supported, even though + # the default serialization version may be held back for a while to + # ensure compatibility + version = config.jax_serialization_version + self.addCleanup(functools.partial(config.update, + "jax_serialization_version", version)) + if self.use_max_serialization_version: + max_version = tfxla.call_module_maximum_supported_version() + self.assertLessEqual(version, max_version) + version = max_version + config.update("jax_serialization_version", max_version) + logging.info("Using JAX serialization version %s", version) + with contextlib.ExitStack() as stack: stack.enter_context(tf.device(self.tf_default_device)) self.addCleanup(stack.pop_all().close)