Skip to content

Commit

Permalink
outline jitted jax.random functions
Browse files Browse the repository at this point in the history
We may want to continue to inline these in Jaxpr somehow, but it's
useful to outline them in HLO for visualization and debugging.
  • Loading branch information
froystig committed May 25, 2023
1 parent 16368bc commit 3238b62
Showing 1 changed file with 29 additions and 29 deletions.
58 changes: 29 additions & 29 deletions jax/_src/random.py
Expand Up @@ -315,7 +315,7 @@ def uniform(key: KeyArray,
shape = core.as_named_shape(shape)
return _uniform(key, shape, dtype, minval, maxval) # type: ignore

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _uniform(key, shape, dtype, minval, maxval) -> Array:
_check_shape("uniform", shape)
if not jnp.issubdtype(dtype, np.floating):
Expand Down Expand Up @@ -379,7 +379,7 @@ def randint(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _randint(key, shape, minval, maxval, dtype)

@partial(jit, static_argnums=(1, 4), inline=True)
@partial(jit, static_argnums=(1, 4))
def _randint(key, shape, minval, maxval, dtype) -> Array:
_check_shape("randint", shape, np.shape(minval), np.shape(maxval))
if not jnp.issubdtype(dtype, np.integer):
Expand Down Expand Up @@ -491,7 +491,7 @@ def permutation(key: KeyArray,
return jnp.take(x, ind, axis, unique_indices=True)


@partial(jit, static_argnums=(2,), inline=True)
@partial(jit, static_argnums=(2,))
def _shuffle(key, x, axis) -> Array:
# On parallel architectures, Fisher-Yates is more expensive than doing
# multiple sorts. This algorithm is based on one developed and analyzed by
Expand Down Expand Up @@ -626,7 +626,7 @@ def normal(key: KeyArray,
shape = core.as_named_shape(shape)
return _normal(key, shape, dtype) # type: ignore

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _normal(key, shape, dtype) -> Array:
if dtypes.issubdtype(dtype, np.complexfloating):
sqrt2 = np.array(np.sqrt(2), dtype)
Expand All @@ -639,7 +639,7 @@ def _normal(key, shape, dtype) -> Array:
else:
return _normal_real(key, shape, dtype) # type: ignore

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _normal_real(key, shape, dtype) -> Array:
_check_shape("normal", shape)
lo = np.nextafter(np.array(-1., dtype), np.array(0., dtype), dtype=dtype)
Expand Down Expand Up @@ -697,7 +697,7 @@ def multivariate_normal(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _multivariate_normal(key, mean, cov, shape, dtype, method) # type: ignore

@partial(jit, static_argnums=(3, 4, 5), inline=True)
@partial(jit, static_argnums=(3, 4, 5))
def _multivariate_normal(key, mean, cov, shape, dtype, method) -> Array:
if not np.ndim(mean) >= 1:
msg = "multivariate_normal requires mean.ndim >= 1, got mean.ndim == {}"
Expand Down Expand Up @@ -771,7 +771,7 @@ def truncated_normal(key: KeyArray,
shape = core.as_named_shape(shape)
return _truncated_normal(key, lower, upper, shape, dtype) # type: ignore

@partial(jit, static_argnums=(3, 4), inline=True)
@partial(jit, static_argnums=(3, 4))
def _truncated_normal(key, lower, upper, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(lower), np.shape(upper))
Expand Down Expand Up @@ -829,7 +829,7 @@ def bernoulli(key: KeyArray,
p = lax.convert_element_type(p, dtype)
return _bernoulli(key, p, shape) # type: ignore

@partial(jit, static_argnums=(2,), inline=True)
@partial(jit, static_argnums=(2,))
def _bernoulli(key, p, shape) -> Array:
if shape is None:
# TODO: Use the named part of `p` as well
Expand Down Expand Up @@ -930,7 +930,7 @@ def cauchy(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _cauchy(key, shape, dtype)

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _cauchy(key, shape, dtype) -> Array:
_check_shape("cauchy", shape)
u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
Expand Down Expand Up @@ -982,7 +982,7 @@ def dirichlet(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _dirichlet(key, alpha, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _dirichlet(key, alpha, shape, dtype) -> Array:
if not np.ndim(alpha) >= 1:
msg = "dirichlet requires alpha.ndim >= 1, got alpha.ndim == {}"
Expand Down Expand Up @@ -1039,7 +1039,7 @@ def exponential(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _exponential(key, shape, dtype)

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _exponential(key, shape, dtype) -> Array:
_check_shape("exponential", shape)
u = uniform(key, shape, dtype)
Expand Down Expand Up @@ -1266,7 +1266,7 @@ def loggamma(key: KeyArray,
return _gamma(key, a, shape=shape, dtype=dtype, log_space=True)


@partial(jit, static_argnames=('shape', 'dtype', 'log_space'), inline=True)
@partial(jit, static_argnames=('shape', 'dtype', 'log_space'))
def _gamma(key, a, shape, dtype, log_space=False) -> Array:
if shape is None:
shape = np.shape(a)
Expand All @@ -1279,7 +1279,7 @@ def _gamma(key, a, shape, dtype, log_space=False) -> Array:
return random_gamma_p.bind(key, a, log_space=log_space)


@partial(jit, static_argnums=(2, 3, 4), inline=True)
@partial(jit, static_argnums=(2, 3, 4))
def _poisson_knuth(key, lam, shape, dtype, max_iters) -> Array:
# Knuth's algorithm for generating Poisson random variates.
# Reference:
Expand All @@ -1302,7 +1302,7 @@ def cond_fn(carry):
return (k - 1).astype(dtype)


@partial(jit, static_argnums=(2, 3, 4), inline=True)
@partial(jit, static_argnums=(2, 3, 4))
def _poisson_rejection(key, lam, shape, dtype, max_iters) -> Array:
# Transformed rejection due to Hormann.
# Reference:
Expand Down Expand Up @@ -1345,7 +1345,7 @@ def cond_fn(carry):
return k.astype(dtype)


@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _poisson(key, lam, shape, dtype) -> Array:
# The implementation matches TensorFlow and NumPy:
# https://github.com/tensorflow/tensorflow/blob/v2.2.0-rc3/tensorflow/core/kernels/random_poisson_op.cc
Expand Down Expand Up @@ -1437,7 +1437,7 @@ def gumbel(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _gumbel(key, shape, dtype)

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _gumbel(key, shape, dtype) -> Array:
_check_shape("gumbel", shape)
return -jnp.log(-jnp.log(
Expand Down Expand Up @@ -1514,7 +1514,7 @@ def laplace(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _laplace(key, shape, dtype)

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _laplace(key, shape, dtype) -> Array:
_check_shape("laplace", shape)
u = uniform(
Expand Down Expand Up @@ -1550,7 +1550,7 @@ def logistic(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _logistic(key, shape, dtype)

@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _logistic(key, shape, dtype):
_check_shape("logistic", shape)
x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.)
Expand Down Expand Up @@ -1593,7 +1593,7 @@ def pareto(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _pareto(key, b, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _pareto(key, b, shape, dtype) -> Array:
if shape is None:
shape = np.shape(b)
Expand Down Expand Up @@ -1640,7 +1640,7 @@ def t(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _t(key, df, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _t(key, df, shape, dtype) -> Array:
if shape is None:
shape = np.shape(df)
Expand Down Expand Up @@ -1693,7 +1693,7 @@ def chisquare(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _chisquare(key, df, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _chisquare(key, df, shape, dtype) -> Array:
if shape is None:
shape = np.shape(df)
Expand Down Expand Up @@ -1750,7 +1750,7 @@ def f(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _f(key, dfnum, dfden, shape, dtype)

@partial(jit, static_argnums=(3, 4), inline=True)
@partial(jit, static_argnums=(3, 4))
def _f(key, dfnum, dfden, shape, dtype) -> Array:
if shape is None:
shape = lax.broadcast_shapes(np.shape(dfden), np.shape(dfnum))
Expand Down Expand Up @@ -1798,7 +1798,7 @@ def rademacher(key: KeyArray,
return _rademacher(key, shape, dtype)


@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _rademacher(key, shape, dtype) -> Array:
bernoulli_samples = bernoulli(key=key, p=0.5, shape=shape).astype(dtype)
return (2 * bernoulli_samples - 1).astype(dtype)
Expand Down Expand Up @@ -1836,7 +1836,7 @@ def maxwell(key: KeyArray,
return _maxwell(key, shape, dtype)


@partial(jit, static_argnums=(1, 2), inline=True)
@partial(jit, static_argnums=(1, 2))
def _maxwell(key, shape, dtype) -> Array:
shape = shape + (3,)
norm_rvs = normal(key=key, shape=shape, dtype=dtype)
Expand Down Expand Up @@ -1878,7 +1878,7 @@ def double_sided_maxwell(key: KeyArray,
return _double_sided_maxwell(key, loc, scale, shape, dtype)


@partial(jit, static_argnums=(3, 4), inline=True)
@partial(jit, static_argnums=(3, 4))
def _double_sided_maxwell(key, loc, scale, shape, dtype) -> Array:
params_shapes = lax.broadcast_shapes(np.shape(loc), np.shape(scale))
if not shape:
Expand Down Expand Up @@ -1929,7 +1929,7 @@ def weibull_min(key: KeyArray,
return _weibull_min(key, scale, concentration, shape, dtype)


@partial(jit, static_argnums=(3, 4), inline=True)
@partial(jit, static_argnums=(3, 4))
def _weibull_min(key, scale, concentration, shape, dtype) -> Array:
random_uniform = uniform(
key=key, shape=shape, minval=0, maxval=1, dtype=dtype)
Expand Down Expand Up @@ -2076,7 +2076,7 @@ def rayleigh(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _rayleigh(key, scale, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _rayleigh(key, scale, shape, dtype) -> Array:
if shape is None:
shape = np.shape(scale)
Expand Down Expand Up @@ -2129,7 +2129,7 @@ def wald(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _wald(key, mean, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _wald(key, mean, shape, dtype) -> Array:
if shape is None:
shape = np.shape(mean)
Expand Down Expand Up @@ -2184,7 +2184,7 @@ def geometric(key: KeyArray,
shape = core.canonicalize_shape(shape)
return _geometric(key, p, shape, dtype)

@partial(jit, static_argnums=(2, 3), inline=True)
@partial(jit, static_argnums=(2, 3))
def _geometric(key, p, shape, dtype) -> Array:
if shape is None:
shape = np.shape(p)
Expand Down

0 comments on commit 3238b62

Please sign in to comment.