Skip to content

Commit

Permalink
random_test: fix deprecation warnings for key tests
Browse files Browse the repository at this point in the history
Some versions of numpy on some platforms raise warnings when custom PRNG keys
are passed to np.assert_array_equal. Address this by creating a specific function
for comparing key values.
  • Loading branch information
jakevdp committed Oct 6, 2023
1 parent f0e4ea2 commit b407620
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 24 deletions.
10 changes: 0 additions & 10 deletions tests/random_lax_test.py
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from functools import partial
import math
from unittest import SkipTest
Expand Down Expand Up @@ -45,7 +44,6 @@
int_dtypes = jtu.dtypes.all_integer
uint_dtypes = jtu.dtypes.all_unsigned

KEY_CTORS = [random.key, random.PRNGKey]

@jtu.with_config(jax_legacy_prng_key='allow')
class LaxRandomTest(jtu.JaxTestCase):
Expand Down Expand Up @@ -1166,14 +1164,6 @@ def testLogNormal(self, sigma, dtype):
for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)

@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_copy(self, make_key):
key = make_key(8459302)
self.assertArraysEqual(key, key.copy())
self.assertArraysEqual(key, copy.copy(key))
self.assertArraysEqual(key, copy.deepcopy(key))
self.assertArraysEqual(key, jax.jit(lambda k: k.copy())(key))


threefry_seed = prng_internal.threefry_seed
threefry_split = prng_internal.threefry_split
Expand Down
44 changes: 30 additions & 14 deletions tests/random_test.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import enum
from functools import partial
import math
Expand Down Expand Up @@ -585,6 +586,10 @@ class KeyArrayTest(jtu.JaxTestCase):
# might also be a more general test of opaque element types. If
# so, add a corresponding test to to CustomElementTypesTest as well.

def assertKeysEqual(self, key1, key2):
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(random.key_data(key1), random.key_data(key2))

def test_construction(self):
key = random.key(42)
self.assertIsInstance(key, jax_random.PRNGKeyArray)
Expand Down Expand Up @@ -662,6 +667,13 @@ def test_key_attributes(self):
self.assertEqual(key.size, math.prod(key.shape))
self.assertEqual(key.ndim, len(key.shape))

def test_key_copy(self):
key = self.make_keys()
self.assertKeysEqual(key, key.copy())
self.assertKeysEqual(key, copy.copy(key))
self.assertKeysEqual(key, copy.deepcopy(key))
self.assertKeysEqual(key, jax.jit(lambda k: k.copy())(key))

def test_isinstance(self):
@jax.jit
def f(k):
Expand Down Expand Up @@ -916,19 +928,19 @@ def test_device_put(self):
device = jax.devices()[0]
keys = self.make_keys(4)
keys_on_device = jax.device_put(keys, device)
self.assertArraysEqual(keys, keys_on_device)
self.assertKeysEqual(keys, keys_on_device)

def test_device_put_sharded(self):
devices = jax.devices()
keys = self.make_keys(len(devices))
keys_on_device = jax.device_put_sharded(list(keys), devices)
self.assertArraysEqual(keys, keys_on_device)
self.assertKeysEqual(keys, keys_on_device)

def test_device_put_replicated(self):
devices = jax.devices()
key = self.make_keys()
keys_on_device = jax.device_put_replicated(key, devices)
self.assertArraysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)
self.assertKeysEqual(jnp.broadcast_to(key, keys_on_device.shape), keys_on_device)

def test_make_array_from_callback(self):
devices = jax.devices()
Expand All @@ -940,7 +952,7 @@ def callback(index):
return jax.vmap(random.key)(i)
result = jax.make_array_from_callback(shape, sharding, callback)
expected = jax.vmap(random.key)(jnp.arange(len(devices)))
self.assertArraysEqual(result, expected)
self.assertKeysEqual(result, expected)

def test_make_array_from_single_device_arrays(self):
devices = jax.devices()
Expand All @@ -950,7 +962,7 @@ def test_make_array_from_single_device_arrays(self):
keys = random.split(random.key(0), len(devices))
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]
result = jax.make_array_from_single_device_arrays(shape, sharding, arrays)
self.assertArraysEqual(result, keys)
self.assertKeysEqual(result, keys)

def test_key_array_custom_jvp(self):
def f_raw(x, key):
Expand Down Expand Up @@ -1034,7 +1046,7 @@ def test_delete(self):
def test_async(self):
key = self.make_keys(10)

self.assertArraysEqual(key, key.block_until_ready())
self.assertKeysEqual(key, key.block_until_ready())
self.assertIsNone(key.copy_to_host_async())

# -- key construction and un/wrapping with impls
Expand Down Expand Up @@ -1071,7 +1083,7 @@ def test_key_make_like_other_key(self, prng_name):
k1 = jax.random.key(42, impl=prng_name)
impl = jax.random.key_impl(k1)
k2 = jax.random.key(42, impl=impl)
self.assertArraysEqual(k1, k2)
self.assertKeysEqual(k1, k2)
self.assertEqual(k1.dtype, k2.dtype)

@jtu.sample_product(prng_name=[name for name, _ in PRNG_IMPLS])
Expand All @@ -1082,7 +1094,7 @@ def test_key_wrap_like_other_key(self, prng_name):
data = jax.random.key_data(k1)
impl = jax.random.key_impl(k1)
k2 = jax.random.wrap_key_data(data, impl=impl)
self.assertArraysEqual(k1, k2)
self.assertKeysEqual(k1, k2)
self.assertEqual(k1.dtype, k2.dtype)

def test_key_impl_from_string_error(self):
Expand Down Expand Up @@ -1136,6 +1148,10 @@ def _double_threefry_fold_in(key, data):


class JnpWithKeyArrayTest(jtu.JaxTestCase):
def assertKeysEqual(self, key1, key2):
self.assertEqual(key1.dtype, key2.dtype)
self.assertArraysEqual(random.key_data(key1), random.key_data(key2))

def check_shape(self, func, *args):
like = lambda keys: jnp.ones(keys.shape)
out_key = func(*args)
Expand Down Expand Up @@ -1257,15 +1273,15 @@ def test_stack(self):

def test_array(self):
key = random.key(123)
self.assertArraysEqual(key, jnp.array(key))
self.assertArraysEqual(key, jnp.asarray(key))
self.assertArraysEqual(key, jax.jit(jnp.array)(key))
self.assertArraysEqual(key, jax.jit(jnp.asarray)(key))
self.assertKeysEqual(key, jnp.array(key))
self.assertKeysEqual(key, jnp.asarray(key))
self.assertKeysEqual(key, jax.jit(jnp.array)(key))
self.assertKeysEqual(key, jax.jit(jnp.asarray)(key))

def test_array_user_dtype(self):
key = random.key(123)
self.assertArraysEqual(key, jnp.array(key, dtype=key.dtype))
self.assertArraysEqual(key, jnp.asarray(key, dtype=key.dtype))
self.assertKeysEqual(key, jnp.array(key, dtype=key.dtype))
self.assertKeysEqual(key, jnp.asarray(key, dtype=key.dtype))

@parameterized.parameters([
(0,),
Expand Down

0 comments on commit b407620

Please sign in to comment.