From 75b992f30bc0a852c64f576691c56cfd379bf9ed Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Wed, 11 May 2022 16:11:30 +0800 Subject: [PATCH 1/9] add @wraps(np.random.xxx) for API documentation --- brainpy/math/random.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index f0dc18f18..b4eea4d22 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -7,6 +7,8 @@ from brainpy.math.jaxarray import JaxArray, Variable +from .utils import wraps + __all__ = [ 'RandomState', @@ -207,28 +209,34 @@ def lognormallognormal(self, mean=0.0, sigma=1.0, size=None): DEFAULT = RandomState(np.random.randint(0, 10000, size=2, dtype=np.uint32)) +@wraps(np.random.seed) def seed(seed=None): global DEFAULT DEFAULT.seed(np.random.randint(0, 100000) if seed is None else seed) +@wraps(np.random.rand) def rand(*dn): return JaxArray(jr.uniform(DEFAULT.split_key(), shape=dn, minval=0., maxval=1.)) +@wraps(np.random.randint) def randint(low, high=None, size=None, dtype=int): return JaxArray(jr.randint(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) +@wraps(np.random.randn) def randn(*dn): return JaxArray(jr.normal(DEFAULT.split_key(), shape=dn)) +@wraps(np.random.random) def random(size=None): return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=0., maxval=1.)) +@wraps(np.random.random_sample) def random_sample(size=None): return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=0., maxval=1.)) @@ -237,87 +245,105 @@ def random_sample(size=None): sample = random_sample +@wraps(np.random.choice) def choice(a, size=None, replace=True, p=None): a = a.value if isinstance(a, JaxArray) else a return JaxArray(jr.choice(DEFAULT.split_key(), a=a, shape=_size2shape(size), replace=replace, p=p)) +@wraps(np.random.permutation) def permutation(x): x = x.value if isinstance(x, JaxArray) else x return JaxArray(jr.permutation(DEFAULT.split_key(), x)) +@wraps(np.random.shuffle) def shuffle(x, axis=0): assert isinstance(x, JaxArray), f'Must be a JaxArray, but got {type(x)}' x.value = jr.permutation(DEFAULT.split_key(), x.value, axis=axis) +@wraps(np.random.beta) def beta(a, b, size=None): a = a.value if isinstance(a, JaxArray) else a b = b.value if isinstance(b, JaxArray) else b return JaxArray(jr.beta(DEFAULT.split_key(), a=a, b=b, shape=_size2shape(size))) +@wraps(np.random.exponential) def exponential(scale=1.0, size=None): assert scale == 1. return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.gamma) def gamma(shape, scale=1.0, size=None): assert scale == 1. return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size))) +@wraps(np.random.gumbel) def gumbel(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.gumbel(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.laplace) def laplace(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.logistic) def logistic(loc=0.0, scale=1.0, size=None): assert loc == 0. assert scale == 1. return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.normal) def normal(loc=0.0, scale=1.0, size=None): return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) +@wraps(np.random.pareto) def pareto(a, size=None): return JaxArray(jr.pareto(DEFAULT.split_key(), b=a, shape=_size2shape(size))) +@wraps(np.random.poisson) def poisson(lam=1.0, size=None): return JaxArray(jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size))) +@wraps(np.random.standard_cauchy) def standard_cauchy(size=None): return JaxArray(jr.cauchy(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.standard_exponential) def standard_exponential(size=None): return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.standard_gamma) def standard_gamma(shape, size=None): return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size))) +@wraps(np.random.standard_normal) def standard_normal(size=None): return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size))) +@wraps(np.random.standard_t) def standard_t(df, size=None): return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size))) +@wraps(np.random.uniform) def uniform(low=0.0, high=1.0, size=None): return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high)) @@ -373,6 +399,7 @@ def bernoulli(p, size=None): return JaxArray(jr.bernoulli(DEFAULT.split_key(), p=p, shape=_size2shape(size))) +@wraps(np.random.lognormal) def lognormal(mean=0.0, sigma=1.0, size=None): samples = jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) samples = samples * sigma + mean From 72fd5791ee16bc145ae365ef9bd0559290839bf8 Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Wed, 11 May 2022 21:11:24 +0800 Subject: [PATCH 2/9] update some functions in random.py --- brainpy/math/random.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index b4eea4d22..9b4e4881e 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -34,6 +34,15 @@ def _size2shape(size): raise ValueError(f'Must be a list/tuple of int, but got {size}') +def _check_shape(name, shape, *param_shapes): + for s in param_shapes: + if s != shape: + msg = ("{} parameter shapes must be broadcast-compatible with shape " + "argument, and the result of broadcasting the shapes must equal " + "the shape argument, but got result {} for shape argument {}.") + raise ValueError(msg.format(name, s, shape)) + + class RandomState(Variable): """RandomState that track the random generator state. """ __slots__ = () @@ -222,6 +231,12 @@ def rand(*dn): @wraps(np.random.randint) def randint(low, high=None, size=None, dtype=int): + if high is None: + high = low + low = 0 + # todo: randint does not support multi-minval/maxval + if size is None: + size = np.shape(low) if len(np.shape(low)) >= len(np.shape(high)) else np.shape(high) return JaxArray(jr.randint(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) From c4dd9f6da33111ed8e530afb7f8564f047bc05e4 Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Wed, 11 May 2022 21:11:55 +0800 Subject: [PATCH 3/9] add tests for random sampling in random.py --- brainpy/math/tests/test_random.py | 56 +++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 brainpy/math/tests/test_random.py diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py new file mode 100644 index 000000000..4a4e75d87 --- /dev/null +++ b/brainpy/math/tests/test_random.py @@ -0,0 +1,56 @@ +import unittest + +import jax.numpy as jnp +import jax.random as jr +import brainpy.math as bm +import brainpy.math.random as br +import numpy.random as nr + +class TestRandom(unittest.TestCase): + def test_seed(self): + test_seed = 299 + br.seed(test_seed) + a = br.rand(3) + br.seed(test_seed) + b = br.rand(3) + self.assertTrue(bm.array_equal(a, b)) + + def test_rand(self): + a = br.rand(3, 2) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) + + def test_randint1(self): + a = br.randint(5, size=10) + self.assertTupleEqual(a.shape, (10,)) + self.assertTrue((a >= 0).all() and (a < 5).all()) + + def test_randint2(self): + a = br.randint(2, 6, size=(4, 3)) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertTrue((a >= 2).all() and (a < 6).all()) + + # def test_randint3(self): + # a = br.randint([1, 2, 3], [10, 7, 8], size=3) + # self.assertTupleEqual(a.shape, (3,)) + # self.assertTrue((a - bm.array([1, 2, 3]) >= 0).all() + # and (-a + bm.array([10, 7, 8]) > 0).all()) + + def test_randn(self): + a = br.randn(3, 2) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_random1(self): + a = br.random() + self.assertIsInstance(a, bm.jaxarray.JaxArray) + self.assertTrue(0. <= a < 1) + + def test_random2(self): + a = br.random(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) + + def test_random_sample(self): + a = br.random_sample(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a < 1).all()) \ No newline at end of file From c96984e08d084cb4a4da533b52462fd83cfad22f Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Thu, 12 May 2022 20:39:00 +0800 Subject: [PATCH 4/9] add tests for random sampling in random.py --- brainpy/math/tests/test_random.py | 195 ++++++++++++++++++++++++++++-- 1 file changed, 186 insertions(+), 9 deletions(-) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 4a4e75d87..256b291ef 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -21,20 +21,24 @@ def test_rand(self): self.assertTrue((a >= 0).all() and (a < 1).all()) def test_randint1(self): - a = br.randint(5, size=10) - self.assertTupleEqual(a.shape, (10,)) - self.assertTrue((a >= 0).all() and (a < 5).all()) + a = br.randint(5) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(0 <= a < 5) def test_randint2(self): a = br.randint(2, 6, size=(4, 3)) self.assertTupleEqual(a.shape, (4, 3)) self.assertTrue((a >= 2).all() and (a < 6).all()) - # def test_randint3(self): - # a = br.randint([1, 2, 3], [10, 7, 8], size=3) - # self.assertTupleEqual(a.shape, (3,)) - # self.assertTrue((a - bm.array([1, 2, 3]) >= 0).all() - # and (-a + bm.array([10, 7, 8]) > 0).all()) + def test_randint3(self): + a = br.randint([1, 2, 3], [10, 7, 8]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a - bm.array([1, 2, 3]) >= 0).all() + and (-a + bm.array([10, 7, 8]) > 0).all()) + + def test_randint4(self): + a = br.randint([1, 2, 3], [10, 7, 8], size=(2, 3)) + self.assertTupleEqual(a.shape, (2, 3)) def test_randn(self): a = br.randn(3, 2) @@ -53,4 +57,177 @@ def test_random2(self): def test_random_sample(self): a = br.random_sample(size=(3, 2)) self.assertTupleEqual(a.shape, (3, 2)) - self.assertTrue((a >= 0).all() and (a < 1).all()) \ No newline at end of file + self.assertTrue((a >= 0).all() and (a < 1).all()) + + def test_choice1(self): + a = bm.random.choice(5) + self.assertTupleEqual(jnp.shape(a), ()) + self.assertTrue(0 <= a < 5) + + def test_choice2(self): + a = bm.random.choice(5, 3, p=[0.1, 0.4, 0.2, 0., 0.3]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a >= 0).all() and (a < 5).all()) + + def test_choice3(self): + a = bm.random.choice(bm.arange(2, 20), size=(4, 3), replace=False) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertTrue((a >= 2).all() and (a < 20).all()) + self.assertEqual(len(bm.unique(a)), 12) + + def test_permutation1(self): + a = bm.random.permutation(10) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_permutation2(self): + a = bm.random.permutation(bm.arange(10)) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_shuffle1(self): + a = bm.arange(10) + bm.random.shuffle(a) + self.assertTupleEqual(a.shape, (10,)) + self.assertEqual(len(bm.unique(a)), 10) + + def test_shuffle2(self): + a = bm.arange(12).reshape(4, 3) + bm.random.shuffle(a, axis=1) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertEqual(len(bm.unique(a)), 12) + + # test that a is only shuffled along axis 1 + uni = bm.unique(bm.diff(a, axis=0)) + self.assertEqual(uni, bm.JaxArray([3])) + + def test_beta1(self): + a = bm.random.beta(2, 2) + self.assertTupleEqual(a.shape, ()) + + def test_beta2(self): + a = bm.random.beta([2, 2, 3], 2, size=(3,)) + self.assertTupleEqual(a.shape, (3,)) + + def test_exponential1(self): + a = bm.random.exponential(10., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_exponential2(self): + a = bm.random.exponential([1., 2., 5.]) + self.assertTupleEqual(a.shape, (3,)) + + def test_gamma(self): + a = bm.random.gamma(2, 10., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_gumbel(self): + a = bm.random.gumbel(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_laplace(self): + a = bm.random.laplace(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_logistic(self): + a = bm.random.logistic(0., 2., size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_normal1(self): + a = bm.random.normal() + self.assertTupleEqual(a.shape, ()) + + def test_normal2(self): + a = bm.random.normal(loc=[0., 2., 4.], scale=[1., 2., 3.]) + self.assertTupleEqual(a.shape, (3,)) + + def test_normal3(self): + a = bm.random.normal(loc=[0., 2., 4.], scale=[[1., 2., 3.], [1., 1., 1.]]) + print(a) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_pareto(self): + a = bm.random.pareto([1, 2, 2], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_poisson(self): + a = bm.random.poisson([1., 2., 2.], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_cauchy(self): + a = bm.random.standard_cauchy(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_exponential(self): + a = bm.random.standard_exponential(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_gamma(self): + a = bm.random.standard_gamma(shape=[1, 2, 4], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_normal(self): + a = bm.random.standard_normal(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_standard_t(self): + a = bm.random.standard_t(df=[1, 2, 4], size=3) + self.assertTupleEqual(a.shape, (3,)) + + def test_standard_uniform1(self): + a = bm.random.uniform() + self.assertTupleEqual(a.shape, ()) + self.assertTrue(0 <=a< 1) + + def test_uniform2(self): + a = bm.random.uniform(low=[-1., 5., 2.], high=[2., 6., 10.], size=3) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue((a - bm.array([-1., 5., 2.]) >= 0).all() + and (-a + bm.array([2., 6., 10.]) > 0).all()) + + def test_uniform3(self): + a = bm.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]]) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_truncated_normal1(self): + a = bm.random.truncated_normal(-1., 1.) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(-1. <= a <= 1.) + + def test_truncated_normal2(self): + a = bm.random.truncated_normal(-1., 1., size=(4, 3)) + self.assertTupleEqual(a.shape, (4, 3)) + self.assertTrue((a >= -1.).all() and (a <= 1.).all()) + + def test_truncated_normal3(self): + a = bm.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]]) + self.assertTupleEqual(a.shape, (2, 3)) + self.assertTrue((a - bm.array([-1., 0., 1.]) >= 0.).all() + and (- a + bm.array([2., 2., 4.]) >= 0.).all()) + + def test_bernoulli1(self): + a = bm.random.bernoulli() + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a == 0 or a == 1) + + def test_bernoulli2(self): + a = bm.random.bernoulli([0.5, 0.6, 0.8]) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) + + def test_bernoulli3(self): + a = bm.random.bernoulli(size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) + + def test_lognormal1(self): + a = bm.random.lognormal() + self.assertTupleEqual(a.shape, ()) + + def test_lognormal2(self): + a = bm.random.lognormal(sigma=[2., 1.], size=[3, 2]) + self.assertTupleEqual(a.shape, (3, 2)) + + def test_lognormal3(self): + a = bm.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]]) + self.assertTupleEqual(a.shape, (2, 2)) From 69cfb79b5a869a9b2f807200ba1db3866aa02243 Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Thu, 12 May 2022 20:39:19 +0800 Subject: [PATCH 5/9] update random sampling functions in random.py --- brainpy/math/random.py | 71 ++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 9b4e4881e..7daf24fe0 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -234,9 +234,11 @@ def randint(low, high=None, size=None, dtype=int): if high is None: high = low low = 0 - # todo: randint does not support multi-minval/maxval + high = jnp.asarray(high) + low = jnp.asarray(low) if size is None: - size = np.shape(low) if len(np.shape(low)) >= len(np.shape(high)) else np.shape(high) + size = jnp.shape(low) if len(jnp.shape(low)) >= len(jnp.shape(high)) else jnp.shape(high) + return JaxArray(jr.randint(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) @@ -263,6 +265,8 @@ def random_sample(size=None): @wraps(np.random.choice) def choice(a, size=None, replace=True, p=None): a = a.value if isinstance(a, JaxArray) else a + if p is not None: + p = jnp.asarray(p) return JaxArray(jr.choice(DEFAULT.split_key(), a=a, shape=_size2shape(size), replace=replace, p=p)) @@ -280,56 +284,61 @@ def shuffle(x, axis=0): @wraps(np.random.beta) def beta(a, b, size=None): - a = a.value if isinstance(a, JaxArray) else a - b = b.value if isinstance(b, JaxArray) else b + a = jnp.asarray(a) + b = jnp.asarray(b) return JaxArray(jr.beta(DEFAULT.split_key(), a=a, b=b, shape=_size2shape(size))) @wraps(np.random.exponential) def exponential(scale=1.0, size=None): - assert scale == 1. - return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size))) + scale = jnp.asarray(scale) + return JaxArray(jr.exponential(DEFAULT.split_key(), shape=_size2shape(size)) / scale) @wraps(np.random.gamma) def gamma(shape, scale=1.0, size=None): - assert scale == 1. - return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size))) + shape = jnp.asarray(shape) + scale = jnp.asarray(scale) + return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size)) * scale) @wraps(np.random.gumbel) def gumbel(loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.gumbel(DEFAULT.split_key(), shape=_size2shape(size))) + loc = jnp.asarray(loc) + scale = jnp.asarray(scale) + return JaxArray(jr.gumbel(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) @wraps(np.random.laplace) def laplace(loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size))) + loc = jnp.asarray(loc) + scale = jnp.asarray(scale) + return JaxArray(jr.laplace(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) @wraps(np.random.logistic) def logistic(loc=0.0, scale=1.0, size=None): - assert loc == 0. - assert scale == 1. - return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size))) + loc = jnp.asarray(loc) + scale = jnp.asarray(scale) + return JaxArray(jr.logistic(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) @wraps(np.random.normal) def normal(loc=0.0, scale=1.0, size=None): + loc = jnp.asarray(loc) + scale = jnp.asarray(scale) return JaxArray(jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) * scale + loc) @wraps(np.random.pareto) def pareto(a, size=None): + a = jnp.asarray(a) return JaxArray(jr.pareto(DEFAULT.split_key(), b=a, shape=_size2shape(size))) @wraps(np.random.poisson) def poisson(lam=1.0, size=None): + lam = jnp.asarray(lam) return JaxArray(jr.poisson(DEFAULT.split_key(), lam=lam, shape=_size2shape(size))) @@ -345,6 +354,7 @@ def standard_exponential(size=None): @wraps(np.random.standard_gamma) def standard_gamma(shape, size=None): + shape = jnp.asarray(shape) return JaxArray(jr.gamma(DEFAULT.split_key(), a=shape, shape=_size2shape(size))) @@ -355,15 +365,21 @@ def standard_normal(size=None): @wraps(np.random.standard_t) def standard_t(df, size=None): + df = jnp.asarray(df) return JaxArray(jr.t(DEFAULT.split_key(), df=df, shape=_size2shape(size))) @wraps(np.random.uniform) def uniform(low=0.0, high=1.0, size=None): + low = jnp.asarray(low) + high = jnp.asarray(high) + if size is None: + size = jnp.shape(low) if len(jnp.shape(low)) >= len(jnp.shape(high)) else jnp.shape(high) + return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high)) -def truncated_normal(lower, upper, size, scale=1.): +def truncated_normal(lower, upper, size=None, scale=1.): """Sample truncated standard normal random values with given shape and dtype. Parameters @@ -390,6 +406,11 @@ def truncated_normal(lower, upper, size, scale=1.): ``shape`` is not None, or else by broadcasting ``lower`` and ``upper``. Returns values in the open interval ``(lower, upper)``. """ + lower = jnp.asarray(lower) + upper = jnp.asarray(upper) + if size is None: + size = jnp.shape(lower) if len(jnp.shape(lower)) >= len(jnp.shape(upper)) else jnp.shape(upper) + rands = jr.truncated_normal(DEFAULT.split_key(), lower=lower, upper=upper, @@ -397,12 +418,13 @@ def truncated_normal(lower, upper, size, scale=1.): return JaxArray(rands * scale) -def bernoulli(p, size=None): +def bernoulli(p=0.5, size=None): """Sample Bernoulli random values with given shape and mean. Args: p: optional, a float or array of floats for the mean of the random - variables. Must be broadcast-compatible with ``shape``. Default 0.5. + variables. Must be broadcast-compatible with ``shape`` and the values + should be within [0, 1]. Default 0.5. size: optional, a tuple of nonnegative integers representing the result shape. Must be broadcast-compatible with ``p.shape``. The default (None) produces a result shape equal to ``p.shape``. @@ -411,11 +433,20 @@ def bernoulli(p, size=None): A random array with boolean dtype and shape given by ``shape`` if ``shape`` is not None, or else ``p.shape``. """ + p = jnp.asarray(p) + if jnp.unique(jnp.logical_and(p >= 0, p <= 1)) != jnp.array([True]): + raise ValueError(r'Bernoulli parameter p should be within [0, 1], but we got {}'.format(p)) + + if size is None: + size = p.shape + return JaxArray(jr.bernoulli(DEFAULT.split_key(), p=p, shape=_size2shape(size))) @wraps(np.random.lognormal) def lognormal(mean=0.0, sigma=1.0, size=None): + mean = jnp.asarray(mean) + sigma = jnp.asarray(sigma) samples = jr.normal(DEFAULT.split_key(), shape=_size2shape(size)) samples = samples * sigma + mean samples = jnp.exp(samples) From e63096e24157b8cf0a468ce12f2f8effeb3d0899 Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Fri, 13 May 2022 17:59:34 +0800 Subject: [PATCH 6/9] add new random sampling functions in random.py --- brainpy/math/random.py | 116 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index 7daf24fe0..e900d849d 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- - +import jax.experimental.host_callback import numpy as np +import numpy.random from jax import numpy as jnp from jax import random as jr from jax.tree_util import register_pytree_node @@ -8,6 +9,7 @@ from brainpy.math.jaxarray import JaxArray, Variable from .utils import wraps +from jax.experimental.host_callback import call as hcb_call __all__ = [ @@ -237,7 +239,7 @@ def randint(low, high=None, size=None, dtype=int): high = jnp.asarray(high) low = jnp.asarray(low) if size is None: - size = jnp.shape(low) if len(jnp.shape(low)) >= len(jnp.shape(high)) else jnp.shape(high) + size = np.broadcast(low, high).shape return JaxArray(jr.randint(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high, dtype=dtype)) @@ -374,7 +376,7 @@ def uniform(low=0.0, high=1.0, size=None): low = jnp.asarray(low) high = jnp.asarray(high) if size is None: - size = jnp.shape(low) if len(jnp.shape(low)) >= len(jnp.shape(high)) else jnp.shape(high) + size = np.broadcast(low, high).shape return JaxArray(jr.uniform(DEFAULT.split_key(), shape=_size2shape(size), minval=low, maxval=high)) @@ -409,7 +411,7 @@ def truncated_normal(lower, upper, size=None, scale=1.): lower = jnp.asarray(lower) upper = jnp.asarray(upper) if size is None: - size = jnp.shape(lower) if len(jnp.shape(lower)) >= len(jnp.shape(upper)) else jnp.shape(upper) + size = np.broadcast(lower, upper).shape rands = jr.truncated_normal(DEFAULT.split_key(), lower=lower, @@ -452,3 +454,109 @@ def lognormal(mean=0.0, sigma=1.0, size=None): samples = jnp.exp(samples) return JaxArray(samples) + +@wraps(np.random.binomial) +def binomial(n, p, size=None): + if size is None: + size = np.broadcast(n, p).shape + size = _size2shape(size) + d = {'n': n, 'p': p, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.binomial(n=x['n'], p=x['p'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + + +@wraps(np.random.chisquare) +def chisquare(df, size=None): + if size is None: + size = np.shape(df) + size = _size2shape(size) + d = {'df': df, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.chisquare(df=x['df'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.dirichlet) +def dirichlet(alpha, size=None): + size = _size2shape(size) + d = {'alpha': alpha, 'size': size} + output_shape = size + np.shape(alpha) + return JaxArray(hcb_call(lambda x: np.random.dirichlet(alpha=x['alpha'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(output_shape, float))) + + +@wraps(np.random.f) +def f(dfnum, dfden, size=None): + if size is None: + size = np.broadcast(dfnum, dfden).shape + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.f(dfnum=x['dfnum'], dfden=x['dfden'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.geometric) +def geometric(p, size=None): + if size is None: + size = np.shape(p) + size = _size2shape(size) + d = {'p': p, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.geometric(p=x['p'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + + +@wraps(np.random.hypergeometric) +def hypergeometric(ngood, nbad, nsample, size=None): + if size is None: + size = np.broadcast(ngood, nbad, nsample).shape + size = _size2shape(size) + d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.hypergeometric(ngood=x['ngood'], nbad=x['nbad'], + nsample=x['nsample'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + + +@wraps(np.random.logseries) +def logseries(p, size=None): + if size is None: + size = np.shape(p) + size = _size2shape(size) + d = {'p': p, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.logseries(p=x['p'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + + +@wraps(np.random.multinomial) +def multinomial(n, pvals, size=None): + size = _size2shape(size) + d = {'n': n, 'pvals': pvals, 'size': size} + output_shape = size + np.shape(pvals) + return JaxArray(hcb_call(lambda x: np.random.multinomial(n=x['n'], pvals=x['pvals'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(output_shape, int))) + + +def _packed_multivariate_normal(d): + candidate_str = ['warn', 'raise', 'ignore'] + selected = np.array([d['warn'], d['raise'], d['ignore']]) + + return np.random.multivariate_normal(mean=d['mean'], cov=d['cov'], size=d['size'], + check_valid=candidate_str[np.arange(3)[selected][0]], + tol=d['tol']) + +@wraps(np.random.multivariate_normal) +def multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8): + size = _size2shape(size) + + if not (check_valid == 'warn' or check_valid == 'raise' or check_valid == 'ignore'): + raise ValueError(r'multivariate_normal argument check_valid should be "warn", "raise", ' + 'or "ignore", but we got {}'.format(check_valid)) + + d = {'mean': mean, 'cov': cov, 'size': size, + 'warn': True if check_valid == 'warn' else False, + 'raise': True if check_valid == 'raise' else False, + 'ignore': True if check_valid == 'ignore' else False, + 'tol': tol} + output_shape = size + np.shape(mean) + + return JaxArray(hcb_call(_packed_multivariate_normal, d, + result_shape=jax.ShapeDtypeStruct(output_shape, float))) + From 4f28c97bacab0d4c048e24779c792b3218c0fb3b Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Fri, 13 May 2022 18:00:02 +0800 Subject: [PATCH 7/9] add tests for new random sampling functions in random.py --- brainpy/math/tests/test_random.py | 82 +++++++++++++++++++++++++++++-- 1 file changed, 78 insertions(+), 4 deletions(-) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 256b291ef..51e9eac30 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -147,7 +147,7 @@ def test_normal3(self): self.assertTupleEqual(a.shape, (2, 3)) def test_pareto(self): - a = bm.random.pareto([1, 2, 2], size=3) + a = bm.random.pareto([1, 2, 2]) self.assertTupleEqual(a.shape, (3,)) def test_poisson(self): @@ -186,6 +186,10 @@ def test_uniform2(self): and (-a + bm.array([2., 6., 10.]) > 0).all()) def test_uniform3(self): + a = bm.random.uniform(low=-1., high=[2., 6., 10.], size=(2, 3)) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_uniform4(self): a = bm.random.uniform(low=[-1., 5., 2.], high=[[2., 6., 10.], [10., 10., 10.]]) self.assertTupleEqual(a.shape, (2, 3)) @@ -195,9 +199,8 @@ def test_truncated_normal1(self): self.assertTrue(-1. <= a <= 1.) def test_truncated_normal2(self): - a = bm.random.truncated_normal(-1., 1., size=(4, 3)) + a = bm.random.truncated_normal(-1., [1., 2., 1.], size=(4, 3)) self.assertTupleEqual(a.shape, (4, 3)) - self.assertTrue((a >= -1.).all() and (a <= 1.).all()) def test_truncated_normal3(self): a = bm.random.truncated_normal([-1., 0., 1.], [[2., 2., 4.], [2., 2., 4.]]) @@ -216,7 +219,7 @@ def test_bernoulli2(self): self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) def test_bernoulli3(self): - a = bm.random.bernoulli(size=(3, 2)) + a = bm.random.bernoulli([0.5, 0.6], size=(3, 2)) self.assertTupleEqual(a.shape, (3, 2)) self.assertTrue(bm.logical_xor(a == 1, a == 0).all()) @@ -231,3 +234,74 @@ def test_lognormal2(self): def test_lognormal3(self): a = bm.random.lognormal([2., 0.], [[2., 1.], [3., 1.2]]) self.assertTupleEqual(a.shape, (2, 2)) + + def test_binomial1(self): + a = bm.random.binomial(5, 0.5) + self.assertIsInstance(a, bm.JaxArray) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a.dtype, int) + + def test_binomial2(self): + a = bm.random.binomial(5, 0.5, size=(3, 2)) + self.assertTupleEqual(a.shape, (3, 2)) + self.assertTrue((a >= 0).all() and (a <= 5).all()) + + def test_binomial3(self): + a = bm.random.binomial(n=[2, 3, 4], p=[[0.5, 0.5, 0.5], [0.6, 0.6, 0.6]]) + self.assertTupleEqual(a.shape, (2, 3)) + + def test_chisquare1(self): + a = bm.random.chisquare(3) + self.assertIsInstance(a, bm.JaxArray) + self.assertTupleEqual(a.shape, ()) + self.assertTrue(a.dtype, float) + + def test_chisquare2(self): + a = bm.random.chisquare(df=[2, 3, 4]) + self.assertTupleEqual(a.shape, (3,)) + + def test_dirichlet1(self): + a = bm.random.dirichlet((10, 5, 3)) + self.assertTupleEqual(a.shape, (3,)) + + def test_dirichlet2(self): + a = bm.random.dirichlet((10, 5, 3), 20) + self.assertTupleEqual(a.shape, (20, 3)) + + def test_f(self): + a = bm.random.f(1., 48., 100) + self.assertTupleEqual(a.shape, (100,)) + + def test_geometric(self): + a = bm.random.geometric([0.7, 0.5, 0.2]) + self.assertTupleEqual(a.shape, (3,)) + + def test_hypergeometric1(self): + a = bm.random.hypergeometric(10, 10, 10, 20) + self.assertTupleEqual(a.shape, (20,)) + + def test_hypergeometric2(self): + a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]]) + self.assertTupleEqual(a.shape, (2, 2)) + + def test_hypergeometric3(self): + a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2)) + self.assertTupleEqual(a.shape, (3, 2, 2)) + + def test_logseries(self): + a = bm.random.logseries([0.7, 0.5, 0.2], size=[4, 3]) + self.assertTupleEqual(a.shape, (4, 3)) + + def test_multinominal1(self): + a = bm.random.multinomial(100, (0.5, 0.2, 0.3), size=[4, 2]) + self.assertTupleEqual(a.shape, (4, 2, 3)) + + def test_multinominal2(self): + a = bm.random.multinomial(100, (0.5, 0.2, 0.3)) + self.assertTupleEqual(a.shape, (3,)) + self.assertTrue(a.sum() == 100) + + def test_multivariate_normal(self): + a = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) + self.assertTupleEqual(a.shape, (3, 2)) + From 3349b53e1280dc472fdaddc75324ce278b4c74e9 Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Fri, 13 May 2022 20:59:15 +0800 Subject: [PATCH 8/9] add tests for new random sampling functions in random.py --- brainpy/math/tests/test_random.py | 45 ++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/brainpy/math/tests/test_random.py b/brainpy/math/tests/test_random.py index 51e9eac30..055615250 100644 --- a/brainpy/math/tests/test_random.py +++ b/brainpy/math/tests/test_random.py @@ -301,7 +301,50 @@ def test_multinominal2(self): self.assertTupleEqual(a.shape, (3,)) self.assertTrue(a.sum() == 100) - def test_multivariate_normal(self): + def test_multivariate_normal1(self): a = bm.random.multivariate_normal([1, 2], [[1, 0], [0, 1]], size=3) self.assertTupleEqual(a.shape, (3, 2)) + def test_multivariate_normal2(self): + a = bm.random.multivariate_normal([1, 2], [[1, 3], [3, 1]], check_valid='ignore') + self.assertTupleEqual(a.shape, (2,)) + + def test_negative_binomial(self): + a = bm.random.negative_binomial([3., 10.], 0.5) + self.assertTupleEqual(a.shape, (2,)) + + def test_noncentral_chisquare(self): + a = bm.random.noncentral_chisquare(3, [3., 2.], (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_noncentral_f(self): + a = bm.random.noncentral_f(3, 20, 3., 100) + self.assertTupleEqual(a.shape, (100,)) + + def test_power(self): + a = bm.random.power(2, (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_rayleigh(self): + a = bm.random.power(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_triangular(self): + a = bm.random.triangular([-1., 0.], 1., [[2., 5.], [3., 3.]]) + self.assertTupleEqual(a.shape, (2, 2)) + + def test_vonmises(self): + a = bm.random.vonmises(2., 2.) + self.assertTupleEqual(a.shape, ()) + + def test_wald(self): + a = bm.random.wald([2., 0.5], 2.) + self.assertTupleEqual(a.shape, (2,)) + + def test_weibull(self): + a = bm.random.weibull(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) + + def test_zipf(self): + a = bm.random.zipf(2., (4, 2)) + self.assertTupleEqual(a.shape, (4, 2)) From 490a1dd3ba824d589b78f8630f69990c8a482d9f Mon Sep 17 00:00:00 2001 From: Xiaoyu Chen <55552143+c-xy17@users.noreply.github.com> Date: Fri, 13 May 2022 20:59:23 +0800 Subject: [PATCH 9/9] add new random sampling functions in random.py --- brainpy/math/random.py | 103 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/brainpy/math/random.py b/brainpy/math/random.py index e900d849d..10b1c4093 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -560,3 +560,106 @@ def multivariate_normal(mean, cov, size=None, check_valid='warn', tol=1e-8): return JaxArray(hcb_call(_packed_multivariate_normal, d, result_shape=jax.ShapeDtypeStruct(output_shape, float))) + +@wraps(np.random.negative_binomial) +def negative_binomial(n, p, size=None): + if size is None: + size = np.broadcast(n, p).shape + size = _size2shape(size) + d = {'n': n, 'p': p, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.negative_binomial(n=x['n'], p=x['p'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + + +@wraps(np.random.noncentral_chisquare) +def noncentral_chisquare(df, nonc, size=None): + if size is None: + size = np.broadcast(df, nonc).shape + size = _size2shape(size) + d = {'df': df, 'nonc': nonc, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.noncentral_chisquare(df=x['df'], nonc=x['nonc'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.noncentral_f) +def noncentral_f(dfnum, dfden, nonc, size=None): + if size is None: + size = np.broadcast(dfnum, dfden, nonc).shape + size = _size2shape(size) + d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], dfden=x['dfden'], + nonc=x['nonc'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.power) +def power(a, size=None): + if size is None: + size = np.shape(a) + size = _size2shape(size) + d = {'a': a, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.power(a=x['a'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.rayleigh) +def rayleigh(scale=1.0, size=None): + if size is None: + size = np.shape(scale) + size = _size2shape(size) + d = {'scale': scale, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.rayleigh(scale=x['scale'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.triangular) +def triangular(left, mode, right, size=None): + if size is None: + size = np.broadcast(left, mode, right).shape + size = _size2shape(size) + d = {'left': left, 'mode': mode, 'right': right, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.triangular(left=x['left'], mode=x['mode'], + right=x['right'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.vonmises) +def vonmises(mu, kappa, size=None): + if size is None: + size = np.broadcast(mu, kappa).shape + size = _size2shape(size) + d = {'mu': mu, 'kappa': kappa, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.vonmises(mu=x['mu'], kappa=x['kappa'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.wald) +def wald(mean, scale, size=None): + if size is None: + size = np.broadcast(mean, scale).shape + size = _size2shape(size) + d = {'mean': mean, 'scale': scale, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.wald(mean=x['mean'], scale=x['scale'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.weibull) +def weibull(a, size=None): + if size is None: + size = np.shape(a) + size = _size2shape(size) + d = {'a': a, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.weibull(a=x['a'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, float))) + + +@wraps(np.random.zipf) +def zipf(a, size=None): + if size is None: + size = np.shape(a) + size = _size2shape(size) + d = {'a': a, 'size': size} + return JaxArray(hcb_call(lambda x: np.random.zipf(a=x['a'], size=x['size']), + d, result_shape=jax.ShapeDtypeStruct(size, int))) + +