diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index c1dd3a4f4188..c11a000c8179 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -2107,14 +2107,14 @@ def scan_bind(*args, **params): xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True, multiple_results=True), initial_style=True) +mlir.register_lowering(scan_p, + mlir.lower_fun(_scan_impl, multiple_results=True)) batching.axis_primitive_batchers[scan_p] = _scan_batching_rule masking.masking_rules[scan_p] = _scan_masking_rule core.custom_typechecks[scan_p] = partial(_scan_typecheck, False) pe.partial_eval_jaxpr_custom_rules[scan_p] = \ partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan') -mlir.register_lowering(scan_p, - mlir.lower_fun(_scan_impl, multiple_results=True)) @api_boundary @@ -2667,6 +2667,9 @@ def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims, linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True, multiple_results=True), initial_style=True) +mlir.register_lowering( + linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl, + multiple_results=True)) ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \ diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index d9c227bcb31a..e69fa01b77ca 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -1674,17 +1674,18 @@ def _round_lower(ctx, x, *, rounding_method): ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosOp)) -@partial(xla.lower_fun, multiple_results=False, new_style=True) @_upcast_fp16_for_computation -def tan_translation_rule(x): +def _tan_impl(x): return div(sin(x), cos(x)) -tan_p = standard_unop(_float | _complex, 'tan', - translation_rule=tan_translation_rule) +tan_p = standard_unop( + _float | _complex, 'tan', + translation_rule=xla.lower_fun(_tan_impl, multiple_results=False, + new_style=True)) ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) +mlir.register_lowering(tan_p, mlir.lower_fun(_tan_impl, multiple_results=False)) - -def asin_translation_rule(x): +def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x))) else: @@ -1692,13 +1693,14 @@ def asin_translation_rule(x): atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x)))))) asin_p = standard_unop(_float | _complex, 'asin', - translation_rule=xla.lower_fun(asin_translation_rule, + translation_rule=xla.lower_fun(asin_impl, multiple_results=False, new_style=True)) ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) +mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl, + multiple_results=False)) - -def acos_translation_rule(x): +def acos_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): result = mul(_const(x, 1j), acosh(x)) # By convention, numpy chooses the branch with positive real part. @@ -1716,19 +1718,23 @@ def acos_translation_rule(x): full_like(x, np.pi)) acos_p = standard_unop(_float | _complex, 'acos', - translation_rule=xla.lower_fun(acos_translation_rule, + translation_rule=xla.lower_fun(acos_impl, multiple_results=False, new_style=True)) ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x)))) +mlir.register_lowering(acos_p, + mlir.lower_fun(acos_impl, multiple_results=False)) -def atan_translation_rule(x): +def atan_impl(x): return atan2(x, _const(x, 1)) atan_p = standard_unop(_float | _complex, 'atan', - translation_rule=xla.lower_fun(atan_translation_rule, + translation_rule=xla.lower_fun(atan_impl, multiple_results=False, new_style=True)) ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) +mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl, + multiple_results=False)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, @@ -2660,6 +2666,9 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, dot_dnums, precision_attr(precision)).result] mlir.register_lowering(dot_general_p, _dot_general_lower) +# Explicitly register a CPU lowering so we don't fall back to the XLA lowering +# on CPU. +mlir.register_lowering(dot_general_p, _dot_general_lower, platform="cpu") def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions): diff --git a/jax/_src/random.py b/jax/_src/random.py index 6818284dd306..a5b288a17dbd 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -33,6 +33,7 @@ from jax.numpy.linalg import cholesky, svd, eigh from jax.interpreters import ad from jax.interpreters import batching +from jax.interpreters import mlir from jax.interpreters import xla from jax._src.util import prod, canonicalize_axis @@ -1018,6 +1019,12 @@ def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space): xla.register_translation(random_gamma_p, xla.lower_fun( partial(_gamma_impl, use_vmap=False), multiple_results=False, new_style=True), platform='cpu') +mlir.register_lowering(random_gamma_p, mlir.lower_fun( + partial(_gamma_impl, use_vmap=True), + multiple_results=False)) +mlir.register_lowering(random_gamma_p, mlir.lower_fun( + partial(_gamma_impl, use_vmap=False), + multiple_results=False), platform='cpu') batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule def gamma(key: KeyArray, diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 38cb7a16fab8..955bae971596 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -1118,11 +1118,11 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray], tf_impl[lax.cos_p] = tf.math.cos tf_impl[lax.cosh_p] = tf.math.cosh tf_impl_with_avals[lax.acos_p] = _convert_jax_impl( - lax_internal.acos_translation_rule, multiple_results=False) + lax_internal.acos_impl, multiple_results=False) tf_impl_with_avals[lax.asin_p] = _convert_jax_impl( - lax_internal.asin_translation_rule, multiple_results=False) + lax_internal.asin_impl, multiple_results=False) tf_impl_with_avals[lax.atan_p] = _convert_jax_impl( - lax_internal.atan_translation_rule, multiple_results=False) + lax_internal.atan_impl, multiple_results=False) def _atan2(y, x, **kwargs): if x.dtype.is_complex or y.dtype.is_complex: diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 8fd60cff1e87..d0926a11be90 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -721,10 +721,11 @@ def write(v: core.Var, node: Sequence[ir.Value]): with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in _platform_specific_lowerings[ctx.platform]: rule = _platform_specific_lowerings[ctx.platform][eqn.primitive] + elif eqn.primitive in xla._backend_specific_translations[ctx.platform]: + rule = xla_fallback_lowering(eqn.primitive) elif eqn.primitive in _lowerings: rule = _lowerings[eqn.primitive] - elif (eqn.primitive in xla._translations or - eqn.primitive in xla._backend_specific_translations[ctx.platform]): + elif eqn.primitive in xla._translations: rule = xla_fallback_lowering(eqn.primitive) else: raise NotImplementedError( @@ -741,7 +742,7 @@ def write(v: core.Var, node: Sequence[ir.Value]): out_nodes = tuple(map(wrap_singleton_ir_values, ans)) except TypeError as e: raise ValueError("Output of translation rule must be iterable: " - f"{eqn}") from e + f"{eqn}, got output {ans}") from e assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn) assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (ans, eqn)