Skip to content

Commit

Permalink
[MHLO] Prefer backend-specific HLO lowerings instead of non-backend-s…
Browse files Browse the repository at this point in the history
…pecific MHLO lowerings.

This allows (in subsequent changes) to switch the generic case for translating a primitive to MHLO, even if we can't yet use an MHLO lowering for a backend-specific case yet.

Add a handful of direct MLIR lowerings for primitives that lacked them.

PiperOrigin-RevId: 439912093
  • Loading branch information
hawkinsp authored and jax authors committed Apr 6, 2022
1 parent 4012267 commit b9bb613
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 20 deletions.
7 changes: 5 additions & 2 deletions jax/_src/lax/control_flow.py
Expand Up @@ -2107,14 +2107,14 @@ def scan_bind(*args, **params):
xla.register_translation(scan_p, xla.lower_fun(_scan_impl, new_style=True,
multiple_results=True),
initial_style=True)
mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))
batching.axis_primitive_batchers[scan_p] = _scan_batching_rule
masking.masking_rules[scan_p] = _scan_masking_rule
core.custom_typechecks[scan_p] = partial(_scan_typecheck, False)
pe.partial_eval_jaxpr_custom_rules[scan_p] = \
partial(pe.partial_eval_jaxpr_custom_rule_not_implemented, 'scan')

mlir.register_lowering(scan_p,
mlir.lower_fun(_scan_impl, multiple_results=True))


@api_boundary
Expand Down Expand Up @@ -2667,6 +2667,9 @@ def _linear_solve_batching_rule(axis_size, axis_name, main_type, args, dims,
linear_solve_p, xla.lower_fun(_custom_linear_solve_impl, new_style=True,
multiple_results=True),
initial_style=True)
mlir.register_lowering(
linear_solve_p, mlir.lower_fun(_custom_linear_solve_impl,
multiple_results=True))
ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
batching.axis_primitive_batchers[linear_solve_p] = _linear_solve_batching_rule
pe.partial_eval_jaxpr_custom_rules[linear_solve_p] = \
Expand Down
33 changes: 21 additions & 12 deletions jax/_src/lax/lax.py
Expand Up @@ -1674,31 +1674,33 @@ def _round_lower(ctx, x, *, rounding_method):
ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x))))
mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosOp))

@partial(xla.lower_fun, multiple_results=False, new_style=True)
@_upcast_fp16_for_computation
def tan_translation_rule(x):
def _tan_impl(x):
return div(sin(x), cos(x))

tan_p = standard_unop(_float | _complex, 'tan',
translation_rule=tan_translation_rule)
tan_p = standard_unop(
_float | _complex, 'tan',
translation_rule=xla.lower_fun(_tan_impl, multiple_results=False,
new_style=True))
ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans)))
mlir.register_lowering(tan_p, mlir.lower_fun(_tan_impl, multiple_results=False))


def asin_translation_rule(x):
def asin_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
return mul(_const(x, -1j), asinh(mul(_const(x, 1j), x)))
else:
return mul(_const(x, 2),
atan2(x, add(_const(x, 1), sqrt(sub(_const(x, 1), square(x))))))

asin_p = standard_unop(_float | _complex, 'asin',
translation_rule=xla.lower_fun(asin_translation_rule,
translation_rule=xla.lower_fun(asin_impl,
multiple_results=False,
new_style=True))
ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x))))
mlir.register_lowering(asin_p, mlir.lower_fun(asin_impl,
multiple_results=False))


def acos_translation_rule(x):
def acos_impl(x):
if dtypes.issubdtype(_dtype(x), np.complexfloating):
result = mul(_const(x, 1j), acosh(x))
# By convention, numpy chooses the branch with positive real part.
Expand All @@ -1716,19 +1718,23 @@ def acos_translation_rule(x):
full_like(x, np.pi))

acos_p = standard_unop(_float | _complex, 'acos',
translation_rule=xla.lower_fun(acos_translation_rule,
translation_rule=xla.lower_fun(acos_impl,
multiple_results=False,
new_style=True))
ad.defjvp(acos_p, lambda g, x: mul(g, -rsqrt(_const(x, 1) - square(x))))
mlir.register_lowering(acos_p,
mlir.lower_fun(acos_impl, multiple_results=False))

def atan_translation_rule(x):
def atan_impl(x):
return atan2(x, _const(x, 1))

atan_p = standard_unop(_float | _complex, 'atan',
translation_rule=xla.lower_fun(atan_translation_rule,
translation_rule=xla.lower_fun(atan_impl,
multiple_results=False,
new_style=True))
ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x)))
mlir.register_lowering(atan_p, mlir.lower_fun(atan_impl,
multiple_results=False))

atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2')
ad.defjvp(atan2_p,
Expand Down Expand Up @@ -2660,6 +2666,9 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers,
dot_dnums, precision_attr(precision)).result]

mlir.register_lowering(dot_general_p, _dot_general_lower)
# Explicitly register a CPU lowering so we don't fall back to the XLA lowering
# on CPU.
mlir.register_lowering(dot_general_p, _dot_general_lower, platform="cpu")


def _broadcast_in_dim_shape_rule(operand, *, shape, broadcast_dimensions):
Expand Down
7 changes: 7 additions & 0 deletions jax/_src/random.py
Expand Up @@ -33,6 +33,7 @@
from jax.numpy.linalg import cholesky, svd, eigh
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import xla
from jax._src.util import prod, canonicalize_axis

Expand Down Expand Up @@ -1018,6 +1019,12 @@ def _gamma_batching_rule(batched_args, batch_dims, *, prng_impl, log_space):
xla.register_translation(random_gamma_p, xla.lower_fun(
partial(_gamma_impl, use_vmap=False),
multiple_results=False, new_style=True), platform='cpu')
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=True),
multiple_results=False))
mlir.register_lowering(random_gamma_p, mlir.lower_fun(
partial(_gamma_impl, use_vmap=False),
multiple_results=False), platform='cpu')
batching.primitive_batchers[random_gamma_p] = _gamma_batching_rule

def gamma(key: KeyArray,
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1118,11 +1118,11 @@ def _integer_pow(x, *, y: int, _in_avals: Sequence[core.ShapedArray],
tf_impl[lax.cos_p] = tf.math.cos
tf_impl[lax.cosh_p] = tf.math.cosh
tf_impl_with_avals[lax.acos_p] = _convert_jax_impl(
lax_internal.acos_translation_rule, multiple_results=False)
lax_internal.acos_impl, multiple_results=False)
tf_impl_with_avals[lax.asin_p] = _convert_jax_impl(
lax_internal.asin_translation_rule, multiple_results=False)
lax_internal.asin_impl, multiple_results=False)
tf_impl_with_avals[lax.atan_p] = _convert_jax_impl(
lax_internal.atan_translation_rule, multiple_results=False)
lax_internal.atan_impl, multiple_results=False)

def _atan2(y, x, **kwargs):
if x.dtype.is_complex or y.dtype.is_complex:
Expand Down
7 changes: 4 additions & 3 deletions jax/interpreters/mlir.py
Expand Up @@ -721,10 +721,11 @@ def write(v: core.Var, node: Sequence[ir.Value]):
with source_info_util.user_context(eqn.source_info.traceback), loc:
if eqn.primitive in _platform_specific_lowerings[ctx.platform]:
rule = _platform_specific_lowerings[ctx.platform][eqn.primitive]
elif eqn.primitive in xla._backend_specific_translations[ctx.platform]:
rule = xla_fallback_lowering(eqn.primitive)
elif eqn.primitive in _lowerings:
rule = _lowerings[eqn.primitive]
elif (eqn.primitive in xla._translations or
eqn.primitive in xla._backend_specific_translations[ctx.platform]):
elif eqn.primitive in xla._translations:
rule = xla_fallback_lowering(eqn.primitive)
else:
raise NotImplementedError(
Expand All @@ -741,7 +742,7 @@ def write(v: core.Var, node: Sequence[ir.Value]):
out_nodes = tuple(map(wrap_singleton_ir_values, ans))
except TypeError as e:
raise ValueError("Output of translation rule must be iterable: "
f"{eqn}") from e
f"{eqn}, got output {ans}") from e

assert all(isinstance(v, tuple) for v in out_nodes), (ans, eqn)
assert all(isinstance(v, ir.Value) for w in out_nodes for v in w), (ans, eqn)
Expand Down

0 comments on commit b9bb613

Please sign in to comment.