Skip to content

Commit

Permalink
add random.triangular and random.lognormal
Browse files Browse the repository at this point in the history
add random.triangular and random.lognormal

add random.triangular and random.lognormal
  • Loading branch information
JiaYaobo committed Aug 5, 2023
1 parent 9adf319 commit 6d18445
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/jax.random.rst
Expand Up @@ -36,6 +36,7 @@ List of Available Functions
laplace
loggamma
logistic
lognormal
maxwell
multivariate_normal
normal
Expand All @@ -49,6 +50,7 @@ List of Available Functions
shuffle
split
t
triangular
truncated_normal
uniform
wald
Expand Down
111 changes: 111 additions & 0 deletions jax/_src/random.py
Expand Up @@ -2216,3 +2216,114 @@ def _geometric(key, p, shape, dtype) -> Array:
log_one_minus_p = jnp.broadcast_to(log_one_minus_p, shape)
g = lax.floor(lax.div(log_u, log_one_minus_p)) + 1
return g.astype(dtype)


def triangular(key: KeyArray,
left: RealArray,
mode: RealArray,
right: RealArray,
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
r"""Sample Triangular random values with given shape and float dtype.
The values are returned according to the probability density function:
.. math::
f(x; a, b, c) = \frac{2}{c-a} \left\{ \begin{array}{ll} \frac{x-a}{b-a} & a \leq x \leq b \\ \frac{c-x}{c-b} & b \leq x \leq c \end{array} \right.
on the domain :math:`a \leq x \leq c`.
Args:
key: a PRNG key used as the random key.
left: a float or array of floats broadcast-compatible with ``shape``
representing the lower limit parameter of the distribution.
mode: a float or array of floats broadcast-compatible with ``shape``
representing the peak value parameter of the distribution, value must
fulfill the condition ``left <= mode <= right``.
right: a float or array of floats broadcast-compatible with ``shape``
representing the upper limit parameter of the distribution, must be
larger than ``left``.
shape: optional, a tuple of nonnegative integers specifying the result
shape. Must be broadcast-compatible with ``left``,``mode`` and ``right``.
The default (None) produces a result shape equal to ``left.shape``, ``mode.shape``
and ``right.shape``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape`` if
``shape`` is not None, or else by ``left.shape``, ``mode.shape`` and ``right.shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.floating):
raise ValueError("dtype argument to `triangular` must be a float "
f"dtype, got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _triangular(key, left, mode, right, shape, dtype)

@partial(jit, static_argnums=(4, 5), inline=True)
def _triangular(key, left, mode, right, shape, dtype) -> Array:
# https://en.wikipedia.org/wiki/Triangular_distribution#Generating_triangular-distributed_random_variates
if shape is None:
shape = lax.broadcast_shapes(np.shape(left), np.shape(mode), np.shape(right))
else:
_check_shape("triangular", shape, np.shape(left), np.shape(mode), np.shape(right))
left = jnp.broadcast_to(left, shape)
mode = jnp.broadcast_to(mode, shape)
right = jnp.broadcast_to(right, shape)
fc = (mode - left) / (right - left)
u = uniform(key, shape, dtype)
out1 = left + lax.sqrt(u * (right - left) * (mode - left))
out2 = right - lax.sqrt((1 - u) * (right - left) * (right - mode))
tri = lax.select(u < fc, out1, out2)
return tri



def lognormal(key: KeyArray,
sigma: RealArray = np.float32(1),
shape: Optional[Shape] = None,
dtype: DTypeLikeFloat = dtypes.float_) -> Array:
r""" Sample lognormal random values with given shape and float dtype.
The values are distributed according to the probability density function:
.. math::
f(x) = \frac{1}{x\sqrt{2\pi\sigma^2}}\exp\left(-\frac{(\log x)^2}{2\sigma^2}\right)
on the domain :math:`x > 0`.
Args:
key: a PRNG key used as the random key.
sigma: a float or array of floats broadcast-compatible with ``shape`` representing
the standard deviation of the underlying normal distribution. Default 1.
shape: optional, a tuple of nonnegative integers specifying the result
shape. The default (None) produces a result shape equal to ``()``.
dtype: optional, a float dtype for the returned values (default float64 if
jax_enable_x64 is true, otherwise float32).
Returns:
A random array with the specified dtype and with shape given by ``shape``.
"""
key, _ = _check_prng_key(key)
dtypes.check_user_dtype_supported(dtype)
if not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to `lognormal` must be a float or complex dtype, "
f"got {dtype}")
dtype = dtypes.canonicalize_dtype(dtype)
if shape is not None:
shape = core.canonicalize_shape(shape)
return _lognormal(key, sigma, shape, dtype) # type: ignore

@partial(jit, static_argnums=(2, 3), inline=True)
def _lognormal(key, sigma, shape, dtype) -> Array:
if shape is None:
shape = np.shape(sigma)
else:
_check_shape("triangular", shape, np.shape(sigma))
sigma = jnp.broadcast_to(sigma, shape)
scaled_norm = normal(key, shape, dtype) * sigma
return lax.exp(scaled_norm)
2 changes: 2 additions & 0 deletions jax/random.py
Expand Up @@ -175,6 +175,7 @@
laplace as laplace,
logistic as logistic,
loggamma as loggamma,
lognormal as lognormal,
maxwell as maxwell,
multivariate_normal as multivariate_normal,
normal as normal,
Expand All @@ -193,6 +194,7 @@
threefry_2x32 as threefry_2x32,
threefry2x32_key as threefry2x32_key,
threefry2x32_p as threefry2x32_p,
triangular as triangular,
truncated_normal as truncated_normal,
uniform as uniform,
unsafe_rbg_key as unsafe_rbg_key,
Expand Down
30 changes: 30 additions & 0 deletions tests/random_test.py
Expand Up @@ -1653,6 +1653,36 @@ def testGeometric(self, p, dtype):
self.assertAllClose(samples.mean(), 1 / p, rtol=0.02, check_dtypes=False)
self.assertAllClose(samples.var(), (1 - p) / (p * p) , rtol=0.05, check_dtypes=False)

@jtu.sample_product(
left = [0.2, 0.5, 1., 2.],
mode = [3., 5., 8., 9.],
right= [10., 20., 30., 40.],
dtype= jtu.dtypes.floating)
def testTriangular(self, left, mode, right, dtype):
key = self.make_key(1)
rand = lambda key: random.triangular(key, left, mode, right, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)

uncompiled_samples = rand(key)
compiled_samples = crand(key)

for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.triang((mode - left) / (right - left), loc=left, scale=right - left).cdf)

@jtu.sample_product(
sigma = [0.2, 0.5, 1., 2.],
dtype=jtu.dtypes.floating)
def testLogNormal(self, sigma, dtype):
key = self.make_key(0)
rand = lambda key: random.lognormal(key, sigma, shape=(10000, ), dtype=dtype)
crand = jax.jit(rand)

uncompiled_samples = rand(key)
compiled_samples = crand(key)

for samples in [uncompiled_samples, compiled_samples]:
self._CheckKolmogorovSmirnovCDF(samples, scipy.stats.lognorm(s=sigma).cdf)

@parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS])
def test_copy(self, make_key):
key = make_key(8459302)
Expand Down

0 comments on commit 6d18445

Please sign in to comment.