diff --git a/tests/random_test.py b/tests/random_test.py index 543172de18fd..79b314c96b9b 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -126,9 +126,6 @@ def testRngRandomBitsViewProperty(self): N = 10 key = random.PRNGKey(1701) nbits = [8, 16, 32] - if jtu.device_under_test() == "tpu": - # U8 and U16 are not supported on TPU. - nbits = [32] rand_bits = [random._random_bits(key, n, (N * 64 // n,)) for n in nbits] rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits]) assert np.all(rand_bits_32 == rand_bits_32[0]) @@ -137,15 +134,13 @@ def testRngRandomBits(self): # Test specific outputs to ensure consistent random values between JAX versions. key = random.PRNGKey(1701) - # U8 and U16 are not supported on TPU. - if jtu.device_under_test() != "tpu": - bits8 = random._random_bits(key, 8, (3,)) - expected8 = np.array([216, 115, 43], dtype=np.uint8) - self.assertArraysEqual(bits8, expected8) + bits8 = random._random_bits(key, 8, (3,)) + expected8 = np.array([216, 115, 43], dtype=np.uint8) + self.assertArraysEqual(bits8, expected8) - bits16 = random._random_bits(key, 16, (3,)) - expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) - self.assertArraysEqual(bits16, expected16) + bits16 = random._random_bits(key, 16, (3,)) + expected16 = np.array([41682, 1300, 55017], dtype=np.uint16) + self.assertArraysEqual(bits16, expected16) bits32 = random._random_bits(key, 32, (3,)) expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32) @@ -163,8 +158,6 @@ def testRngRandomBits(self): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testRngUniform(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.uniform() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.uniform(key, (10000,), dtype) crand = api.jit(rand) @@ -180,8 +173,6 @@ def testRngUniform(self, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in int_dtypes + uint_dtypes)) def testRngRandint(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.randint() not supported on TPU for 8- or 16-bit types.") lo = 5 hi = 10 @@ -200,8 +191,6 @@ def testRngRandint(self, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in [np.float16, np.float32, np.float64])) def testNormal(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.normal() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.normal(key, (10000,), dtype) crand = api.jit(rand) @@ -380,8 +369,6 @@ def testBeta(self, a, b, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in [np.float16, np.float32, np.float64])) def testCauchy(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.cauchy() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.cauchy(key, (10000,), dtype) crand = api.jit(rand) @@ -418,8 +405,6 @@ def testDirichlet(self, alpha, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testExponential(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.exponential() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.exponential(key, (10000,), dtype) crand = api.jit(rand) @@ -484,8 +469,6 @@ def testGammaGradType(self): for lam in [0.5, 3, 9, 11, 50, 500] for dtype in [np.int16, np.int32, np.int64])) def testPoisson(self, lam, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.poisson() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype) crand = api.jit(rand) @@ -530,8 +513,6 @@ def testGumbel(self, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testLaplace(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.laplace() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.laplace(key, (10000,), dtype) crand = api.jit(rand) @@ -546,8 +527,6 @@ def testLaplace(self, dtype): {"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype} for dtype in float_dtypes)) def testLogistic(self, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.logistic() not supported on TPU for 16-bit types.") key = random.PRNGKey(0) rand = lambda key: random.logistic(key, (10000,), dtype) crand = api.jit(rand) @@ -602,8 +581,6 @@ def testT(self, df, dtype): for dim in [1, 3, 5] for dtype in float_dtypes)) def testMultivariateNormal(self, dim, dtype): - if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3: - raise SkipTest("random.multivariate_normal() not supported on TPU for 16-bit types.") r = np.random.RandomState(dim) mean = r.randn(dim) cov_factor = r.randn(dim, dim)