Skip to content

Commit

Permalink
[MHLO] Add MHLO lowerings for name_p and unreachable_p.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 441746096
  • Loading branch information
hawkinsp authored and jax authors committed Apr 14, 2022
1 parent e187428 commit 8980bc4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import partial_eval as pe
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.tree_util import tree_flatten, tree_unflatten
from jax._src import ad_util
Expand Down Expand Up @@ -394,6 +395,7 @@ def name_jvp(primals, tangents, *, name):

xla.register_translation(name_p,
lambda ctx, avals_in, avals_out, x, *, name: [x])
mlir.register_lowering(name_p, lambda ctx, x, *, name: [x])

def name_batcher(args, dims, *, name):
(x,), (d,) = args, dims
Expand Down
1 change: 1 addition & 0 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ def unreachable_impl(*_, out_avals, exc_type, message):
# that errs. Since translation over-approximates concrete evaluation,
# we err on translation for the time being.
xla.register_translation(unreachable_p, unreachable_impl)
mlir.register_lowering(unreachable_p, unreachable_impl)

# Abstract evaluation proceeds without issue, to allow for staging
unreachable_p.def_abstract_eval(lambda *_, out_avals, **__: out_avals)
Expand Down

0 comments on commit 8980bc4

Please sign in to comment.