Skip to content

Commit

Permalink
Add dynamic slice U8 index test
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 23, 2021
1 parent 92e8ec4 commit 27fc797
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions tests/lax_test.py
Expand Up @@ -2424,6 +2424,17 @@ def f(inputs):
self.assertArraysEqual(np.full((1, 30), np.float32(42)),
f(np.zeros((1, 24), dtype=np.float32)))

def testDynamicSliceU8Index(self):
# Regression test for u8 index in dynamic-slice (#6122)
# TODO(b/183216273): enable this test for CPU & GPU when possible.
if jtu.device_under_test() == "cpu":
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on CPU.")
if jtu.device_under_test() == "gpu":
raise unittest.SkipTest("DynamicSliceU8Index test is a known failure on GPU.")
x = np.arange(200)
np.testing.assert_equal(
np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128])


class LazyConstantTest(jtu.JaxTestCase):
def _Check(self, make_const, expected):
Expand Down

0 comments on commit 27fc797

Please sign in to comment.