Skip to content

Commit

Permalink
Remove type restrictions
Browse files Browse the repository at this point in the history
We support s8, u8, s16, u16, half16 on TPU
  • Loading branch information
majnemer committed Aug 10, 2020
1 parent eb086f9 commit 6bab85b
Showing 1 changed file with 6 additions and 29 deletions.
35 changes: 6 additions & 29 deletions tests/random_test.py
Expand Up @@ -126,9 +126,6 @@ def testRngRandomBitsViewProperty(self):
N = 10
key = random.PRNGKey(1701)
nbits = [8, 16, 32]
if jtu.device_under_test() == "tpu":
# U8 and U16 are not supported on TPU.
nbits = [32]
rand_bits = [random._random_bits(key, n, (N * 64 // n,)) for n in nbits]
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])
Expand All @@ -137,15 +134,13 @@ def testRngRandomBits(self):
# Test specific outputs to ensure consistent random values between JAX versions.
key = random.PRNGKey(1701)

# U8 and U16 are not supported on TPU.
if jtu.device_under_test() != "tpu":
bits8 = random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)
bits8 = random._random_bits(key, 8, (3,))
expected8 = np.array([216, 115, 43], dtype=np.uint8)
self.assertArraysEqual(bits8, expected8)

bits16 = random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)
bits16 = random._random_bits(key, 16, (3,))
expected16 = np.array([41682, 1300, 55017], dtype=np.uint16)
self.assertArraysEqual(bits16, expected16)

bits32 = random._random_bits(key, 32, (3,))
expected32 = np.array([56197195, 4200222568, 961309823], dtype=np.uint32)
Expand All @@ -163,8 +158,6 @@ def testRngRandomBits(self):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testRngUniform(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.uniform() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.uniform(key, (10000,), dtype)
crand = api.jit(rand)
Expand All @@ -180,8 +173,6 @@ def testRngUniform(self, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in int_dtypes + uint_dtypes))
def testRngRandint(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.randint() not supported on TPU for 8- or 16-bit types.")
lo = 5
hi = 10

Expand All @@ -200,8 +191,6 @@ def testRngRandint(self, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
def testNormal(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.normal() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.normal(key, (10000,), dtype)
crand = api.jit(rand)
Expand Down Expand Up @@ -380,8 +369,6 @@ def testBeta(self, a, b, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in [np.float16, np.float32, np.float64]))
def testCauchy(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.cauchy() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.cauchy(key, (10000,), dtype)
crand = api.jit(rand)
Expand Down Expand Up @@ -418,8 +405,6 @@ def testDirichlet(self, alpha, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testExponential(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.exponential() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.exponential(key, (10000,), dtype)
crand = api.jit(rand)
Expand Down Expand Up @@ -484,8 +469,6 @@ def testGammaGradType(self):
for lam in [0.5, 3, 9, 11, 50, 500]
for dtype in [np.int16, np.int32, np.int64]))
def testPoisson(self, lam, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.poisson() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key, lam: random.poisson(key, lam, (10000,), dtype)
crand = api.jit(rand)
Expand Down Expand Up @@ -530,8 +513,6 @@ def testGumbel(self, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testLaplace(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.laplace() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.laplace(key, (10000,), dtype)
crand = api.jit(rand)
Expand All @@ -546,8 +527,6 @@ def testLaplace(self, dtype):
{"testcase_name": "_dtype={}".format(np.dtype(dtype).name), "dtype": dtype}
for dtype in float_dtypes))
def testLogistic(self, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.logistic() not supported on TPU for 16-bit types.")
key = random.PRNGKey(0)
rand = lambda key: random.logistic(key, (10000,), dtype)
crand = api.jit(rand)
Expand Down Expand Up @@ -602,8 +581,6 @@ def testT(self, df, dtype):
for dim in [1, 3, 5]
for dtype in float_dtypes))
def testMultivariateNormal(self, dim, dtype):
if jtu.device_under_test() == "tpu" and jnp.dtype(dtype).itemsize < 3:
raise SkipTest("random.multivariate_normal() not supported on TPU for 16-bit types.")
r = np.random.RandomState(dim)
mean = r.randn(dim)
cov_factor = r.randn(dim, dim)
Expand Down

0 comments on commit 6bab85b

Please sign in to comment.