Skip to content

Commit

Permalink
Merge pull request #14594 from gnecula:tf_platform_check
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 511409386
  • Loading branch information
jax authors committed Feb 22, 2023
2 parents 1d4b7a3 + b8f96f0 commit 7e001d8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 26 deletions.
65 changes: 43 additions & 22 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from jax.experimental import maps
from jax.experimental.jax2tf import shape_poly
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import mlir
from jax.interpreters import pxla
from jax.interpreters import xla

Expand All @@ -55,6 +54,7 @@
from jax._src import util
from jax._src.global_device_array import GlobalDeviceArray
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.lax import control_flow as lax_control_flow
from jax._src.lax import lax as lax_internal
from jax._src.lax import linalg as lax_linalg
Expand Down Expand Up @@ -213,7 +213,8 @@ def convert(fun_jax: Callable,
polymorphic_shapes=None,
with_gradient=True,
enable_xla=True,
experimental_native_lowering="default") -> Callable:
experimental_native_lowering="default",
experimental_native_lowering_strict_checks=False) -> Callable:
"""Lowers `fun_jax` into a function that uses only TensorFlow ops.
See
Expand Down Expand Up @@ -273,6 +274,10 @@ def convert(fun_jax: Callable,
function and aborts if this is not possible.
experimental_native_lowering: DO NOT USE, for experimental purposes only.
The value "default" defers to --jax2tf_default_experimental_native_lowering.
experimental_native_lowering_strict_checks: DO NOT USE, for experimental purposes only.
In conjunction with `experimental_native_lowering`, enable the following
checks: the lowered computation is executed on a platform for which it
was lowered, (more to come).
Returns:
A version of `fun_jax` that expects TfVals as arguments (or
Expand Down Expand Up @@ -358,11 +363,13 @@ def converted_fun_tf(*args_tf: TfVal, **kwargs_tf: TfVal) -> TfVal:

@tf.custom_gradient
def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
outs_tf, out_avals = _interpret_fun_jax(
fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering,
experimental_native_lowering_strict_checks=experimental_native_lowering_strict_checks)
return (tuple(outs_tf),
make_custom_gradient_fn_tf(
fun_flat_jax=fun_flat_jax,
Expand All @@ -373,11 +380,13 @@ def converted_fun_flat_with_custom_gradient_tf(*args_flat_tf: TfVal) -> TfVal:

out_flat_tf = converted_fun_flat_with_custom_gradient_tf(*args_flat_tf)
else:
outs_tf, out_avals = _interpret_fun_jax(fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering)
outs_tf, out_avals = _interpret_fun_jax(
fun_flat_jax,
args_flat_tf, args_avals_flat,
name_stack,
fresh_constant_cache=True,
experimental_native_lowering=experimental_native_lowering,
experimental_native_lowering_strict_checks=experimental_native_lowering_strict_checks)
message = ("The jax2tf-converted function does not support gradients. "
"Use `with_gradient` parameter to enable gradients")
# We use PreventGradient, which is propagated through a SavedModel.
Expand Down Expand Up @@ -566,11 +575,14 @@ def _interpret_fun_jax(
args_avals: Sequence[core.ShapedArray],
extra_name_stack: Optional[str],
fresh_constant_cache: bool = False,
experimental_native_lowering: bool = False
experimental_native_lowering: bool = False,
experimental_native_lowering_strict_checks: bool = True,
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
if experimental_native_lowering:
del extra_name_stack
return _lower_native_and_run(fun_jax, args_avals, args_tf)
return _lower_native_and_run(
fun_jax, args_avals, args_tf,
experimental_native_lowering_strict_checks=experimental_native_lowering_strict_checks)
else:
with core.new_base_main(TensorFlowTrace) as main: # type: ignore
subtrace_fun = _interpret_subtrace(lu.wrap_init(fun_jax), main, args_avals)
Expand All @@ -587,6 +599,8 @@ def _interpret_fun_jax(
def _lower_native_and_run(fun_jax: Callable,
args_avals: Sequence[core.ShapedArray],
args_tf: Sequence[TfVal],
*,
experimental_native_lowering_strict_checks: bool,
) -> Tuple[Tuple[TfVal, ...], Tuple[core.ShapedArray, ...]]:
"""Lowers the function using native lowering and then invokes it.
Expand Down Expand Up @@ -637,7 +651,7 @@ def _lower_native_and_run(fun_jax: Callable,
lowered = fun_jax_lower(*arg_specs_jax)._lowering
if config.jax2tf_use_stablehlo:
mlir_module = lowered.stablehlo()
xla_call_module_version = 2
xla_call_module_version = 3
else:
mlir_module = lowered.mhlo()
xla_call_module_version = 1
Expand Down Expand Up @@ -729,18 +743,25 @@ def _out_type(jax_type):
args_tf = tuple(
map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"]))

if logging.vlog_is_on(3):
mlir_module_text = mlir.module_to_string(mlir_module)
logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s",
xla_call_module_version, ", ".join(dim_args_spec),
mlir_module_text)
res = tfxla.call_module(
args_tf,
call_module_attrs = dict(
version=xla_call_module_version,
module=mlir_serialized_module,
Tout=out_types,
Sout=out_shapes,
dim_args_spec=dim_args_spec)
log_msg = f"version={xla_call_module_version} dim_args_spec=" + ", ".join(dim_args_spec)
if xla_call_module_version == 3:
if experimental_native_lowering_strict_checks:
call_module_attrs["platforms"] = (jax.default_backend().upper(),)
else:
call_module_attrs["platforms"] = () # No platform checking
log_msg += " platforms=" + ", ".join(call_module_attrs["platforms"]) # type: ignore
if logging.vlog_is_on(3):
mlir_module_text = mlir.module_to_string(mlir_module)
logging.vlog(3, "XlaCallModule (%s)\n%s",
log_msg,
mlir_module_text)
res = tfxla.call_module(args_tf, **call_module_attrs)
if "out_shardings" in lowered.compile_args:
res = list(map(_shard_value, res, out_avals, lowered.compile_args["out_shardings"]))

Expand Down
11 changes: 7 additions & 4 deletions jax/experimental/jax2tf/tests/primitives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,13 @@ def test_integer_div(self):
y = np.int32(3)
self.ConvertAndCompare(jnp.floor_divide, x, y)
expected = jnp.floor_divide(x, y)
# Try it with TF 1 as well (#5831)
with tf.compat.v1.Session() as sess:
tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y))
self.assertAllClose(expected, tf1_res)
if not config.jax2tf_default_experimental_native_lowering:
# With native lowering TF1 seems to want to run the converted code
# on the CPU even when the default backend is the TPU.
# Try it with TF 1 as well (#5831)
with tf.compat.v1.Session() as sess:
tf1_res = sess.run(jax2tf.convert(jnp.floor_divide)(x, y))
self.assertAllClose(expected, tf1_res)

def test_boolean_gather(self):
values = np.array([[True, True], [False, True], [False, False]],
Expand Down

0 comments on commit 7e001d8

Please sign in to comment.