Skip to content

Commit

Permalink
Remove code that skips array PRNG tests on CUDA.
Browse files Browse the repository at this point in the history
#13037 added this skip, but I have no idea why. The test seems to pass on GPU.

PiperOrigin-RevId: 568216977
  • Loading branch information
hawkinsp authored and jax authors committed Sep 25, 2023
1 parent 3a9289e commit d3f5e7f
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions tests/array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,6 @@ class RngShardingTest(jtu.JaxTestCase):
# tests that the PRNGs are automatically sharded as expected

@parameterized.named_parameters(("3", 3), ("4", 4), ("5", 5))
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_1d(self, num_devices):
@jax.jit
def f(x):
Expand Down Expand Up @@ -1159,7 +1158,6 @@ def f(x):
"mesh_shape": mesh_shape, "pspec": pspec}
for mesh_shape in [(3, 2), (4, 2), (2, 3)]
for pspec in [P('x', None), P(None, 'y'), P('x', 'y')])
@jtu.skip_on_devices("gpu")
def test_random_bits_is_pure_map_2d(self, mesh_shape, pspec):
@jax.jit
def f(x):
Expand Down

0 comments on commit d3f5e7f

Please sign in to comment.