diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index 2bd41c5a81ab..0791b9507410 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from functools import partial import math from unittest import SkipTest @@ -45,7 +44,6 @@ int_dtypes = jtu.dtypes.all_integer uint_dtypes = jtu.dtypes.all_unsigned -KEY_CTORS = [random.key, random.PRNGKey] @jtu.with_config(jax_legacy_prng_key='allow') class LaxRandomTest(jtu.JaxTestCase): @@ -1166,14 +1164,6 @@ def testLogNormal(self, sigma, dtype): for samples in [uncompiled_samples, compiled_samples]: self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf) - @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) - def test_copy(self, make_key): - key = make_key(8459302) - self.assertArraysEqual(key, key.copy()) - self.assertArraysEqual(key, copy.copy(key)) - self.assertArraysEqual(key, copy.deepcopy(key)) - self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key)) - threefry_seed = prng_internal.threefry_seed threefry_split = prng_internal.threefry_split diff --git a/tests/random_test.py b/tests/random_test.py index 4345a31de3b7..8d58bdd5bde1 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import enum from functools import partial import math @@ -585,6 +586,10 @@ class KeyArrayTest(jtu.JaxTestCase): # might also be a more general test of opaque element types. If # so, add a corresponding test to to CustomElementTypesTest as well. + def assertKeysEqual(self, key1, key2): + self.assertEqual(key1.dtype, key2.dtype) + self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def test_construction(self): key = random.key(42) self.assertIsInstance(key, jax_random.PRNGKeyArray) @@ -662,6 +667,13 @@ def test_key_attributes(self): self.assertEqual(key.size, math.prod(key.shape)) self.assertEqual(key.ndim, len(key.shape)) + def test_key_copy(self): + key = self.make_keys() + self.assertKeysEqual(key, key.copy()) + self.assertKeysEqual(key, copy.copy(key)) + self.assertKeysEqual(key, copy.deepcopy(key)) + self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key)) + def test_isinstance(self): @jax.jit def f(k): @@ -916,19 +928,19 @@ def test_device_put(self): device = jax.devices()[0] keys = self.make_keys(4) keys_on_device = jax.device_put(keys, device) - self.assertArraysEqual(keys, keys_on_device) + self.assertKeysEqual(keys, keys_on_device) def test_device_put_sharded(self): devices = jax.devices() keys = self.make_keys(len(devices)) keys_on_device = jax.device_put_sharded(list(keys), devices) - self.assertArraysEqual(keys, keys_on_device) + self.assertKeysEqual(keys, keys_on_device) def test_device_put_replicated(self): devices = jax.devices() key = self.make_keys() keys_on_device = jax.device_put_replicated(key, devices) - self.assertArraysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device) + self.assertKeysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device) def test_make_array_from_callback(self): devices = jax.devices() @@ -940,7 +952,7 @@ def callback(index): return jax.vmap(random.key)(i) result = jax.make_array_from_callback(shape, sharding, callback) expected = jax.vmap(random.key)(jnp.arange(len(devices))) - self.assertArraysEqual(result, expected) + self.assertKeysEqual(result, expected) def test_make_array_from_single_device_arrays(self): devices = jax.devices() @@ -950,7 +962,7 @@ def test_make_array_from_single_device_arrays(self): keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] result = jax.make_array_from_single_device_arrays(shape, sharding, arrays) - self.assertArraysEqual(result, keys) + self.assertKeysEqual(result, keys) def test_key_array_custom_jvp(self): def f_raw(x, key): @@ -1034,7 +1046,7 @@ def test_delete(self): def test_async(self): key = self.make_keys(10) - self.assertArraysEqual(key, key.block_until_ready()) + self.assertKeysEqual(key, key.block_until_ready()) self.assertIsNone(key.copy_to_host_async()) # -- key construction and un/wrapping with impls @@ -1071,7 +1083,7 @@ def test_key_make_like_other_key(self, prng_name): k1 = jax.random.key(42, impl=prng_name) impl = jax.random.key_impl(k1) k2 = jax.random.key(42, impl=impl) - self.assertArraysEqual(k1, k2) + self.assertKeysEqual(k1, k2) self.assertEqual(k1.dtype, k2.dtype) @jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS]) @@ -1082,7 +1094,7 @@ def test_key_wrap_like_other_key(self, prng_name): data = jax.random.key_data(k1) impl = jax.random.key_impl(k1) k2 = jax.random.wrap_key_data(data, impl=impl) - self.assertArraysEqual(k1, k2) + self.assertKeysEqual(k1, k2) self.assertEqual(k1.dtype, k2.dtype) def test_key_impl_from_string_error(self): @@ -1136,6 +1148,10 @@ def _double_threefry_fold_in(key, data): class JnpWithKeyArrayTest(jtu.JaxTestCase): + def assertKeysEqual(self, key1, key2): + self.assertEqual(key1.dtype, key2.dtype) + self.assertArraysEqual(random.key_data(key1), random.key_data(key2)) + def check_shape(self, func, *args): like = lambda keys: jnp.ones(keys.shape) out_key = func(*args) @@ -1257,15 +1273,15 @@ def test_stack(self): def test_array(self): key = random.key(123) - self.assertArraysEqual(key, jnp.array(key)) - self.assertArraysEqual(key, jnp.asarray(key)) - self.assertArraysEqual(key, jax.jit(jnp.array)(key)) - self.assertArraysEqual(key, jax.jit(jnp.asarray)(key)) + self.assertKeysEqual(key, jnp.array(key)) + self.assertKeysEqual(key, jnp.asarray(key)) + self.assertKeysEqual(key, jax.jit(jnp.array)(key)) + self.assertKeysEqual(key, jax.jit(jnp.asarray)(key)) def test_array_user_dtype(self): key = random.key(123) - self.assertArraysEqual(key, jnp.array(key, dtype=key.dtype)) - self.assertArraysEqual(key, jnp.asarray(key, dtype=key.dtype)) + self.assertKeysEqual(key, jnp.array(key, dtype=key.dtype)) + self.assertKeysEqual(key, jnp.asarray(key, dtype=key.dtype)) @parameterized.parameters([ (0,),