From f7731bf95990865380f84e14400d4cbdb5d51f2f Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Mon, 7 Mar 2022 12:25:01 -0800 Subject: [PATCH] remove `_const` from public `jax.lax` module Modify all internal call sites to use `jax._src.lax.lax._const`. --- jax/_src/lax/linalg.py | 10 ++- jax/_src/numpy/lax_numpy.py | 119 +++++++++++++++--------------- jax/_src/numpy/linalg.py | 3 +- jax/_src/prng.py | 4 +- jax/_src/random.py | 31 ++++---- jax/_src/scipy/special.py | 37 +++++----- jax/_src/scipy/stats/bernoulli.py | 5 +- jax/_src/scipy/stats/beta.py | 3 +- jax/_src/scipy/stats/betabinom.py | 5 +- jax/_src/scipy/stats/cauchy.py | 3 +- jax/_src/scipy/stats/chi2.py | 5 +- jax/_src/scipy/stats/dirichlet.py | 3 +- jax/_src/scipy/stats/gamma.py | 3 +- jax/_src/scipy/stats/geom.py | 5 +- jax/_src/scipy/stats/laplace.py | 9 ++- jax/_src/scipy/stats/nbinom.py | 3 +- jax/_src/scipy/stats/norm.py | 5 +- jax/_src/scipy/stats/pareto.py | 3 +- jax/_src/scipy/stats/poisson.py | 5 +- jax/_src/scipy/stats/t.py | 7 +- jax/experimental/jet.py | 24 +++--- jax/lax/__init__.py | 2 +- tests/lax_test.py | 15 ++-- 23 files changed, 171 insertions(+), 138 deletions(-) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 83b14a53d5de..897b2fd990e3 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -323,8 +323,9 @@ def cholesky_jvp_rule(primals, tangents): # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf def phi(X): l = jnp.tril(X) - return l / lax.expand_dims(lax._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype), - range(l.ndim - 2)) + return l / lax.expand_dims( + lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype), + range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, conjugate_a=True, lower=True) @@ -991,7 +992,7 @@ def _lu_jvp_rule(primals, tangents): ndims = len(a_shape) l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) - zero = lax._const(lu, 0) + zero = lax_internal._const(lu, 0) l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding) l = l + lax.expand_dims(jnp.eye(m, m, dtype=dtype), range(l.ndim - 2)) @@ -999,7 +1000,8 @@ def _lu_jvp_rule(primals, tangents): ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) - u = lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + lax.expand_dims(u_eye, range(lu.ndim - 2)) + u = (lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + + lax.expand_dims(u_eye, range(lu.ndim - 2))) la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, unit_diagonal=True) diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index a203d2c0e866..49c6a339d6f9 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -51,6 +51,7 @@ from jax import lax from jax._src import device_array from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator +from jax._src.lax import lax as lax_internal from jax._src.ops import scatter from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, canonicalize_axis as _canonicalize_axis, maybe_named_axis) @@ -202,6 +203,8 @@ def _jnp_dtype(obj, align=False, copy=False): 64: np.int64, } +_lax_const = lax_internal._const + def _promote_shapes(fun_name, *args): """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" if len(args) < 2: @@ -356,8 +359,8 @@ def _convert_and_clip_integer(val, dtype): # This happens in X32 mode and can either come from a jax value created in another # context, or a Python integer converted to int64. pass - min_val = lax._const(val, _max(iinfo(dtype).min, iinfo(val_dtype).min)) - max_val = lax._const(val, _min(iinfo(dtype).max, iinfo(val_dtype).max)) + min_val = _lax_const(val, _max(iinfo(dtype).min, iinfo(val_dtype).min)) + max_val = _lax_const(val, _min(iinfo(dtype).max, iinfo(val_dtype).max)) return clip(val, min_val, max_val).astype(dtype) @@ -567,7 +570,7 @@ def sign(x): if issubdtype(dtype, complexfloating): re = lax.real(x) return lax.complex( - lax.sign(where(re != 0, re, lax.imag(x))), lax._const(re, 0)) + lax.sign(where(re != 0, re, lax.imag(x))), _lax_const(re, 0)) return lax.sign(x) @@ -604,8 +607,8 @@ def floor_divide(x1, x2): x2r = lax.real(x2) x2i = lax.imag(x2) which = lax.ge(lax.abs(x2r), lax.abs(x2i)) - rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i)) - rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1)) + rat1 = where(which, _lax_const(x2i, 1), lax.div(x2r, x2i)) + rat2 = where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) return lax.convert_element_type(out, dtype) @@ -630,7 +633,7 @@ def _float_divmod(x1, x2): ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) mod = lax.select(ind, mod + x2, mod) - div = lax.select(ind, div - lax._const(div, 1), div) + div = lax.select(ind, div - _lax_const(div, 1), div) return lax.round(div), mod @@ -646,8 +649,8 @@ def _power(x1, x2): # TODO(phawkins): add integer pow support to XLA. bits = 6 # Anything more would overflow for any x1 > 1 - zero = lax._const(x2, 0) - one = lax._const(x2, 1) + zero = _lax_const(x2, 0) + one = _lax_const(x2, 1) # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) for _ in range(bits): @@ -683,15 +686,15 @@ def logaddexp(x1, x2): lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2))) + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _lax_const(amax, 2))) out = lax.add(amax, lax.log1p(lax.exp(delta))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) def _wrap_between(x, _a): """Wraps `x` between `[-a, a]`.""" - a = lax._const(x, _a) - two_a = lax._const(x, 2 * _a) - zero = lax._const(x, 0) + a = _lax_const(x, _a) + two_a = _lax_const(x, 2 * _a) + zero = _lax_const(x, 0) rem = lax.rem(lax.add(x, a), two_a) rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) return lax.sub(rem, a) @@ -721,10 +724,10 @@ def logaddexp2(x1, x2): return lax.select(isnan(delta), lax.add(x1, x2), # NaNs or infinities of the same sign. lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - lax._const(x1, np.log(2))))) + _lax_const(x1, np.log(2))))) else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), lax._const(x1, np.log(2)))) + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _lax_const(amax, 2))) + out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _lax_const(x1, np.log(2)))) return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) @logaddexp2.defjvp @@ -742,21 +745,21 @@ def _logaddexp2_jvp(primals, tangents): @partial(jit, inline=True) def log2(x): x, = _promote_args_inexact("log2", x) - return lax.div(lax.log(x), lax.log(lax._const(x, 2))) + return lax.div(lax.log(x), lax.log(_lax_const(x, 2))) @_wraps(np.log10) @partial(jit, inline=True) def log10(x): x, = _promote_args_inexact("log10", x) - return lax.div(lax.log(x), lax.log(lax._const(x, 10))) + return lax.div(lax.log(x), lax.log(_lax_const(x, 10))) @_wraps(np.exp2) @partial(jit, inline=True) def exp2(x): x, = _promote_args_inexact("exp2", x) - return lax.exp(lax.mul(lax.log(lax._const(x, 2)), x)) + return lax.exp(lax.mul(lax.log(_lax_const(x, 2)), x)) @_wraps(np.signbit) @jit @@ -764,7 +767,7 @@ def signbit(x): x, = _promote_args("signbit", x) dtype = _dtype(x) if issubdtype(dtype, integer): - return lax.lt(x, lax._const(x, 0)) + return lax.lt(x, _lax_const(x, 0)) elif issubdtype(dtype, bool_): return full_like(x, False, dtype=bool_) elif not issubdtype(dtype, floating): @@ -803,7 +806,7 @@ def trapz(y, x=None, dx=1.0, axis: int = -1): @jit def trunc(x): _check_arraylike('trunc', x) - return where(lax.lt(x, lax._const(x, 0)), ceil(x), floor(x)) + return where(lax.lt(x, _lax_const(x, 0)), ceil(x), floor(x)) @partial(jit, static_argnums=(2, 3, 4)) @@ -856,8 +859,8 @@ def correlate(a, v, mode='valid', *, precision=None): def _normalize_float(x): info = finfo(_dtype(x)) cond = lax.abs(x) < info.tiny - x1 = where(cond, x * lax._const(x, 1 << info.nmant), x) - x2 = where(cond, lax._const(np.int32, -info.nmant), lax._const(np.int32, 0)) + x1 = where(cond, x * _lax_const(x, 1 << info.nmant), x) + x2 = where(cond, _lax_const(np.int32, -info.nmant), _lax_const(np.int32, 0)) int_type = _INT_DTYPES[info.bits] return lax.bitcast_convert_type(x1, int_type), x2 @@ -934,7 +937,7 @@ def frexp(x): @jit def remainder(x1, x2): x1, x2 = _promote_args("remainder", x1, x2) - zero = lax._const(x1, 0) + zero = _lax_const(x1, 0) trunc_mod = lax.rem(x1, x2) trunc_mod_not_zero = lax.ne(trunc_mod, zero) do_plus = lax.bitwise_and( @@ -963,14 +966,14 @@ def square(x): @partial(jit, inline=True) def deg2rad(x): x, = _promote_args_inexact("deg2rad", x) - return lax.mul(x, lax._const(x, pi / 180)) + return lax.mul(x, _lax_const(x, pi / 180)) @_wraps(np.rad2deg) @partial(jit, inline=True) def rad2deg(x): x, = _promote_args_inexact("rad2deg", x) - return lax.mul(x, lax._const(x, 180 / pi)) + return lax.mul(x, _lax_const(x, 180 / pi)) degrees = rad2deg @@ -1087,9 +1090,9 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None): def heaviside(x1, x2): _check_arraylike("heaviside", x1, x2) x1, x2 = _promote_dtypes_inexact(x1, x2) - zero = lax._const(x1, 0) + zero = _lax_const(x1, 0) return where(lax.lt(x1, zero), zero, - where(lax.gt(x1, zero), lax._const(x1, 1), x2)) + where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) @_wraps(np.hypot) @@ -1116,9 +1119,9 @@ def reciprocal(x): def sinc(x): _check_arraylike("sinc", x) x, = _promote_dtypes_inexact(x) - eq_zero = lax.eq(x, lax._const(x, 0)) - pi_x = lax.mul(lax._const(x, pi), x) - safe_pi_x = where(eq_zero, lax._const(x, 1), pi_x) + eq_zero = lax.eq(x, _lax_const(x, 0)) + pi_x = lax.mul(_lax_const(x, pi), x) + safe_pi_x = where(eq_zero, _lax_const(x, 1), pi_x) return where(eq_zero, _sinc_maclaurin(0, pi_x), lax.div(lax.sin(safe_pi_x), safe_pi_x)) @@ -1220,13 +1223,13 @@ def real(val): @jit def iscomplex(x): i = imag(x) - return lax.ne(i, lax._const(i, 0)) + return lax.ne(i, _lax_const(i, 0)) @_wraps(np.isreal) @jit def isreal(x): i = imag(x) - return lax.eq(i, lax._const(i, 0)) + return lax.eq(i, _lax_const(i, 0)) @_wraps(np.angle) @partial(jit, static_argnames=['deg']) @@ -2095,7 +2098,7 @@ def _round_float(x): # end due to precision problems. As a workaround for float16, convert to # float32, x = lax.convert_element_type(x, np.float32) if dtype == np.float16 else x - factor = lax._const(x, 10 ** decimals) + factor = _lax_const(x, 10 ** decimals) out = lax.div(lax.round(lax.mul(x, factor), lax.RoundingMethod.TO_NEAREST_EVEN), factor) return lax.convert_element_type(out, dtype) if dtype == np.float16 else out @@ -2114,7 +2117,7 @@ def fix(x, out=None): _check_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") - zero = lax._const(x, 0) + zero = _lax_const(x, 0) return where(lax.ge(x, zero), floor(x), ceil(x)) @@ -2146,12 +2149,12 @@ def isinf(x): _check_arraylike("isinf", x) dtype = _dtype(x) if issubdtype(dtype, floating): - return lax.eq(lax.abs(x), lax._const(x, inf)) + return lax.eq(lax.abs(x), _lax_const(x, inf)) elif issubdtype(dtype, complexfloating): re = lax.real(x) im = lax.imag(x) - return lax.bitwise_or(lax.eq(lax.abs(re), lax._const(re, inf)), - lax.eq(lax.abs(im), lax._const(im, inf))) + return lax.bitwise_or(lax.eq(lax.abs(re), _lax_const(re, inf)), + lax.eq(lax.abs(im), _lax_const(im, inf))) else: return full_like(x, False, dtype=bool_) @@ -2160,7 +2163,7 @@ def _isposneginf(infinity, x, out): raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") dtype = _dtype(x) if issubdtype(dtype, floating): - return lax.eq(x, lax._const(x, infinity)) + return lax.eq(x, _lax_const(x, infinity)) elif issubdtype(dtype, complexfloating): raise ValueError("isposinf/isneginf are not well defined for complex types") else: @@ -2580,7 +2583,7 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdims=False): _check_arraylike("count_nonzero", a) - return sum(lax.ne(a, lax._const(a, 0)), axis=axis, + return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims) @@ -2637,7 +2640,7 @@ def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, axis=axis, keepdims=keepdims, **kwargs) if nan_if_all_nan: return where(all(isnan(a), axis=axis, keepdims=keepdims), - lax._const(a, nan), out) + _lax_const(a, nan), out) else: return out @@ -2762,7 +2765,7 @@ def _cumulative_reduction(a, axis = _canonicalize_axis(axis, num_dims) if fill_nan: - a = where(isnan(a), lax._const(a, fill_value), a) + a = where(isnan(a), _lax_const(a, fill_value), a) if not dtype and _dtype(a) == bool_: dtype = int_ @@ -4973,7 +4976,7 @@ def vander(x, N=None, increasing=False): iota = lax.iota(x.dtype, N) if not increasing: - iota = lax.sub(lax._const(iota, N - 1), iota) + iota = lax.sub(_lax_const(iota, N - 1), iota) return power(x[..., None], expand_dims(iota, tuple(range(x.ndim)))) @@ -5277,7 +5280,7 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None): # TODO(phawkins): we have no way to report out of bounds errors yet. raise NotImplementedError("The 'raise' mode to jnp.take is not supported.") elif mode == "wrap": - indices = mod(indices, lax._const(indices, a.shape[axis_idx])) + indices = mod(indices, _lax_const(indices, a.shape[axis_idx])) gather_mode = lax.GatherScatterMode.PROMISE_IN_BOUNDS elif mode == "fill": # Undocumented non-standard mode corresponding to the fill_or_drop mode on @@ -5317,12 +5320,12 @@ def _take(a, indices, axis: Optional[int] = None, out=None, mode=None): def _normalize_index(index, axis_size): """Normalizes an index value in the range [-N, N) to the range [0, N).""" if core.is_constant_dim(axis_size): - axis_size_val = lax._const(index, axis_size) + axis_size_val = _lax_const(index, axis_size) else: axis_size_val = lax.convert_element_type(core.dimension_as_value(axis_size), _dtype(index)) return lax.select( - lax.lt(index, lax._const(index, 0)), + lax.lt(index, _lax_const(index, 0)), lax.add(index, axis_size_val), index) @@ -5407,7 +5410,7 @@ def _unique_sorted_mask(ar, axis): # Work around issue in sorting of complex numbers with Nan only in the # imaginary component. This can be removed if sorting in this situation # is fixed to match numpy. - aux = where(isnan(aux), lax._const(aux, nan), aux) + aux = where(isnan(aux), _lax_const(aux, nan), aux) size, *out_shape = aux.shape if _prod(out_shape) == 0: size = 1 @@ -5998,7 +6001,7 @@ def _gcd_cond_fn(xs): def _gcd_body_fn(xs): x1, x2 = xs x1, x2 = (where(x2 != 0, x2, x1), - where(x2 != 0, lax.rem(x1, x2), lax._const(x2, 0))) + where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) @_wraps(np.gcd) @@ -6020,7 +6023,7 @@ def lcm(x1, x2): _check_arraylike("lcm", x1, x2) x1, x2 = _promote_dtypes(x1, x2) d = gcd(x1, x2) - return where(d == 0, lax._const(d, 0), + return where(d == 0, _lax_const(d, 0), abs(multiply(x1, floor_divide(x2, d)))) @@ -6218,14 +6221,14 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans): q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) counts = lax.expand_dims(counts, tuple(range(q_ndim))) - q = lax.mul(q, lax.sub(counts, lax._const(q, 1))) + q = lax.mul(q, lax.sub(counts, _lax_const(q, 1))) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) - low_weight = lax.sub(lax._const(high_weight, 1), high_weight) + low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.max(lax._const(low, 0), lax.min(low, counts - 1)) - high = lax.max(lax._const(high, 0), lax.min(high, counts - 1)) + low = lax.max(_lax_const(low, 0), lax.min(low, counts - 1)) + high = lax.max(_lax_const(high, 0), lax.min(high, counts - 1)) low = lax.convert_element_type(low, int64) high = lax.convert_element_type(high, int64) out_shape = q_shape + shape_after_reduction @@ -6242,14 +6245,14 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans): a = where(any(isnan(a), axis=axis, keepdims=True), nan, a) a = lax.sort(a, dimension=axis) n = a_shape[axis] - q = lax.mul(q, lax._const(q, n - 1)) + q = lax.mul(q, _lax_const(q, n - 1)) low = lax.floor(q) high = lax.ceil(q) high_weight = lax.sub(q, low) - low_weight = lax.sub(lax._const(high_weight, 1), high_weight) + low_weight = lax.sub(_lax_const(high_weight, 1), high_weight) - low = lax.clamp(lax._const(low, 0), low, lax._const(low, n - 1)) - high = lax.clamp(lax._const(high, 0), high, lax._const(high, n - 1)) + low = lax.clamp(_lax_const(low, 0), low, _lax_const(low, n - 1)) + high = lax.clamp(_lax_const(high, 0), high, _lax_const(high, n - 1)) low = lax.convert_element_type(low, int64) high = lax.convert_element_type(high, int64) @@ -6279,10 +6282,10 @@ def _quantile(a, q, axis, interpolation, keepdims, squash_nans): elif interpolation == "higher": result = high_value elif interpolation == "nearest": - pred = lax.le(high_weight, lax._const(high_weight, 0.5)) + pred = lax.le(high_weight, _lax_const(high_weight, 0.5)) result = lax.select(pred, low_value, high_value) elif interpolation == "midpoint": - result = lax.mul(lax.add(low_value, high_value), lax._const(low_value, 0.5)) + result = lax.mul(lax.add(low_value, high_value), _lax_const(low_value, 0.5)) else: raise ValueError(f"interpolation={interpolation!r} not recognized") if keepdims and keepdim: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 6c5e3989d141..2005d91d6dd9 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -22,6 +22,7 @@ from jax import jit, custom_jvp from jax import lax +from jax._src.lax import lax as lax_internal from jax._src.lax import linalg as lax_linalg from jax._src import dtypes from jax._src.numpy.util import _wraps @@ -434,7 +435,7 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None, return jnp.sum(jnp.abs(x), axis=axis, keepdims=keepdims) else: abs_x = jnp.abs(x) - ord = lax._const(abs_x, ord) + ord = lax_internal._const(abs_x, ord) out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims) return jnp.power(out, 1. / ord) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index e1922195f9a5..65870f45140b 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -28,6 +28,7 @@ from jax.interpreters import batching from jax.interpreters import xla from jax._src.api import jit, vmap +from jax._src.lax import lax as lax_internal from jax._src.lib import xla_client from jax._src.lib import cuda_prng from jax._src.numpy.lax_numpy import ( @@ -264,7 +265,8 @@ def threefry_seed(seed: int) -> jnp.ndarray: raise TypeError(f"PRNG key seed must be an integer; got {seed!r}") convert = lambda k: lax.reshape(lax.convert_element_type(k, np.uint32), [1]) - k1 = convert(lax.shift_right_logical(seed_arr, lax._const(seed_arr, 32))) + k1 = convert( + lax.shift_right_logical(seed_arr, lax_internal._const(seed_arr, 32))) k2 = convert(jnp.bitwise_and(seed_arr, np.uint32(0xFFFFFFFF))) return lax.concatenate([k1, k2], 0) diff --git a/jax/_src/random.py b/jax/_src/random.py index 9da425f70627..171eef524e8e 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -27,8 +27,9 @@ from jax.config import config from jax.core import NamedShape from jax._src.api import jit, vmap -from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer +from jax._src.lax import lax as lax_internal from jax._src.lib import xla_bridge +from jax._src.numpy.lax_numpy import _arraylike, _check_arraylike, _convert_and_clip_integer from jax.numpy.linalg import cholesky, svd, eigh from jax.interpreters import ad from jax.interpreters import batching @@ -52,6 +53,8 @@ ### utilities +_lax_const = lax_internal._const + def _isnan(x): return lax.ne(x, x) @@ -336,14 +339,14 @@ def _randint(key, shape, minval, maxval, dtype): # causing remainders below to have no effect, which is the correct semantics. span = lax.select( maxval_out_of_range & (maxval > minval), - lax.add(span, lax._const(span, 1)), + lax.add(span, _lax_const(span, 1)), span) # To compute a remainder operation on an integer that might have twice as many # bits as we can represent in the native unsigned dtype, we compute a # multiplier equal to 2**nbits % span. To avoid overflow, we use the identity: # (a * b) % N = [(a % N) * (b % N)] % N - multiplier = lax.rem(lax._const(span, 2 ** (nbits // 2)), span) + multiplier = lax.rem(_lax_const(span, 2 ** (nbits // 2)), span) multiplier = lax.rem(lax.mul(multiplier, multiplier), span) random_offset = lax.add(lax.mul(lax.rem(higher_bits, span), multiplier), @@ -789,8 +792,8 @@ def cauchy(key: KeyArray, def _cauchy(key, shape, dtype): _check_shape("cauchy", shape) u = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) - pi = lax._const(u, np.pi) - return lax.tan(lax.mul(pi, lax.sub(u, lax._const(u, 0.5)))) + pi = _lax_const(u, np.pi) + return lax.tan(lax.mul(pi, lax.sub(u, _lax_const(u, 0.5)))) def dirichlet(key: KeyArray, @@ -876,12 +879,12 @@ def _gamma_one(key: KeyArray, alpha): # Ref: A simple method for generating gamma variables, George Marsaglia and Wai Wan Tsang # The algorithm can also be founded in: # https://en.wikipedia.org/wiki/Gamma_distribution#Generating_gamma-distributed_random_variables - zero = lax._const(alpha, 0) - one = lax._const(alpha, 1) - minus_one = lax._const(alpha, -1) - one_over_two = lax._const(alpha, 0.5) - one_over_three = lax._const(alpha, 1. / 3.) - squeeze_const = lax._const(alpha, 0.0331) + zero = _lax_const(alpha, 0) + one = _lax_const(alpha, 1) + minus_one = _lax_const(alpha, -1) + one_over_two = _lax_const(alpha, 0.5) + one_over_three = _lax_const(alpha, 1. / 3.) + squeeze_const = _lax_const(alpha, 0.0331) dtype = lax.dtype(alpha) key, subkey = _split(key) @@ -923,7 +926,7 @@ def _next_kxv(kxv): return key, X, V, U # initial state is chosen such that _cond_fn will return True - _, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, lax._const(alpha, 2))) + _, _, V, _ = lax.while_loop(_cond_fn, _body_fn, (key, zero, one, _lax_const(alpha, 2))) z = lax.mul(lax.mul(d, V), boost) return lax.select(lax.eq(z, zero), jnp.finfo(z.dtype).tiny, z) @@ -1271,7 +1274,7 @@ def logistic(key: KeyArray, def _logistic(key, shape, dtype): _check_shape("logistic", shape) x = uniform(key, shape, dtype, minval=jnp.finfo(dtype).eps, maxval=1.) - return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x))) + return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) def pareto(key: KeyArray, @@ -1353,7 +1356,7 @@ def _t(key, df, shape, dtype): df = lax.convert_element_type(df, dtype) key_n, key_g = _split(key) n = normal(key_n, shape, dtype) - two = lax._const(n, 2) + two = _lax_const(n, 2) half_df = lax.div(df, two) g = gamma(key_n, half_df, shape, dtype) return n * jnp.sqrt(half_df / g) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 324c7b570a45..26e5e8b4565c 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -22,6 +22,7 @@ from jax import jit from jax import lax, core from jax.interpreters import ad +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.lax_numpy import asarray, _reduction_dims, _promote_args_inexact from jax._src.numpy.util import _wraps @@ -89,18 +90,18 @@ def erfinv(x): @_wraps(osp_special.logit, update_doc=False) def logit(x): x = asarray(x) - return lax.log(lax.div(x, lax.sub(lax._const(x, 1), x))) + return lax.log(lax.div(x, lax.sub(_lax_const(x, 1), x))) logit.defjvps( - lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(lax._const(x, 1), x)))) + lambda g, ans, x: lax.div(g, lax.mul(x, lax.sub(_lax_const(x, 1), x)))) @api.custom_jvp @_wraps(osp_special.expit, update_doc=False) def expit(x): x = asarray(x) - one = lax._const(x, 1) + one = _lax_const(x, 1) return lax.div(one, lax.add(one, lax.exp(lax.neg(x)))) -expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans)) +expit.defjvps(lambda g, ans, x: g * ans * (_lax_const(ans, 1) - ans)) @_wraps(osp_special.logsumexp) @@ -164,7 +165,7 @@ def xlog1py(x, y): @_wraps(osp_special.entr) def entr(x): x, = _promote_args_inexact("entr", x) - return lax.select(lax.lt(x, lax._const(x, 0)), + return lax.select(lax.lt(x, _lax_const(x, 0)), lax.full_like(x, -np.inf), lax.neg(xlogy(x, x))) @@ -174,10 +175,10 @@ def multigammaln(a, d): d = core.concrete_or_error(int, d, "d argument of multigammaln") a, d_ = _promote_args_inexact("multigammaln", a, d) - constant = lax.mul(lax.mul(lax.mul(lax._const(a, 0.25), d_), - lax.sub(d_, lax._const(a, 1))), - lax.log(lax._const(a, np.pi))) - b = lax.div(jnp.arange(d, dtype=d_.dtype), lax._const(a, 2)) + constant = lax.mul(lax.mul(lax.mul(_lax_const(a, 0.25), d_), + lax.sub(d_, _lax_const(a, 1))), + lax.log(_lax_const(a, np.pi))) + b = lax.div(jnp.arange(d, dtype=d_.dtype), _lax_const(a, 2)) res = jnp.sum(gammaln(jnp.expand_dims(a, axis=-1) - jnp.expand_dims(b, axis=tuple(range(a.ndim)))), axis=-1) @@ -651,8 +652,8 @@ def _double_factorial(n): _norm_logpdf_constant = np.log(np.sqrt(2 * np.pi)) def _norm_logpdf(x): - neg_half = lax._const(x, -0.5) - log_normalizer = lax._const(x, _norm_logpdf_constant) + neg_half = _lax_const(x, -0.5) + log_normalizer = _lax_const(x, _norm_logpdf_constant) return lax.sub(lax.mul(neg_half, lax.square(x)), log_normalizer) @_wraps(osp_special.i0e) @@ -1124,7 +1125,7 @@ def _expint1(x): def _eval_expint_k(A, B, x): # helper function for all subsequent intervals A, B = [jnp.array(U, dtype=x.dtype) for U in [A, B]] - one = lax._const(x, 1.0) + one = _lax_const(x, 1.0) w = one / x f = jnp.polyval(A, w) / jnp.polyval(B, w) f = w * f + one @@ -1288,7 +1289,7 @@ def _expint7(x): def _expi_pos(x): # x > 0 - _c = lax._const + _c = _lax_const conds = [(_c(x, 0) < x) & (x <= _c(x, 2))] + [ (_c(x, 2 ** i) < x) & (x <= _c(x, 2 ** (i + 1))) for i in range(1, 6) ] @@ -1318,7 +1319,7 @@ def expi_jvp(primals, tangents): def _expn1(n, x): # exponential integral En - _c = lax._const + _c = _lax_const x = jnp.array(x) MACHEP = jnp.finfo(x.dtype).eps @@ -1356,7 +1357,7 @@ def cond(d): def _expn2(n, x): # x > 1. - _c = lax._const + _c = _lax_const BIG = _c(x, 1.44115188075855872e17) MACHEP = jnp.finfo(BIG.dtype).eps # ? zero = _c(x, 0.0) @@ -1407,7 +1408,7 @@ def cond(d): def _expn3(n, x): # n >= 5000 - _c = lax._const + _c = _lax_const one = _c(x, 1.0) xk = x + n yk = one / (xk * xk) @@ -1424,7 +1425,7 @@ def _expn3(n, x): @jit def expn(n, x): n, x = _promote_args_inexact("expn", n, x) - _c = lax._const + _c = _lax_const zero = _c(x, 0) one = _c(x, 1) conds = [ @@ -1454,7 +1455,7 @@ def expn(n, x): def expn_jvp(n, primals, tangents): (x,), (x_dot,) = primals, tangents return expn(n, x), lax.mul( - lax.neg(x_dot), expn(lax.sub(n, lax._const(n, 1)), x) + lax.neg(x_dot), expn(lax.sub(n, _lax_const(n, 1)), x) ) diff --git a/jax/_src/scipy/stats/bernoulli.py b/jax/_src/scipy/stats/bernoulli.py index b47e8c6969f1..877c997817d7 100644 --- a/jax/_src/scipy/stats/bernoulli.py +++ b/jax/_src/scipy/stats/bernoulli.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.util import _wraps from jax.scipy.special import xlogy, xlog1py @@ -24,8 +25,8 @@ @_wraps(osp_stats.bernoulli.logpmf, update_doc=False) def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("bernoulli.logpmf", k, p, loc) - zero = lax._const(k, 0) - one = lax._const(k, 1) + zero = _lax_const(k, 0) + one = _lax_const(k, 1) x = lax.sub(k, loc) log_probs = xlogy(x, p) + xlog1py(lax.sub(one, x), -p) return jnp.where(jnp.logical_or(lax.lt(x, zero), lax.gt(x, one)), diff --git a/jax/_src/scipy/stats/beta.py b/jax/_src/scipy/stats/beta.py index 2e67329de55b..b62f108a4018 100644 --- a/jax/_src/scipy/stats/beta.py +++ b/jax/_src/scipy/stats/beta.py @@ -15,6 +15,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or from jax.scipy.special import betaln, xlogy, xlog1py @@ -23,7 +24,7 @@ @_wraps(osp_stats.beta.logpdf, update_doc=False) def logpdf(x, a, b, loc=0, scale=1): x, a, b, loc, scale = _promote_args_inexact("beta.logpdf", x, a, b, loc, scale) - one = lax._const(x, 1) + one = _lax_const(x, 1) shape_term = lax.neg(betaln(a, b)) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.add(xlogy(lax.sub(a, one), y), diff --git a/jax/_src/scipy/stats/betabinom.py b/jax/_src/scipy/stats/betabinom.py index 18d4ada6ed37..e93e6117ca80 100644 --- a/jax/_src/scipy/stats/betabinom.py +++ b/jax/_src/scipy/stats/betabinom.py @@ -17,6 +17,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf, logical_or, nan from jax._src.scipy.special import betaln @@ -28,8 +29,8 @@ def logpmf(k, n, a, b, loc=0): """JAX implementation of scipy.stats.betabinom.logpmf.""" k, n, a, b, loc = _promote_args_inexact("betabinom.logpmf", k, n, a, b, loc) y = lax.sub(lax.floor(k), loc) - one = lax._const(y, 1) - zero = lax._const(y, 0) + one = _lax_const(y, 1) + zero = _lax_const(y, 0) combiln = lax.neg(lax.add(lax.log1p(n), betaln(lax.add(lax.sub(n,y), one), lax.add(y,one)))) beta_lns = lax.sub(betaln(lax.add(y,a), lax.add(lax.sub(n,y),b)), betaln(a,b)) log_probs = lax.add(combiln, beta_lns) diff --git a/jax/_src/scipy/stats/cauchy.py b/jax/_src/scipy/stats/cauchy.py index 4880b4facb50..85179c858032 100644 --- a/jax/_src/scipy/stats/cauchy.py +++ b/jax/_src/scipy/stats/cauchy.py @@ -17,6 +17,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact @@ -24,7 +25,7 @@ @_wraps(osp_stats.cauchy.logpdf, update_doc=False) def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("cauchy.logpdf", x, loc, scale) - pi = lax._const(x, np.pi) + pi = _lax_const(x, np.pi) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.mul(pi, scale)) return lax.neg(lax.add(normalize_term, lax.log1p(lax.mul(scaled_x, scaled_x)))) diff --git a/jax/_src/scipy/stats/chi2.py b/jax/_src/scipy/stats/chi2.py index a4734c680b38..f74ffd43e281 100644 --- a/jax/_src/scipy/stats/chi2.py +++ b/jax/_src/scipy/stats/chi2.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf @@ -23,8 +24,8 @@ @_wraps(osp_stats.chi2.logpdf, update_doc=False) def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("chi2.logpdf", x, df, loc, scale) - one = lax._const(x, 1) - two = lax._const(x, 2) + one = _lax_const(x, 1) + two = _lax_const(x, 2) y = lax.div(lax.sub(x, loc), scale) df_on_two = lax.div(df, two) diff --git a/jax/_src/scipy/stats/dirichlet.py b/jax/_src/scipy/stats/dirichlet.py index ac653075ba04..cbdfece02458 100644 --- a/jax/_src/scipy/stats/dirichlet.py +++ b/jax/_src/scipy/stats/dirichlet.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.util import _wraps from jax.scipy.special import gammaln, xlogy @@ -39,7 +40,7 @@ def logpdf(x, alpha): "`x` must have either the same number of entries as `alpha` " f"or one entry fewer; got x.shape={x.shape}, alpha.shape={alpha.shape}" ) - one = lax._const(x, 1) + one = _lax_const(x, 1) if x.shape[0] != alpha.shape[0]: x = jnp.concatenate([x, lax.sub(one, x.sum(0, keepdims=True))], axis=0) normalize_term = jnp.sum(gammaln(alpha)) - gammaln(jnp.sum(alpha)) diff --git a/jax/_src/scipy/stats/gamma.py b/jax/_src/scipy/stats/gamma.py index da958694c917..52e96753ab07 100644 --- a/jax/_src/scipy/stats/gamma.py +++ b/jax/_src/scipy/stats/gamma.py @@ -15,6 +15,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax.scipy.special import gammaln, xlogy @@ -23,7 +24,7 @@ @_wraps(osp_stats.gamma.logpdf, update_doc=False) def logpdf(x, a, loc=0, scale=1): x, a, loc, scale = _promote_args_inexact("gamma.logpdf", x, a, loc, scale) - one = lax._const(x, 1) + one = _lax_const(x, 1) y = lax.div(lax.sub(x, loc), scale) log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y) shape_terms = lax.add(gammaln(a), lax.log(scale)) diff --git a/jax/_src/scipy/stats/geom.py b/jax/_src/scipy/stats/geom.py index 71afffdbb94b..b51c381fc172 100644 --- a/jax/_src/scipy/stats/geom.py +++ b/jax/_src/scipy/stats/geom.py @@ -15,6 +15,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.util import _wraps from jax.scipy.special import xlog1py @@ -22,8 +23,8 @@ @_wraps(osp_stats.geom.logpmf, update_doc=False) def logpmf(k, p, loc=0): k, p, loc = jnp._promote_args_inexact("geom.logpmf", k, p, loc) - zero = lax._const(k, 0) - one = lax._const(k, 1) + zero = _lax_const(k, 0) + one = _lax_const(k, 1) x = lax.sub(k, loc) log_probs = xlog1py(lax.sub(x, one), -p) + lax.log(p) return jnp.where(lax.le(x, zero), -jnp.inf, log_probs) diff --git a/jax/_src/scipy/stats/laplace.py b/jax/_src/scipy/stats/laplace.py index e61fca34dceb..827ffd3f2bb0 100644 --- a/jax/_src/scipy/stats/laplace.py +++ b/jax/_src/scipy/stats/laplace.py @@ -15,6 +15,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact @@ -22,7 +23,7 @@ @_wraps(osp_stats.laplace.logpdf, update_doc=False) def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.logpdf", x, loc, scale) - two = lax._const(x, 2) + two = _lax_const(x, 2) linear_term = lax.div(lax.abs(lax.sub(x, loc)), scale) return lax.neg(lax.add(linear_term, lax.log(lax.mul(two, scale)))) @@ -33,9 +34,9 @@ def pdf(x, loc=0, scale=1): @_wraps(osp_stats.laplace.cdf, update_doc=False) def cdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("laplace.cdf", x, loc, scale) - half = lax._const(x, 0.5) - one = lax._const(x, 1) - zero = lax._const(x, 0) + half = _lax_const(x, 0.5) + one = _lax_const(x, 1) + zero = _lax_const(x, 0) diff = lax.div(lax.sub(x, loc), scale) return lax.select(lax.le(diff, zero), lax.mul(half, lax.exp(diff)), diff --git a/jax/_src/scipy/stats/nbinom.py b/jax/_src/scipy/stats/nbinom.py index 2c3b7b884336..8b4da9dc7dcc 100644 --- a/jax/_src/scipy/stats/nbinom.py +++ b/jax/_src/scipy/stats/nbinom.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.lax_numpy import _promote_args_inexact, where, inf from jax._src.numpy.util import _wraps from jax._src.scipy.special import gammaln, xlogy @@ -25,7 +26,7 @@ def logpmf(k, n, p, loc=0): """JAX implementation of scipy.stats.nbinom.logpmf.""" k, n, p, loc = _promote_args_inexact("nbinom.logpmf", k, n, p, loc) - one = lax._const(k, 1) + one = _lax_const(k, 1) y = lax.sub(k, loc) comb_term = lax.sub( lax.sub(gammaln(lax.add(y, n)), gammaln(n)), gammaln(lax.add(y, one)) diff --git a/jax/_src/scipy/stats/norm.py b/jax/_src/scipy/stats/norm.py index ea6e26621567..952adc082370 100644 --- a/jax/_src/scipy/stats/norm.py +++ b/jax/_src/scipy/stats/norm.py @@ -17,6 +17,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy import lax_numpy as jnp from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact @@ -26,9 +27,9 @@ def logpdf(x, loc=0, scale=1): x, loc, scale = _promote_args_inexact("norm.logpdf", x, loc, scale) scale_sqrd = lax.square(scale) - log_normalizer = lax.log(lax.mul(lax._const(x, 2 * np.pi), scale_sqrd)) + log_normalizer = lax.log(lax.mul(_lax_const(x, 2 * np.pi), scale_sqrd)) quadratic = lax.div(lax.square(lax.sub(x, loc)), scale_sqrd) - return lax.div(lax.add(log_normalizer, quadratic), lax._const(x, -2)) + return lax.div(lax.add(log_normalizer, quadratic), _lax_const(x, -2)) @_wraps(osp_stats.norm.pdf, update_doc=False) diff --git a/jax/_src/scipy/stats/pareto.py b/jax/_src/scipy/stats/pareto.py index f4bc855fa8bc..6e86049816bb 100644 --- a/jax/_src/scipy/stats/pareto.py +++ b/jax/_src/scipy/stats/pareto.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact, inf, where @@ -23,7 +24,7 @@ @_wraps(osp_stats.pareto.logpdf, update_doc=False) def logpdf(x, b, loc=0, scale=1): x, b, loc, scale = _promote_args_inexact("pareto.logpdf", x, b, loc, scale) - one = lax._const(x, 1) + one = _lax_const(x, 1) scaled_x = lax.div(lax.sub(x, loc), scale) normalize_term = lax.log(lax.div(scale, b)) log_probs = lax.neg(lax.add(normalize_term, lax.mul(lax.add(b, one), lax.log(scaled_x)))) diff --git a/jax/_src/scipy/stats/poisson.py b/jax/_src/scipy/stats/poisson.py index 26da81cd766d..c509ab5210b4 100644 --- a/jax/_src/scipy/stats/poisson.py +++ b/jax/_src/scipy/stats/poisson.py @@ -16,6 +16,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy import lax_numpy as jnp from jax.scipy.special import xlogy, gammaln, gammaincc @@ -24,7 +25,7 @@ @_wraps(osp_stats.poisson.logpmf, update_doc=False) def logpmf(k, mu, loc=0): k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) - zero = lax._const(k, 0) + zero = _lax_const(k, 0) x = lax.sub(k, loc) log_probs = xlogy(x, mu) - gammaln(x + 1) - mu return jnp.where(lax.lt(x, zero), -jnp.inf, log_probs) @@ -36,7 +37,7 @@ def pmf(k, mu, loc=0): @_wraps(osp_stats.poisson.cdf, update_doc=False) def cdf(k, mu, loc=0): k, mu, loc = jnp._promote_args_inexact("poisson.logpmf", k, mu, loc) - zero = lax._const(k, 0) + zero = _lax_const(k, 0) x = lax.sub(k, loc) p = gammaincc(jnp.floor(1 + x), mu) return jnp.where(lax.lt(x, zero), zero, p) diff --git a/jax/_src/scipy/stats/t.py b/jax/_src/scipy/stats/t.py index f7496aeb16dd..4056ebe78144 100644 --- a/jax/_src/scipy/stats/t.py +++ b/jax/_src/scipy/stats/t.py @@ -17,6 +17,7 @@ import scipy.stats as osp_stats from jax import lax +from jax._src.lax.lax import _const as _lax_const from jax._src.numpy.util import _wraps from jax._src.numpy.lax_numpy import _promote_args_inexact @@ -24,11 +25,11 @@ @_wraps(osp_stats.t.logpdf, update_doc=False) def logpdf(x, df, loc=0, scale=1): x, df, loc, scale = _promote_args_inexact("t.logpdf", x, df, loc, scale) - two = lax._const(x, 2) + two = _lax_const(x, 2) scaled_x = lax.div(lax.sub(x, loc), scale) df_over_two = lax.div(df, two) - df_plus_one_over_two = lax.add(df_over_two, lax._const(x, 0.5)) - normalize_term_const = lax.mul(lax.mul(scale, scale), lax._const(x, np.pi)) + df_plus_one_over_two = lax.add(df_over_two, _lax_const(x, 0.5)) + normalize_term_const = lax.mul(lax.mul(scale, scale), _lax_const(x, np.pi)) normalize_term_tmp = lax.div(lax.log(lax.mul(normalize_term_const, df)), two) normalize_term = lax.sub(lax.add(lax.lgamma(df_over_two), normalize_term_tmp), lax.lgamma(df_plus_one_over_two)) diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index c6d4d2555a33..0b8388fa6eb4 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -59,17 +59,20 @@ import numpy as np import jax -import jax.numpy as jnp from jax import core -from jax._src.util import unzip2 -from jax._src import ad_util -from jax._src import dispatch +from jax import lax +from jax.custom_derivatives import custom_jvp_call_jaxpr_p +from jax.interpreters import xla +import jax.linear_util as lu +import jax.numpy as jnp from jax.tree_util import (register_pytree_node, tree_structure, treedef_is_leaf, tree_flatten, tree_unflatten) -import jax.linear_util as lu -from jax.interpreters import xla -from jax.custom_derivatives import custom_jvp_call_jaxpr_p -from jax import lax + +from jax._src import ad_util +from jax._src import dispatch +from jax._src.lax import lax as lax_internal +from jax._src.util import unzip2 + def jet(fun, primals, series): r"""Taylor-mode higher-order automatic differentiation. @@ -371,7 +374,10 @@ def deriv_prop(prim, deriv, primals_in, series_in): return primal_out, series_out -def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) +def_deriv(lax.erf_p, + lambda x: lax.mul( + lax_internal._const(x, 2. / np.sqrt(np.pi)), + lax.exp(lax.neg(lax.square(x))))) def def_comp(prim, comp): diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index e8acdc76f020..3f981f3aa878 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -250,7 +250,7 @@ from jax._src.lax.lax import ( _reduce_sum, _reduce_max, _reduce_min, _reduce_or, _reduce_and, _float, _complex, _input_dtype, - _const, _eq_meet, _broadcasting_select, + _eq_meet, _broadcasting_select, _check_user_dtype_supported, _one, _zero, _upcast_fp16_for_computation, _broadcasting_shape_rule, _eye, _tri, _delta, _ones, _zeros, _dilate_shape) diff --git a/tests/lax_test.py b/tests/lax_test.py index 83ac6e29fe2c..4bffa25f5e7f 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -26,17 +26,18 @@ import numpy as np import jax -import jax.numpy as jnp from jax import core -from jax._src import dtypes from jax import lax -from jax._src import test_util as jtu -from jax import tree_util -from jax._src import lax_reference +import jax.numpy as jnp from jax.test_util import check_grads +from jax import tree_util import jax.util -from jax._src.util import prod +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src import lax_reference +from jax._src.util import prod +from jax._src.lax import lax as lax_internal from jax._src.lax.lax import _device_put_raw @@ -2626,7 +2627,7 @@ def test_const(self, dtype, weak_type): else: val = lax._convert_element_type(0, dtype, weak_type=weak_type) - const = lax._const(val, 0) + const = lax_internal._const(val, 0) self.assertEqual(dtypes.dtype(val, canonicalize=True), dtypes.dtype(const, canonicalize=True))