Skip to content

Commit

Permalink
Bump minimum jaxlib version to 0.3.10
Browse files Browse the repository at this point in the history
  • Loading branch information
sharadmv committed Jun 28, 2022
1 parent 10320cb commit fcf65ac
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 96 deletions.
38 changes: 7 additions & 31 deletions jax/_src/lax/lax.py
Expand Up @@ -1653,11 +1653,7 @@ def _tan_impl(x):

tan_p = standard_unop(_float | _complex, 'tan')
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
if jax._src.lib.mlir_api_version >= 11:
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))
else:
mlir.register_lowering(tan_p,
mlir.lower_fun(_tan_impl, multiple_results=False))
mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp))

def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
Expand Down Expand Up @@ -1713,35 +1709,21 @@ def atan_impl(x):

cosh_p = standard_unop(_float | _complex, 'cosh')
ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x)))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
else:
xla.register_translation(cosh_p, standard_translate(cosh_p))
if jax._src.lib.mlir_api_version >= 8:
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))
mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp))

asinh_p = standard_unop(_float | _complex, 'asinh')
ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x))))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))
else:
xla.register_translation(asinh_p, standard_translate(asinh_p))
mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp))

acosh_p = standard_unop(_float | _complex, 'acosh')
ad.defjvp(acosh_p,
lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x)))))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))
else:
xla.register_translation(acosh_p, standard_translate(acosh_p))
mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp))

atanh_p = standard_unop(_float | _complex, 'atanh')
ad.defjvp(atanh_p,
lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x))))
if jax._src.lib.mlir_api_version >= 10:
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))
else:
xla.register_translation(atanh_p, standard_translate(atanh_p))
mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp))

regularized_incomplete_beta_p = standard_naryop(
[_float, _float, _float], 'regularized_incomplete_beta')
Expand Down Expand Up @@ -1816,18 +1798,12 @@ def _bessel_i1e_jvp(g, y, x):
erf_p = standard_unop(_float, 'erf')
ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
if jax._src.lib.mlir_api_version >= 12:
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))
else:
xla.register_translation(erf_p, standard_translate(erf_p))
mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp))

erfc_p = standard_unop(_float, 'erfc')
ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)),
mul(g, exp(neg(square(x))))))
if jax._src.lib.mlir_api_version >= 12:
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))
else:
xla.register_translation(erfc_p, standard_translate(erfc_p))
mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp))

erf_inv_p = standard_unop(_float, 'erf_inv')
ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.),
Expand Down
10 changes: 2 additions & 8 deletions jax/_src/lax/slicing.py
Expand Up @@ -905,14 +905,8 @@ def _dynamic_slice_batching_rule(batched_args, batch_dims, *, slice_sizes):
batching.primitive_batchers[dynamic_slice_p] = _dynamic_slice_batching_rule

def _dynamic_slice_lower(ctx, x, *start_indices, slice_sizes):
if jax._src.lib.mlir_api_version < 13:
aval_out, = ctx.avals_out
return mhlo.DynamicSliceOp(mlir.aval_to_ir_type(aval_out), x,
start_indices,
mlir.dense_int_elements(slice_sizes)).results
else:
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results
return mhlo.DynamicSliceOp(x, start_indices,
mlir.dense_int_elements(slice_sizes)).results

mlir.register_lowering(dynamic_slice_p, _dynamic_slice_lower)

Expand Down
10 changes: 2 additions & 8 deletions jax/_src/lax/windowed_reductions.py
Expand Up @@ -646,14 +646,8 @@ def _select_and_gather_add_lowering(
const = lambda dtype, x: mlir.ir_constant(np.array(x, dtype=dtype),
canonicalize_types=False)

if jax._src.lib.mlir_api_version >= 9:
def _broadcast(x, dims):
return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))
else:
def _broadcast(x, dims):
etype = ir.RankedTensorType(x.type).element_type
return mhlo.BroadcastOp(ir.RankedTensorType(dims, etype), x,
mlir.dense_int_elements(dims))
def _broadcast(x, dims):
return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims))

if double_word_reduction:
# TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so
Expand Down
6 changes: 2 additions & 4 deletions jax/experimental/host_callback.py
Expand Up @@ -1011,8 +1011,7 @@ def _outside_call_translation_rule(ctx, avals_in, avals_out,
use_outfeed = _use_outfeed(ctx.platform)
# TODO(sharadmv): Delete non-outfeed path when jaxlib minimum version is
# bumped past 0.3.8.
assert use_outfeed or jaxlib.version < (0, 3, 8), (
'Should be using MLIR path for `CustomCall` lowering')
assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering'
send_infeed = use_outfeed and need_callback_results_on_device
generated_infeed = False # Keep track if we emitted an infeed op
if use_outfeed:
Expand Down Expand Up @@ -1198,8 +1197,7 @@ def wrapped_callback(*args):
f"identity = {identity}")
return results + [next_token, next_itoken]

if jaxlib.version >= (0, 3, 8):
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")
mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu")

def _outside_call_run_callback(
arrays, device, *,
Expand Down
17 changes: 3 additions & 14 deletions jax/interpreters/mlir.py
Expand Up @@ -208,7 +208,6 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
if canonicalize_types:
x = np.asarray(x, dtypes.canonicalize_dtype(x.dtype))
element_type = dtype_to_ir_type(x.dtype)
ir_type = ir.RankedTensorType.get(x.shape, element_type)
shape = x.shape
if x.dtype == np.bool_:
nelems = x.size
Expand All @@ -222,15 +221,9 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types
x = np.ascontiguousarray(x)
attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape)
if jax._src.lib.mlir_api_version < 21:
if jax._src.lib.xla_extension_version >= 64:
return (mhlo.ConstOp(attr).result,)
else:
return (mhlo.ConstOp(ir_type, attr).result,)
return (mhlo.ConstOp(attr).result,)
else:
if jax._src.lib.xla_extension_version >= 64:
return (mhlo.ConstantOp(attr).result,)
else:
return (mhlo.ConstantOp(ir_type, attr).result,)
return (mhlo.ConstantOp(attr).result,)



Expand Down Expand Up @@ -1047,11 +1040,7 @@ def _named_call_lowering(ctx, *args, name, backend=None,
def full_like_aval(value, aval: core.ShapedArray) -> ir.Value:
"""Returns an IR constant shaped full of `value` shaped like `aval`."""
zero = ir_constant(np.array(value, aval.dtype))
if jax._src.lib.mlir_api_version < 9:
return mhlo.BroadcastOp(aval_to_ir_type(aval), zero,
dense_int_elements(aval.shape)).result
else:
return mhlo.BroadcastOp(zero, dense_int_elements(aval.shape)).result
return mhlo.BroadcastOp(zero, dense_int_elements(aval.shape)).result


def zeros_like_lowering(ctx, x):
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Expand Up @@ -18,5 +18,5 @@ def _version_as_tuple(version_str):
__version__ = "0.3.15"
__version_info__ = _version_as_tuple(__version__)

_minimum_jaxlib_version = "0.3.7"
_minimum_jaxlib_version = "0.3.10"
_minimum_jaxlib_version_info = _version_as_tuple(_minimum_jaxlib_version)
12 changes: 3 additions & 9 deletions jaxlib/pocketfft.py
Expand Up @@ -149,15 +149,9 @@ def pocketfft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]):
ir.RankedTensorType.get([], out_type),
ir.DenseElementsAttr.get(
np.array(0, dtype=out_dtype), type=out_type))
if jax._src.lib.mlir_api_version < 9:
return mhlo.BroadcastOp(
ir.RankedTensorType.get(out_shape, out_type),
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
else:
return mhlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result
return mhlo.BroadcastOp(
zero,
ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result

u8_type = ir.IntegerType.get_unsigned(8)
if xla_client._version >= 64:
Expand Down
6 changes: 0 additions & 6 deletions tests/debugging_primitives_test.py
Expand Up @@ -401,11 +401,5 @@ def f(x):
lines = [f"{i}\n" for i in range(40)]
self._assertLinesEqual(output(), "".join(lines))

if jaxlib.version < (0, 3, 8):
# No lowering for `emit_python_callback` in older jaxlibs.
del DebugPrintTest
del DebugPrintControlFlowTest
del DebugPrintParallelTest

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())
12 changes: 0 additions & 12 deletions tests/jaxpr_effects_test.py
Expand Up @@ -578,11 +578,8 @@ class EffectOrderingTest(jtu.JaxTestCase):

@jtu.skip_on_devices(*disabled_backends)
def test_can_execute_python_callback(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
log = []
def log_value(x):
log.append(x)
Expand All @@ -600,11 +597,8 @@ def f(x):

@jtu.skip_on_devices(*disabled_backends)
def test_ordered_effect_remains_ordered_across_multiple_devices(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
log = []
Expand Down Expand Up @@ -636,11 +630,8 @@ def g(x):

@jtu.skip_on_devices(*disabled_backends)
def test_different_threads_get_different_tokens(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
tokens = []
Expand Down Expand Up @@ -693,11 +684,8 @@ def f(x):

@jtu.skip_on_devices(*disabled_backends)
def test_can_pmap_unordered_callback(self):
# TODO(sharadmv): remove jaxlib check when minimum version is bumped
# TODO(sharadmv): enable this test on GPU and TPU when backends are
# supported
if jaxlib.version < (0, 3, 8):
raise unittest.SkipTest("`emit_python_callback` only supported in jaxlib >= 0.3.8")
if jax.device_count() < 2:
raise unittest.SkipTest("Test requires >= 2 devices.")
log = set()
Expand Down
3 changes: 0 additions & 3 deletions tests/linalg_test.py
Expand Up @@ -41,8 +41,6 @@
float_types = jtu.dtypes.floating
complex_types = jtu.dtypes.complex

jaxlib_version = jax._src.lib.version


class NumpyLinalgTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -739,7 +737,6 @@ def compare_orthogonal(q1, q2):
qr = partial(jnp.linalg.qr, mode=mode)
jtu.check_jvp(qr, partial(jvp, qr), (a,), atol=3e-3)

@unittest.skipIf(jaxlib_version < (0, 3, 8), "test requires jaxlib>=0.3.8")
@jtu.skip_on_devices("tpu")
def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16):
# Regression test for https://github.com/google/jax/issues/10530
Expand Down

0 comments on commit fcf65ac

Please sign in to comment.