Skip to content

Commit

Permalink
[jax2tf] Added error for attempting to use wrong jax_serialization_ve…
Browse files Browse the repository at this point in the history
…rsion

Previously, the serialization would use the specified serialization version
without checking if it supported by the serialzier.
This could result in invalid serializations

Also add some compatibility tests for all supported versions.
  • Loading branch information
gnecula committed Jul 25, 2023
1 parent 14a6089 commit 4081035
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 33 deletions.
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -834,7 +834,7 @@ def _convert_value(val, aval):
kept_args_avals = [aval for i, aval in enumerate(exported.in_avals) if i in exported.module_kept_var_idx]
kept_args_flat_tf = [atf for i, atf in enumerate(args_flat_tf) if i in exported.module_kept_var_idx]

version = exported.xla_call_module_version
version = exported.serialization_version
call_module_attrs = dict(
version=version,
Tout=out_types,
Expand Down
35 changes: 25 additions & 10 deletions jax/experimental/jax2tf/jax_export.py
Expand Up @@ -25,7 +25,6 @@
from typing import Any, Callable, Optional, Union

from absl import logging
import numpy as np

import jax
from jax import config
Expand Down Expand Up @@ -113,6 +112,9 @@ def __hash__(self) -> int:
return hash(self._impl)


minimum_supported_serialization_version = 6
maximum_supported_serialization_version = 7

@dataclasses.dataclass(frozen=True)
class Exported:
"""A JAX function lowered to StableHLO.
Expand All @@ -135,14 +137,15 @@ class Exported:
out_shardings: the flattened output shardings, as long as `in_avals`.
lowering_platform: one of 'tpu', 'cpu', 'cuda', 'rocm'
mlir_module_serialized: the serialized lowered VHLO module.
xla_call_module_version: a version number for the serialized module.
serialization_version: a version number for the serialized module.
See more versioning details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions.
module_kept_var_idx: the sorted indices of the arguments among `in_avals` that
must be passed to the module. The other arguments have been dropped
because they are not used. Same length as `in_shardings`.
module_uses_dim_vars: whether the `mlir_module_serialized` uses shape
polymorphic dimension variables. This may be from `in_avals` but also
from inner calls of shape-polymorphic Exported modules.
uses_shape_polymorphism: whether the `mlir_module_serialized` uses shape
polymorphism. This may be because `in_avals` contains dimension
variables, but also from inner calls of shape-polymorphic
Exported modules.
disabled_checks: a list of descriptors of safety checks that have been
disabled at export time. See docstring for `DisabledSafetyCheck`.
_get_vjp: an optional function that takes the current exported function and
Expand All @@ -164,9 +167,9 @@ class Exported:
disabled_checks: Sequence[DisabledSafetyCheck]

mlir_module_serialized: bytes
xla_call_module_version: int
serialization_version: int
module_kept_var_idx: tuple[int, ...]
module_uses_dim_vars: bool
uses_shape_polymorphism: bool

_get_vjp: Optional[Callable[["Exported"], "Exported"]]

Expand Down Expand Up @@ -311,6 +314,12 @@ def f_jax(*args, **kwargs): ...
"""
fun_name = getattr(fun_jax, "__name__", "unknown")
version = config.jax_serialization_version
if (version < minimum_supported_serialization_version or
version > maximum_supported_serialization_version):
raise ValueError(
f"The requested jax_serialization version {version} is outside the "
f"range of supported versions [{minimum_supported_serialization_version}"
f"..{maximum_supported_serialization_version}]")

def do_export(*args_specs, **kwargs_specs) -> Exported:
if not hasattr(fun_jax, "lower"):
Expand Down Expand Up @@ -395,8 +404,8 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
disabled_checks=tuple(disabled_checks),
mlir_module_serialized=mlir_module_serialized,
module_kept_var_idx=module_kept_var_idx,
module_uses_dim_vars=shape_poly_state.uses_dim_vars,
xla_call_module_version=version, # type: ignore
uses_shape_polymorphism=shape_poly_state.uses_dim_vars,
serialization_version=version, # type: ignore
_get_vjp=lambda exported: _export_native_vjp(fun_jax, exported))

return do_export
Expand Down Expand Up @@ -826,6 +835,12 @@ def flattened_primal_fun_jax(*args_flat):
### Importing

def call_exported(exported: Exported) -> Callable[..., jax.Array]:

if (exported.serialization_version >= 7 and
exported.uses_shape_polymorphism):
if xla_client.mlir_api_version < 52:
raise NotImplementedError(
"Current jaxlib does not support shape polymorphism with serialization version >= 7")
@jax.custom_vjp
def f_flat(*args_flat):
return call_exported_p.bind(*args_flat, exported=exported)
Expand Down Expand Up @@ -945,7 +960,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
f"The exported function '{exported.fun_name}' was lowered for "
f"platform '{exported.lowering_platform}' but it is used "
f"on '{platform}'.")
if exported.module_uses_dim_vars:
if exported.uses_shape_polymorphism:
ctx.module_context.shape_poly_state.uses_dim_vars = True

submodule = ir.Module.parse(exported.mlir_module)
Expand Down
8 changes: 4 additions & 4 deletions jax/experimental/jax2tf/tests/back_compat_test_util.py
Expand Up @@ -292,7 +292,7 @@ def serialize(self,

module_str = str(exported.mlir_module)
serialized = exported.mlir_module_serialized
module_version = exported.xla_call_module_version
module_version = exported.serialization_version
return serialized, module_str, module_version

def run_serialized(self, data: CompatTestData,
Expand Down Expand Up @@ -323,10 +323,10 @@ def _get_vjp(_):
lowering_platform=data.platform,
disabled_checks=(),
mlir_module_serialized=data.mlir_module_serialized,
xla_call_module_version=data.xla_call_module_version,
serialization_version=data.xla_call_module_version,
module_kept_var_idx=tuple(range(len(in_avals))),
module_uses_dim_vars=any(not core.is_constant_shape(a.shape)
for a in in_avals),
uses_shape_polymorphism=any(not core.is_constant_shape(a.shape)
for a in in_avals),
_get_vjp=_get_vjp)

# We use pjit in case there are shardings in the exported module.
Expand Down
59 changes: 50 additions & 9 deletions jax/experimental/jax2tf/tests/jax_export_test.py
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
import contextlib
import math
import functools
import logging
import re
from typing import Optional
import unittest
Expand All @@ -39,6 +41,23 @@

class JaxExportTest(jtu.JaxTestCase):

def override_serialization_version(self, version_override: int):
version = config.jax_serialization_version
if version != version_override:
self.addCleanup(functools.partial(config.update,
"jax_serialization_version",
version_override))
config.update("jax_serialization_version", version_override)
logging.info(
"Using JAX serialization version %s",
config.jax_serialization_version)

def setUp(self):
super().setUp()
# Run tests with the maximum supported version by default
self.override_serialization_version(
jax_export.maximum_supported_serialization_version)

def test_basic_export_only(self):
def my_fun(x):
return jnp.sin(x)
Expand Down Expand Up @@ -250,7 +269,6 @@ def f1_exp(a, b): # For VJP, make a function without kwargs
exp_vjp = jax.vjp(f1_exp, a, b)[1](out_ct)
self.assertAllClose(jax_vjp, exp_vjp)


def test_roundtrip(self):
def f1(x):
return jnp.sin(x)
Expand All @@ -265,6 +283,29 @@ def f2(x):
self.assertAllClose(jnp.cos(jnp.sin(jnp.sin(a))),
jax_export.call_exported(exp_f2)(a))

@jtu.parameterized_filterable(
kwargs=[
dict(v=v)
for v in range(jax_export.minimum_supported_serialization_version - 1,
jax_export.maximum_supported_serialization_version + 2)])
def test_shape_poly_basic_versions(self, v: int):
self.override_serialization_version(v)
with contextlib.ExitStack() as e:
if not (jax_export.minimum_supported_serialization_version <= v
<= jax_export.maximum_supported_serialization_version):
e.enter_context(self.assertRaisesRegex(
ValueError,
f"The requested jax_serialization version {v} is outside the range of supported versions"))

if (xc.mlir_api_version <= 51 and
config.jax_serialization_version >= 7):
raise unittest.SkipTest("Not supported in old jaxlib")
exp = jax_export.export(jnp.sin)(
jax_export.poly_spec((3, 4), np.float32, "w, h"))
x = np.arange(30, dtype=np.float32).reshape((5, 6))
res = jax_export.call_exported(exp)(x)
self.assertAllClose(res, np.sin(x))

# A function is exported with f32[poly_spec] and is called with different arg
# shapes. We use jax_export.call_exported and we also run the shape check
# module.
Expand Down Expand Up @@ -299,6 +340,8 @@ def test_poly_shape_checks(
arg_shape=(3, 4, 12), arg_dtype=np.float32,
expect_error=None): # If given, error from running the exported module

if xc.mlir_api_version <= 51:
raise unittest.SkipTest("Not supported in old jaxlib")
def f(x): # x: f32[poly_spec]
return jnp.reshape(x, (-1, x.shape[1]))

Expand All @@ -308,7 +351,7 @@ def f(x): # x: f32[poly_spec]
disabled_checks = ()
exp_f = jax_export.export(f, disabled_checks=disabled_checks)(
jax_export.poly_spec((3, 4, 12), np.float32, poly_spec))
self.assertEqual(exp_f.module_uses_dim_vars, poly_spec != "3,4,12")
self.assertEqual(exp_f.uses_shape_polymorphism, poly_spec != "3,4,12")
arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # arg : f32[3,4,12]

Expand Down Expand Up @@ -404,18 +447,16 @@ def test_poly_shape_checks_nested(
expect_error_run=None):
# Polymorphic export called with static or polymorphic shapes
if xc.mlir_api_version <= 51:
disabled_checks = (jax_export.DisabledSafetyCheck.shape_assertions(),)
else:
disabled_checks = ()
raise unittest.SkipTest("Not supported in old jaxlib")
def inner(x): # x: inner_poly_spec
return jnp.reshape(x, (-1, x.shape[1]))

arg = np.arange(np.prod(arg_shape),
dtype=arg_dtype).reshape(arg_shape) # x : f32[3,4,12]
inner_exp = jax_export.export(inner, disabled_checks=disabled_checks)(
inner_exp = jax_export.export(inner)(
jax_export.poly_spec((3, 4, 12), np.float32, inner_poly_spec))

self.assertEqual(inner_exp.module_uses_dim_vars,
self.assertEqual(inner_exp.uses_shape_polymorphism,
(inner_poly_spec != "3,4,12"))
def outer(x): # x: outer_poly_spec
# Use an addition to test that the shapes are refined properly for the
Expand All @@ -427,13 +468,13 @@ def outer(x): # x: outer_poly_spec
stack.push(self.assertRaisesRegex(ValueError, expect_error_outer_exp))

# Call it after exporting again, with polymorphic shapes
outer_exp = jax_export.export(outer, disabled_checks=disabled_checks)(
outer_exp = jax_export.export(outer)(
jax_export.poly_spec(arg.shape, arg.dtype, outer_poly_spec))

if expect_error_outer_exp is not None:
return

self.assertEqual(outer_exp.module_uses_dim_vars,
self.assertEqual(outer_exp.uses_shape_polymorphism,
(inner_poly_spec != "3,4,12" or outer_poly_spec != "3,4,12"))

with contextlib.ExitStack() as stack:
Expand Down
22 changes: 13 additions & 9 deletions jax/experimental/jax2tf/tests/tf_test_util.py
Expand Up @@ -31,6 +31,7 @@

from jax import config
from jax.experimental import jax2tf
from jax.experimental.jax2tf import jax_export
from jax._src import xla_bridge
import numpy as np
import tensorflow as tf # type: ignore[import]
Expand Down Expand Up @@ -156,7 +157,7 @@ def ComputeTfValueAndGrad(tf_f: Callable, tf_args: Sequence,
jax_numpy_dtype_promotion='standard')
class JaxToTfTestCase(jtu.JaxTestCase):
# We want most tests to use the maximum available version, from the locally
# installed tfxla module.
# installed tfxla module and jax_export.
use_max_serialization_version = True

def setUp(self):
Expand All @@ -179,14 +180,17 @@ def setUp(self):
self.addCleanup(functools.partial(config.update,
"jax_serialization_version", version))
if self.use_max_serialization_version:
# The largest version we support is 7
max_version = min(7, tfxla.call_module_maximum_supported_version())
self.assertLessEqual(version, max_version)
version = max_version
config.update("jax_serialization_version", max_version)
logging.info("Using JAX serialization version %s%s",
version,
" (max_version)" if self.use_max_serialization_version else "")
# Use the largest supported by both jax_export and tfxla.call_module
version = min(jax_export.maximum_supported_serialization_version,
tfxla.call_module_maximum_supported_version())
self.assertGreaterEqual(version,
jax_export.minimum_supported_serialization_version)
config.update("jax_serialization_version", version)
logging.info(
"Using JAX serialization version %s (jax_export.max_version %s, tf.XlaCallModule max version %s)",
version,
jax_export.maximum_supported_serialization_version,
tfxla.call_module_maximum_supported_version())

with contextlib.ExitStack() as stack:
stack.enter_context(tf.device(self.tf_default_device))
Expand Down

0 comments on commit 4081035

Please sign in to comment.