Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
06bf5fe by George Necula <gcnecula@gmail.com>:

[jax2tf] Added a flag and environment variable to control the serialization version.

This allows us to control the serialization version to be compatible with
the deployed version of tf.XlaCallModule. In particular, we can run
most tests with the maximum available version, while keeping the
default lower.

COPYBARA_INTEGRATE_REVIEW=#16746 from gnecula:tf_version 06bf5fe
PiperOrigin-RevId: 548504243
  • Loading branch information
gnecula authored and jax authors committed Jul 16, 2023
1 parent cd39128 commit 603eeb1
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 11 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Expand Up @@ -54,6 +54,11 @@ Remember to align the itemized text with the first line of an item within a list
* The deprecated config option `jax_jit_pjit_api_merge`, which did nothing,
has been removed.

* New features
* JAX now supports a configuration flag --jax_serialization_version
and a JAX_SERIALIZATION_VERSION environment variable to control the
serialization version ({jax-issue}`#16746`).

## jaxlib 0.4.14

* Deprecations
Expand Down
15 changes: 15 additions & 0 deletions jax/_src/config.py
Expand Up @@ -687,6 +687,21 @@ def update_thread_local_jit_state(**kw):
)
)

jax_serialization_version = config.define_int_state(
name='jax_serialization_version',
# Note: bump the default serialization version at least one month after
# we update XlaCallModule to support the new version, so that serialized
# modules are forward compatible with deployed versions of XlaCallModule.
# Version 6 of XlaCallModule is supported since June 7th, 2023.
default=int_env('JAX_SERIALIZATION_VERSION', 6),
help=(
'The version number to use for native serialization. This must be '
'within the range of versions supported by the tf.XlaCallModule '
'used in your deployment environment. '
'See https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+VERSION_MAXIMUM_SUPPORTED%22&type=code.'
)
)

jax_platforms = config.define_string_state(
name='jax_platforms',
default=None,
Expand Down
14 changes: 4 additions & 10 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -843,15 +843,9 @@ def _convert_value(val, aval):
kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx]
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]

if hasattr(tfxla, "call_module_maximum_supported_version"):
max_version_supported = tfxla.call_module_maximum_supported_version()
else:
max_version_supported = 5
# TODO(necula): cleanup handling of Exported.xla_call_module_version
assert exported.xla_call_module_version == 6

version = exported.xla_call_module_version
call_module_attrs = dict(
version=max_version_supported,
version=version,
Tout=out_types,
Sout=out_shapes_tf,
function_list=[
Expand All @@ -861,12 +855,12 @@ def _convert_value(val, aval):
)

call_module_attrs["platforms"] = (exported.lowering_platform.upper(),)
if max_version_supported >= 6:
if version >= 6:
call_module_attrs["disabled_checks"] = tuple(
str(dc)
for dc in exported.disabled_checks)
else:
if exported.xla_call_module_version >= 3:
if version >= 3:
if DisabledSafetyCheck.platform() in exported.disabled_checks:
call_module_attrs["platforms"] = () # No platform checking

Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/jax_export.py
Expand Up @@ -26,6 +26,7 @@
import numpy as np

import jax
from jax import config
from jax import sharding

from jax._src import core
Expand Down Expand Up @@ -438,7 +439,7 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:


def _serialize_module(module: ir.Module) -> tuple[bytes, int]:
xla_call_module_version = 6
xla_call_module_version = config.jax_serialization_version
mlir_str = mlir.module_to_bytecode(module)
if hlo.get_api_version() < 4:
target_version = hlo.get_earliest_forward_compatible_version()
Expand Down
10 changes: 10 additions & 0 deletions jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -1677,6 +1677,7 @@ def get_serialized_computation(

class XlaCallModuleTest(tf_test_util.JaxToTfTestCase):
"""Unit tests for XlaCallModule. Will move these eventually to TF."""
# TODO(necula): move these tests to TF
def test_simple(self):

def f_jax(x):
Expand Down Expand Up @@ -1796,6 +1797,15 @@ def func():
jax_result = func()
self.assertEqual(tf_result, jax_result)

class Jax2TfVersioningTest(tf_test_util.JaxToTfTestCase):
# Use a separate test case with the default jax_serialization_version
def setUp(self):
self.use_max_serialization_version = False
super().setUp()

def test_simple(self):
self.ConvertAndCompare(jnp.sin, 0.7)


if __name__ == "__main__":
# TODO: Remove once tensorflow is 2.10.0 everywhere.
Expand Down
18 changes: 18 additions & 0 deletions jax/experimental/jax2tf/tests/tf_test_util.py
Expand Up @@ -14,6 +14,7 @@

import contextlib
import dataclasses
import functools
import re
import os

Expand All @@ -34,6 +35,7 @@
import numpy as np
import tensorflow as tf # type: ignore[import]
from tensorflow.compiler.xla import xla_data_pb2 # type: ignore[import]
from tensorflow.compiler.tf2xla.python import xla as tfxla # type: ignore[import]

DType = Any

Expand Down Expand Up @@ -153,6 +155,9 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
@jtu.with_config(jax_numpy_rank_promotion="allow",
jax_numpy_dtype_promotion='standard')
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module.
use_max_serialization_version = True

def setUp(self):
super().setUp()
Expand All @@ -167,6 +172,19 @@ def setUp(self):
self.assertEqual(jtu.device_under_test().upper(),
self.tf_default_device.device_type)

# We run the tests using the maximum version supported, even though
# the default serialization version may be held back for a while to
# ensure compatibility
version = config.jax_serialization_version
self.addCleanup(functools.partial(config.update,
"jax_serialization_version", version))
if self.use_max_serialization_version:
max_version = tfxla.call_module_maximum_supported_version()
self.assertLessEqual(version, max_version)
version = max_version
config.update("jax_serialization_version", max_version)
logging.info("Using JAX serialization version %s", version)

with contextlib.ExitStack() as stack:
stack.enter_context(tf.device(self.tf_default_device))
self.addCleanup(stack.pop_all().close)
Expand Down

0 comments on commit 603eeb1

Please sign in to comment.