Skip to content

Commit

Permalink
[jax2tf] Raise errors for experimental_native_lowering and custom_call
Browse files Browse the repository at this point in the history
Raise explicit error when the experimental_native_lowering encounters
a mhlo.custom_call. This would lead to failure when trying to run in TF.
  • Loading branch information
gnecula committed Jul 21, 2022
1 parent 07fcf79 commit 6c9d2a0
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
27 changes: 27 additions & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -52,6 +52,7 @@
from jax._src.lax import linalg as lax_linalg
from jax._src.lax import slicing as lax_slicing
from jax._src.lax import windowed_reductions as lax_windowed_reductions
from jax._src import lib as jaxlib
from jax._src.lib import xla_client

from jax.experimental.jax2tf import shape_poly
Expand Down Expand Up @@ -609,6 +610,24 @@ def _lower_native(fun: lu.WrappedFun, in_vals: Sequence[TfVal],

mhlo_module = lowered.mhlo()
mhlo_module_text = mlir.module_to_string(mhlo_module)
if jaxlib.version <= (0, 3, 14):
mhlo_module_text = _fixup_mhlo_module_text(mhlo_module_text)
# We do not support custom_call, try to give an error for now
if "mhlo.custom_call" in mhlo_module_text:
# Try to give a nice error message. We could just dump the module...
msg = ("experimental_native_lowering does not work with custom calls. "
"Most likely you are running this on CPU or GPU for JAX programs that "
"use custom calls on those platforms. The serialization should "
"work on TPU.")
custom_calls = re.findall(r'mhlo.custom_call.*call_target_name\s+=\s+"([^"]+)".*loc\(([^\)]+)\)',
mhlo_module_text)
for cc in custom_calls:
msg += f"\n{cc[0]}"
# Get the line number
m = re.search('^' + cc[1] + ' =.*', mhlo_module_text, re.MULTILINE)
if m:
msg += f"\n from line {m.group(0)}"
raise NotImplementedError(msg)
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")

# Figure out the result types and shapes
Expand Down Expand Up @@ -645,6 +664,14 @@ def _convert_res(res_val, res_jax_type):
for res_val, out_aval in zip(res, out_avals))
return zip(res, out_avals)

def _fixup_mhlo_module_text(mhlo_module_text: str) -> str:
# A workaround for MHLO not (yet) having backwards compatibility. With
# jaxlib 0.3.14 we have an old serialization method that puts "..." around
# MHLO attributes. The parser is new and does not accept those attributes.
# We try to fix it up here, temporarily.
import re
return re.sub(r'#mhlo<"([^"]+)">', "#mhlo<\\1>", mhlo_module_text)


def _call_wrapped_with_new_constant_cache(fun: lu.WrappedFun,
in_vals: Sequence[TfVal],
Expand Down
4 changes: 3 additions & 1 deletion jax/experimental/jax2tf/tests/jax2tf_test.py
Expand Up @@ -36,6 +36,7 @@
from jax.experimental.jax2tf.tests import tf_test_util
import jax.interpreters.mlir as mlir
from jax._src import source_info_util
from jax._src import lib as jaxlib
import jax._src.lib.xla_bridge

import numpy as np
Expand Down Expand Up @@ -1245,6 +1246,8 @@ def get_serialized_computation(
lowered = jax.jit(f_jax, abstracted_axes=abstracted_axes).lower(*args)
mhlo_module = lowered.compiler_ir(dialect='mhlo')
mhlo_module_text = mlir.module_to_string(mhlo_module)
if jaxlib.version <= (0, 3, 14):
mhlo_module_text = jax2tf.jax2tf._fixup_mhlo_module_text(mhlo_module_text)
logging.info(f'Serialized ir.Module = {mhlo_module_text}')
return mhlo_module_text

Expand All @@ -1266,7 +1269,6 @@ def f_jax(x):
self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res),
[jax_res])

@unittest.skip("TODO(necula): Cannot deserialize MHLO computation")
def test_while(self):
# With nested computation
def f_jax(count, x):
Expand Down
19 changes: 15 additions & 4 deletions jax/experimental/jax2tf/tests/primitives_test.py
Expand Up @@ -56,6 +56,7 @@
from typing import Any, Dict, Tuple
import unittest

from absl import logging
from absl.testing import absltest
from absl.testing import parameterized

Expand Down Expand Up @@ -101,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="conv_general_dilated_dtype_precision_lhs=float16[2,3,9,10]_rhs=float16[3,3,4,5]_windowstrides=(1,1)_padding=((0,0),(0,0))_lhsdilation=(1,1)_rhsdilation=(1,1)_dimensionnumbers=('NCHW','OIHW','NCHW')_featuregroupcount=1_batchgroupcount=1_precision=DEFAULT_preferred=float64_enablexla=True"
#one_containing="custom_linear_solve_"
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")
Expand All @@ -113,10 +114,17 @@ def test_prim(self, harness: primitive_harness.Harness):
func_jax = harness.dyn_fun
args = harness.dyn_args_maker(self.rng())
enable_xla = harness.params.get("enable_xla", True)
if config.jax2tf_default_experimental_native_lowering and not enable_xla:
return
associative_scan_reductions = harness.params.get("associative_scan_reductions", False)
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
try:
with jax.jax2tf_associative_scan_reductions(associative_scan_reductions):
self.ConvertAndCompare(func_jax, *args, limitations=limitations,
enable_xla=enable_xla)
except Exception as e:
if (config.jax2tf_default_experimental_native_lowering and
"does not work with custom calls" in str(e)):
logging.warning("Supressing error %s", e)

def test_primitive_coverage(self):
"""Fail if there are JAX primitives that are not implemented."""
Expand All @@ -139,6 +147,9 @@ def test_primitive_coverage(self):
for p in all_primitives:
if p.name == "axis_index":
continue
# TODO: remove once we delete sharded_jit.py
if p.name in ["sharded_call", "sharding_constraint"]:
continue
# TODO: Remove once tensorflow is 2.10.0 everywhere.
if p.name == "optimization_barrier":
continue
Expand Down

0 comments on commit 6c9d2a0

Please sign in to comment.