From 6d184459ef91bbf0ae1d63c12db855e3359c387f Mon Sep 17 00:00:00 2001 From: jiayaobo Date: Fri, 4 Aug 2023 11:06:32 +0800 Subject: [PATCH] add random.triangular and random.lognormal add random.triangular and random.lognormal add random.triangular and random.lognormal --- docs/jax.random.rst | 2 + jax/_src/random.py | 111 +++++++++++++++++++++++++++++++++++++++++++ jax/random.py | 2 + tests/random_test.py | 30 ++++++++++++ 4 files changed, 145 insertions(+) diff --git a/docs/jax.random.rst b/docs/jax.random.rst index d1878ceb56b0..24793f97abe2 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -36,6 +36,7 @@ List of Available Functions laplace loggamma logistic + lognormal maxwell multivariate_normal normal @@ -49,6 +50,7 @@ List of Available Functions shuffle split t + triangular truncated_normal uniform wald diff --git a/jax/_src/random.py b/jax/_src/random.py index 6928fb34ced9..663c18a194a9 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -2216,3 +2216,114 @@ def _geometric(key, p, shape, dtype) -> Array: log_one_minus_p = jnp.broadcast_to(log_one_minus_p, shape) g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1 return g.astype(dtype) + + +def triangular(key: KeyArray, + left: RealArray, + mode: RealArray, + right: RealArray, + shape: Optional[Shape] = None, + dtype: DTypeLikeFloat = dtypes.float_) -> Array: + r"""Sample Triangular random values with given shape and float dtype. + + The values are returned according to the probability density function: + + .. math:: + f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right. + + on the domain :math:`a \leq x \leq c`. + + Args: + key: a PRNG key used as the random key. + left: a float or array of floats broadcast-compatible with ``shape`` + representing the lower limit parameter of the distribution. + mode: a float or array of floats broadcast-compatible with ``shape`` + representing the peak value parameter of the distribution, value must + fulfill the condition ``left <= mode <= right``. + right: a float or array of floats broadcast-compatible with ``shape`` + representing the upper limit parameter of the distribution, must be + larger than ``left``. + shape: optional, a tuple of nonnegative integers specifying the result + shape. Must be broadcast-compatible with ``left``,``mode`` and ``right``. + The default (None) produces a result shape equal to ``left.shape``, ``mode.shape`` + and ``right.shape``. + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array with the specified dtype and with shape given by ``shape`` if + ``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``. + """ + key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) + if not dtypes.issubdtype(dtype, np.floating): + raise ValueError("dtype argument to `triangular` must be a float " + f"dtype, got {dtype}") + dtype = dtypes.canonicalize_dtype(dtype) + if shape is not None: + shape = core.canonicalize_shape(shape) + return _triangular(key, left, mode, right, shape, dtype) + +@partial(jit, static_argnums=(4, 5), inline=True) +def _triangular(key, left, mode, right, shape, dtype) -> Array: + # https://en.wikipedia.org/wiki/Triangular_distribution#Generating_triangular-distributed_random_variates + if shape is None: + shape = lax.broadcast_shapes(np.shape(left), np.shape(mode), np.shape(right)) + else: + _check_shape("triangular", shape, np.shape(left), np.shape(mode), np.shape(right)) + left = jnp.broadcast_to(left, shape) + mode = jnp.broadcast_to(mode, shape) + right = jnp.broadcast_to(right, shape) + fc = (mode - left) / (right - left) + u = uniform(key, shape, dtype) + out1 = left + lax.sqrt(u * (right - left) * (mode - left)) + out2 = right - lax.sqrt((1 - u) * (right - left) * (right - mode)) + tri = lax.select(u < fc, out1, out2) + return tri + + + +def lognormal(key: KeyArray, + sigma: RealArray = np.float32(1), + shape: Optional[Shape] = None, + dtype: DTypeLikeFloat = dtypes.float_) -> Array: + r""" Sample lognormal random values with given shape and float dtype. + + The values are distributed according to the probability density function: + + .. math:: + f(x) = \frac{1}{x\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(\log x)^2}{2\sigma^2}\right) + + on the domain :math:`x > 0`. + + Args: + key: a PRNG key used as the random key. + sigma: a float or array of floats broadcast-compatible with ``shape`` representing + the standard deviation of the underlying normal distribution. Default 1. + shape: optional, a tuple of nonnegative integers specifying the result + shape. The default (None) produces a result shape equal to ``()``. + dtype: optional, a float dtype for the returned values (default float64 if + jax_enable_x64 is true, otherwise float32). + + Returns: + A random array with the specified dtype and with shape given by ``shape``. + """ + key, _ = _check_prng_key(key) + dtypes.check_user_dtype_supported(dtype) + if not dtypes.issubdtype(dtype, np.inexact): + raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, " + f"got {dtype}") + dtype = dtypes.canonicalize_dtype(dtype) + if shape is not None: + shape = core.canonicalize_shape(shape) + return _lognormal(key, sigma, shape, dtype) # type: ignore + +@partial(jit, static_argnums=(2, 3), inline=True) +def _lognormal(key, sigma, shape, dtype) -> Array: + if shape is None: + shape = np.shape(sigma) + else: + _check_shape("triangular", shape, np.shape(sigma)) + sigma = jnp.broadcast_to(sigma, shape) + scaled_norm = normal(key, shape, dtype) * sigma + return lax.exp(scaled_norm) diff --git a/jax/random.py b/jax/random.py index e86e9b276dd5..c64db163f3a6 100644 --- a/jax/random.py +++ b/jax/random.py @@ -175,6 +175,7 @@ laplace as laplace, logistic as logistic, loggamma as loggamma, + lognormal as lognormal, maxwell as maxwell, multivariate_normal as multivariate_normal, normal as normal, @@ -193,6 +194,7 @@ threefry_2x32 as threefry_2x32, threefry2x32_key as threefry2x32_key, threefry2x32_p as threefry2x32_p, + triangular as triangular, truncated_normal as truncated_normal, uniform as uniform, unsafe_rbg_key as unsafe_rbg_key, diff --git a/tests/random_test.py b/tests/random_test.py index bb74e0b9d2ca..5c52780ce59d 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -1653,6 +1653,36 @@ def testGeometric(self, p, dtype): self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False) self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False) + @jtu.sample_product( + left = [0.2, 0.5, 1., 2.], + mode = [3., 5., 8., 9.], + right= [10., 20., 30., 40.], + dtype= jtu.dtypes.floating) + def testTriangular(self, left, mode, right, dtype): + key = self.make_key(1) + rand = lambda key: random.triangular(key, left, mode, right, shape=(10000, ), dtype=dtype) + crand = jax.jit(rand) + + uncompiled_samples = rand(key) + compiled_samples = crand(key) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang((mode - left) / (right - left), loc=left, scale=right - left).cdf) + + @jtu.sample_product( + sigma = [0.2, 0.5, 1., 2.], + dtype=jtu.dtypes.floating) + def testLogNormal(self, sigma, dtype): + key = self.make_key(0) + rand = lambda key: random.lognormal(key, sigma, shape=(10000, ), dtype=dtype) + crand = jax.jit(rand) + + uncompiled_samples = rand(key) + compiled_samples = crand(key) + + for samples in [uncompiled_samples, compiled_samples]: + self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf) + @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_copy(self, make_key): key = make_key(8459302)