Skip to content

Commit

Permalink
random_test: add tests of random values for distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 23, 2022
1 parent e0d3946 commit e68e87a
Showing 1 changed file with 142 additions and 0 deletions.
142 changes: 142 additions & 0 deletions tests/random_test.py
Expand Up @@ -15,6 +15,8 @@

from functools import partial
from unittest import SkipTest, skipIf
from typing import Any, Tuple, NamedTuple, Optional
import zlib

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -56,6 +58,121 @@ def _prng_key_as_array(key):
('unsafe_rbg', prng.unsafe_rbg_prng_impl)]


class RandomValuesCase(NamedTuple):
name: str
prng_impl: str
shape: Tuple[int]
dtype: Any
params: dict
expected: np.ndarray
skip_on_x64: bool = False
atol: Optional[float] = None
rtol: Optional[float] = None

def _testname(self):
if self.dtype is None:
shape_dtype = str(self.shape)
else:
shape_dtype = jtu.format_shape_dtype_string(self.shape, self.dtype)
name = f"_{self.name}_{self.prng_impl}_{shape_dtype}"
if self.params:
fmt = lambda x: str(x).replace(' ', '').replace('\n', '')
name += "_" + "_".join(f"{k}={fmt(v)}" for k, v in self.params.items())
return name

def _seed(self):
# Generate a deterministic unique 32-bit seed given the name and prng impl
return zlib.adler32((self.name + self.prng_impl).encode())


_RANDOM_VALUES_CASES = [
# TODO(jakevdp) add coverage for other distributions.
RandomValuesCase("bernoulli", "threefry2x32", (5,), None, {'p': 0.5},
np.array([False, True, True, True, False]), skip_on_x64=True),
RandomValuesCase("bernoulli", "rbg", (5,), None, {'p': 0.5},
np.array([True, True, True, True, True]), skip_on_x64=True),
RandomValuesCase("beta", "threefry2x32", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.533685, 0.843179, 0.063495, 0.573444, 0.459514], dtype='float32')),
RandomValuesCase("beta", "rbg", (5,), np.float32, {'a': 0.8, 'b': 0.9},
np.array([0.841308, 0.669989, 0.731763, 0.985127, 0.022745], dtype='float32')),
RandomValuesCase("cauchy", "threefry2x32", (5,), np.float32, {},
np.array([ -0.088416, -10.169713, 3.49677, -1.18056, 0.34556], dtype='float32'), rtol=1E-5),
RandomValuesCase("cauchy", "rbg", (5,), np.float32, {},
np.array([0.008389, 0.108793, -0.031826, -0.01876, 0.963218], dtype='float32')),
RandomValuesCase("dirichlet", "threefry2x32", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.556287, 0.304219, 0.139494], [0.15221 , 0.632251, 0.21554]], dtype='float32')),
RandomValuesCase("dirichlet", "rbg", (2,), np.float32, {'alpha': np.array([0.5, 0.6, 0.7], dtype='float32')},
np.array([[0.024769, 0.002189, 0.973041], [0.326, 0.00244, 0.67156]], dtype='float32')),
RandomValuesCase("double_sided_maxwell", "threefry2x32", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([-2.408914, -3.370437, 3.235352, -0.907734, -1.708732], dtype='float32'), skip_on_x64=True),
RandomValuesCase("double_sided_maxwell", "rbg", (5,), np.float32, {"loc": 1, "scale": 2},
np.array([4.957495, 3.003086, 5.33935, 2.942878, -1.203524], dtype='float32'), skip_on_x64=True),
RandomValuesCase("exponential", "threefry2x32", (5,), np.float32, {},
np.array([0.526067, 0.043046, 0.039932, 0.46427 , 0.123886], dtype='float32')),
RandomValuesCase("exponential", "rbg", (5,), np.float32, {},
np.array([0.231303, 0.684814, 0.017181, 0.089552, 0.345087], dtype='float32')),
RandomValuesCase("gamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([0.332641, 0.10187 , 1.816109, 0.023457, 0.487853], dtype='float32')),
RandomValuesCase("gamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([0.235293, 0.446747, 0.146372, 0.79252 , 0.294762], dtype='float32')),
RandomValuesCase("gumbel", "threefry2x32", (5,), np.float32, {},
np.array([2.06701, 0.911726, 0.145736, 0.185427, -0.00711], dtype='float32')),
RandomValuesCase("gumbel", "rbg", (5,), np.float32, {},
np.array([-0.099308, -1.123809, 1.007618, -0.077968, 3.421349], dtype='float32')),
RandomValuesCase("laplace", "threefry2x32", (5,), np.float32, {},
np.array([0.578939, -0.204902, 0.555733, 0.911053, -0.96456], dtype='float32')),
RandomValuesCase("laplace", "rbg", (5,), np.float32, {},
np.array([-2.970422, 1.925082, -0.757887, -4.444797, 0.561983], dtype='float32')),
RandomValuesCase("loggamma", "threefry2x32", (5,), np.float32, {'a': 0.8},
np.array([-0.899633, -0.424083, 0.631593, 0.102374, -1.07189], dtype='float32')),
RandomValuesCase("loggamma", "rbg", (5,), np.float32, {'a': 0.8},
np.array([-1.333825, 0.287259, -0.343074, -0.998258, -0.773598], dtype='float32')),
RandomValuesCase("logistic", "threefry2x32", (5,), np.float32, {},
np.array([0.19611, -1.709053, -0.274093, -0.208322, -1.675489], dtype='float32')),
RandomValuesCase("logistic", "rbg", (5,), np.float32, {},
np.array([-0.234923, -0.545184, 0.700992, -0.708609, -1.474884], dtype='float32')),
RandomValuesCase("maxwell", "threefry2x32", (5,), np.float32, {},
np.array([3.070779, 0.908479, 1.521317, 0.875551, 1.306137], dtype='float32')),
RandomValuesCase("maxwell", "rbg", (5,), np.float32, {},
np.array([2.048746, 0.470027, 1.053105, 1.01969, 2.710645], dtype='float32')),
RandomValuesCase("multivariate_normal", "threefry2x32", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[ 1.067826, 1.215599, 0.234166], [-0.237534, 1.32591, 1.413987]], dtype='float32'), skip_on_x64=True),
RandomValuesCase("multivariate_normal", "rbg", (2,), np.float32, {"mean": np.ones((1, 3)), "cov": np.eye(3)},
np.array([[-0.036897, 0.770969, 0.756959], [1.755091, 2.350553, 0.627142]], dtype='float32'), skip_on_x64=True),
RandomValuesCase("normal", "threefry2x32", (5,), np.float32, {},
np.array([-1.173234, -1.511662, 0.070593, -0.099764, 1.052845], dtype='float32')),
RandomValuesCase("normal", "rbg", (5,), np.float32, {},
np.array([-0.479658, 0.565747, -1.065106, 0.997962, -1.478002], dtype='float32')),
RandomValuesCase("pareto", "threefry2x32", (5,), np.float32, {"b": 0.5},
np.array([2.751398, 1.281863, 87.85448, 1.254542, 2.824487], dtype='float32')),
RandomValuesCase("pareto", "rbg", (5,), np.float32, {"b": 0.5},
np.array([1.241914, 1.521864, 5.615384, 1911.502, 1.816702], dtype='float32')),
RandomValuesCase("poisson", "threefry2x32", (5,), np.int32, {"lam": 5},
np.array([7, 3, 6, 11, 6], dtype='int32')),
# Note: poisson not implemented for rbg sampler.
RandomValuesCase("rademacher", "threefry2x32", (5,), np.int32, {},
np.array([-1, -1, -1, -1, 1], dtype='int32'), skip_on_x64=True),
RandomValuesCase("rademacher", "rbg", (5,), np.int32, {},
np.array([1, 1, 1, -1, -1], dtype='int32'), skip_on_x64=True),
RandomValuesCase("randint", "threefry2x32", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([0, 5, 7, 7, 5], dtype='int32')),
RandomValuesCase("randint", "rbg", (5,), np.int32, {"minval": 0, "maxval": 10},
np.array([7, 1, 8, 5, 8], dtype='int32')),
RandomValuesCase("truncated_normal", "threefry2x32", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.582807, 1.709771, 0.159513, 0.861376, 0.36148], dtype='float32')),
RandomValuesCase("truncated_normal", "rbg", (5,), np.float32, {"lower": 0, "upper": 2},
np.array([0.770068, 1.516464, 0.710406, 0.762801, 1.305324], dtype='float32')),
RandomValuesCase("uniform", "threefry2x32", (5,), np.float32, {},
np.array([0.298671, 0.073213, 0.873356, 0.260549, 0.412797], dtype='float32')),
RandomValuesCase("uniform", "rbg", (5,), np.float32, {},
np.array([0.477161, 0.706508, 0.656261, 0.432547, 0.057772], dtype='float32')),
RandomValuesCase("weibull_min", "threefry2x32", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.605863, 0.841809, 0.224218, 0.4826 , 0.027901], dtype='float32')),
RandomValuesCase("weibull_min", "rbg", (5,), np.float32, {"scale": 1, "concentration": 1},
np.array([1.370903, 0.086532, 0.061688, 3.407599, 0.215077], dtype='float32')),
]


class PrngTest(jtu.JaxTestCase):

def testThreefry2x32(self):
Expand Down Expand Up @@ -172,6 +289,31 @@ def testRngRandomBitsViewProperty(self):
rand_bits_32 = np.array([np.array(r).view(np.uint32) for r in rand_bits])
assert np.all(rand_bits_32 == rand_bits_32[0])


@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": case._testname(), "case": case}
for case in _RANDOM_VALUES_CASES))
@jtu.skip_on_devices("tpu") # TPU precision causes issues.
def testRandomDistributionValues(self, case):
"""
Tests values output by various distributions. This will catch any unintentional
changes to the implementations that could result in different random sequences.
Any refactoring of random distributions that leads to non-trivial differences in
this test should involve a deprecation cycle following the procedures outlined at
https://jax.readthedocs.io/en/latest/api_compatibility.html
"""
if config.x64_enabled and case.skip_on_x64:
self.skipTest("test produces different values when jax_enable_x64=True")
with jax.default_prng_impl(case.prng_impl):
func = getattr(random, case.name)
key = random.PRNGKey(case._seed())
if case.dtype:
actual = func(key, **case.params, shape=case.shape, dtype=case.dtype)
else:
actual = func(key, **case.params, shape=case.shape)
self.assertAllClose(actual, case.expected, atol=case.atol, rtol=case.rtol)

def testPRNGValues(self):
# Test to ensure consistent random values between JAX versions
k = random.PRNGKey(0)
Expand Down

0 comments on commit e68e87a

Please sign in to comment.