From 603eeb19017d50526a85e9c6c49f76330254bd47 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 16 Jul 2023 09:26:27 -0700 Subject: [PATCH] Copybara import of the project: -- 06bf5fe7b2ac97156df541bab989dc5beb1aff0c by George Necula : [jax2tf] Added a flag and environment variable to control the serialization version. This allows us to control the serialization version to be compatible with the deployed version of tf.XlaCallModule. In particular, we can run most tests with the maximum available version, while keeping the default lower. COPYBARA_INTEGRATE_REVIEW=https://github.com/google/jax/pull/16746 from gnecula:tf_version 06bf5fe7b2ac97156df541bab989dc5beb1aff0c PiperOrigin-RevId: 548504243 --- CHANGELOG.md | 5 +++++ jax/_src/config.py | 15 +++++++++++++++ jax/experimental/jax2tf/jax2tf.py | 14 ++++---------- jax/experimental/jax2tf/jax_export.py | 3 ++- jax/experimental/jax2tf/tests/jax2tf_test.py | 10 ++++++++++ jax/experimental/jax2tf/tests/tf_test_util.py | 18 ++++++++++++++++++ 6 files changed, 54 insertions(+), 11 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 33603fa0c7cc..299631235b28 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,11 @@ Remember to align the itemized text with the first line of an item within a list * The deprecated config option `jax_jit_pjit_api_merge`, which did nothing, has been removed. +* New features + * JAX now supports a configuration flag --jax_serialization_version + and a JAX_SERIALIZATION_VERSION environment variable to control the + serialization version ({jax-issue}`#16746`). + ## jaxlib 0.4.14 * Deprecations diff --git a/jax/_src/config.py b/jax/_src/config.py index 3b2bea6a9c8a..ee1ce471c478 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -687,6 +687,21 @@ def update_thread_local_jit_state(**kw): ) ) +jax_serialization_version = config.define_int_state( + name='jax_serialization_version', + # Note: bump the default serialization version at least one month after + # we update XlaCallModule to support the new version, so that serialized + # modules are forward compatible with deployed versions of XlaCallModule. + # Version 6 of XlaCallModule is supported since June 7th, 2023. + default=int_env('JAX_SERIALIZATION_VERSION', 6), + help=( + 'The version number to use for native serialization. This must be ' + 'within the range of versions supported by the tf.XlaCallModule ' + 'used in your deployment environment. ' + 'See https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code.' + ) +) + jax_platforms = config.define_string_state( name='jax_platforms', default=None, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ca83082a4129..a2ed5eccecb9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -843,15 +843,9 @@ def _convert_value(val, aval): kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx] kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx] - if hasattr(tfxla, "call_module_maximum_supported_version"): - max_version_supported = tfxla.call_module_maximum_supported_version() - else: - max_version_supported = 5 - # TODO(necula): cleanup handling of Exported.xla_call_module_version - assert exported.xla_call_module_version == 6 - + version = exported.xla_call_module_version call_module_attrs = dict( - version=max_version_supported, + version=version, Tout=out_types, Sout=out_shapes_tf, function_list=[ @@ -861,12 +855,12 @@ def _convert_value(val, aval): ) call_module_attrs["platforms"] = (exported.lowering_platform.upper(),) - if max_version_supported >= 6: + if version >= 6: call_module_attrs["disabled_checks"] = tuple( str(dc) for dc in exported.disabled_checks) else: - if exported.xla_call_module_version >= 3: + if version >= 3: if DisabledSafetyCheck.platform() in exported.disabled_checks: call_module_attrs["platforms"] = () # No platform checking diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index 343206849a30..1b8f44a32817 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -26,6 +26,7 @@ import numpy as np import jax +from jax import config from jax import sharding from jax._src import core @@ -438,7 +439,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: def _serialize_module(module: ir.Module) -> tuple[bytes, int]: - xla_call_module_version = 6 + xla_call_module_version = config.jax_serialization_version mlir_str = mlir.module_to_bytecode(module) if hlo.get_api_version() < 4: target_version = hlo.get_earliest_forward_compatible_version() diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index 6fccff8c3b9d..fbec68c946e0 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1677,6 +1677,7 @@ def get_serialized_computation( class XlaCallModuleTest(tf_test_util.JaxToTfTestCase): """Unit tests for XlaCallModule. Will move these eventually to TF.""" + # TODO(necula): move these tests to TF def test_simple(self): def f_jax(x): @@ -1796,6 +1797,15 @@ def func(): jax_result = func() self.assertEqual(tf_result, jax_result) +class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase): + # Use a separate test case with the default jax_serialization_version + def setUp(self): + self.use_max_serialization_version = False + super().setUp() + + def test_simple(self): + self.ConvertAndCompare(jnp.sin, 0.7) + if __name__ == "__main__": # TODO: Remove once tensorflow is 2.10.0 everywhere. diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index 04bfe899bf5f..27b39c119bcc 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -14,6 +14,7 @@ import contextlib import dataclasses +import functools import re import os @@ -34,6 +35,7 @@ import numpy as np import tensorflow as tf # type: ignore[import] from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import] +from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import] DType = Any @@ -153,6 +155,9 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence, @jtu.with_config(jax_numpy_rank_promotion="allow", jax_numpy_dtype_promotion='standard') class JaxToTfTestCase(jtu.JaxTestCase): + # We want most tests to use the maximum available version, from the locally + # installed tfxla module. + use_max_serialization_version = True def setUp(self): super().setUp() @@ -167,6 +172,19 @@ def setUp(self): self.assertEqual(jtu.device_under_test().upper(), self.tf_default_device.device_type) + # We run the tests using the maximum version supported, even though + # the default serialization version may be held back for a while to + # ensure compatibility + version = config.jax_serialization_version + self.addCleanup(functools.partial(config.update, + "jax_serialization_version", version)) + if self.use_max_serialization_version: + max_version = tfxla.call_module_maximum_supported_version() + self.assertLessEqual(version, max_version) + version = max_version + config.update("jax_serialization_version", max_version) + logging.info("Using JAX serialization version %s", version) + with contextlib.ExitStack() as stack: stack.enter_context(tf.device(self.tf_default_device)) self.addCleanup(stack.pop_all().close)