Skip to content

Commit

Permalink
Readd a default lowering rule for cumsum et al.
Browse files Browse the repository at this point in the history
A previous change removed the only non-constrained lowering rule, breaking lowering for platforms without explicit lowering rules

PiperOrigin-RevId: 633297839
  • Loading branch information
hawkinsp authored and jax authors committed May 13, 2024
1 parent 1e48adc commit 72a81e5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,14 +2449,14 @@ def register_lowering(fn, platform=None):

if xla_extension_version >= 263:
# In XLA, there's a rewriter for an O(N^2) reduce-window implementation.
# TODO(https://github.com/llvm/llvm-project/issues/91883): enable rewrite
# for CPU once the vectorizer crash is fixed..
for platform in ['cuda', 'rocm', 'tpu']:
register_lowering(
partial(cumred_reduce_window_impl, reduce_window_fn), platform
)

# TODO(https://github.com/llvm/llvm-project/issues/91883) Re-enable rewrite
# for CPU once the vectorizer crash is fixed..
register_lowering(partial(associative_scan, reduce_fn), 'cpu')
register_lowering(partial(associative_scan, reduce_fn))
else:
# Older XLA versions only have this rewrite for TPU.
register_lowering(
Expand Down

0 comments on commit 72a81e5

Please sign in to comment.