Skip to content

Commit

Permalink
Add regression test for lax.rev simplification error
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 476430486
  • Loading branch information
Jake VanderPlas authored and jax authors committed Sep 23, 2022
1 parent ecb27a9 commit a6b24b3
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3438,6 +3438,15 @@ def test_gather(self):
self.assertIsInstance(ys, FooArray)
self.assertEqual(ys.shape, (3, 2, 1))

def test_xla_reverse_bug(self):
# Regression test for b/248295786
# This was an XLA bug related to an incorrect optimization of reverse
def f(x):
y = jnp.array([2, 5])
return lax.rev(x * y, (0,))
x = jnp.array([1, 2])
self.assertArraysEqual(f(x), jax.jit(f)(x))

# TODO(frostig,mattjj): more polymorphic primitives tests

if __name__ == '__main__':
Expand Down

0 comments on commit a6b24b3

Please sign in to comment.