From 6355fac8822bced4bfa657187a7284477f373c52 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 14 Mar 2022 19:14:02 -0700 Subject: [PATCH] lax_numpy.py: factor ufuncs into their own private submodule Re-lands part of #9724 PiperOrigin-RevId: 434629548 --- jax/_src/numpy/lax_numpy.py | 660 +-------------------------------- jax/_src/numpy/ufuncs.py | 711 ++++++++++++++++++++++++++++++++++++ 2 files changed, 725 insertions(+), 646 deletions(-) create mode 100644 jax/_src/numpy/ufuncs.py diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index ba595018d14a..e704917790a2 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -37,7 +37,7 @@ import opt_einsum import jax -from jax import custom_jvp, jit +from jax import jit from jax import core from jax import errors from jax import lax @@ -51,6 +51,17 @@ from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator from jax._src.lax import lax as lax_internal from jax._src.numpy.ndarray import ndarray +from jax._src.numpy.ufuncs import ( # noqa: F401 + abs, absolute, add, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, + bitwise_and, bitwise_not, bitwise_or, bitwise_xor, cbrt, ceil, conj, conjugate, + copysign, cos, cosh, deg2rad, degrees, divide, divmod, equal, exp, exp2, expm1, + fabs, float_power, floor, floor_divide, fmod, frexp, greater, greater_equal, + heaviside, hypot, imag, invert, isfinite, isinf, isnan, isneginf, isposinf, + ldexp, left_shift, less, less_equal, log, log10, log1p, log2, logaddexp, logaddexp2, + logical_and, logical_not, logical_or, logical_xor, maximum, minimum, mod, modf, + multiply, negative, nextafter, not_equal, positive, power, rad2deg, radians, real, + reciprocal, remainder, right_shift, rint, sign, signbit, sin, sinc, sinh, sqrt, + square, subtract, tan, tanh, true_divide) from jax._src.numpy.util import ( # noqa: F401 _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _complex_elem_type, _promote_args, _promote_args_inexact, _promote_dtypes, _promote_dtypes_inexact, _promote_shapes, _register_stackable, @@ -299,387 +310,6 @@ def isscalar(element): def result_type(*args): return dtypes.result_type(*args) -def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): - if promote_to_inexact: - fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x)) - else: - fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x)) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): - if promote_to_inexact: - fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) - else: - fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): - def fn(x1, x2): - x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) - return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -fabs = _one_to_one_unop(np.fabs, lax.abs, True) -bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) -invert = _one_to_one_unop(np.invert, lax.bitwise_not) -negative = _one_to_one_unop(np.negative, lax.neg) -positive = _one_to_one_unop(np.positive, lambda x: x) - -floor = _one_to_one_unop(np.floor, lax.floor, True) -ceil = _one_to_one_unop(np.ceil, lax.ceil, True) -exp = _one_to_one_unop(np.exp, lax.exp, True) -log = _one_to_one_unop(np.log, lax.log, True) -expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) -log1p = _one_to_one_unop(np.log1p, lax.log1p, True) -sin = _one_to_one_unop(np.sin, lax.sin, True) -cos = _one_to_one_unop(np.cos, lax.cos, True) -tan = _one_to_one_unop(np.tan, lax.tan, True) -arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) -arccos = _one_to_one_unop(np.arccos, lax.acos, True) -arctan = _one_to_one_unop(np.arctan, lax.atan, True) -sinh = _one_to_one_unop(np.sinh, lax.sinh, True) -cosh = _one_to_one_unop(np.cosh, lax.cosh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -tanh = _one_to_one_unop(np.tanh, lax.tanh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) -sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) -cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) - - -add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) -bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) -bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) -bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) -left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) -equal = _one_to_one_binop(np.equal, lax.eq) -multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) -not_equal = _one_to_one_binop(np.not_equal, lax.ne) -subtract = _one_to_one_binop(np.subtract, lax.sub) -arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) -minimum = _one_to_one_binop(np.minimum, lax.min) -maximum = _one_to_one_binop(np.maximum, lax.max) -float_power = _one_to_one_binop(np.float_power, lax.pow, True) -nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) - -@_wraps(np.arccosh) -@jit -def arccosh(x): - # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different - # convention than np.arccosh. - out = lax.acosh(*_promote_args_inexact("arccosh", x)) - if issubdtype(out.dtype, np.complexfloating): - out = where(real(out) < 0, lax.neg(out), out) - return out - -def _comparison_op(numpy_fn, lax_fn): - # TODO(https://github.com/google/jax/issues/6713): decorate this function with - # jit, after fixing a surprising interaction with remat(..., concrete=True). - def fn(x1, x2): - x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) - # Comparison on complex types are defined as a lexicographic ordering on - # the (real, imag) pair. - if issubdtype(_dtype(x1), complexfloating): - rx = lax.real(x1) - ry = lax.real(x2) - return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), - lax_fn(rx, ry)) - return lax_fn(x1, x2) - return _wraps(numpy_fn)(fn) - -greater_equal = _comparison_op(np.greater_equal, lax.ge) -greater = _comparison_op(np.greater, lax.gt) -less_equal = _comparison_op(np.less_equal, lax.le) -less = _comparison_op(np.less, lax.lt) - - -def _logical_op(np_op, bitwise_op): - @_wraps(np_op, update_doc=False) - @partial(jit, inline=True) - def op(*args): - zero = lambda x: lax.full_like(x, shape=(), fill_value=0) - args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x)) - for x in args) - return bitwise_op(*_promote_args(np_op.__name__, *args)) - return op - -logical_and = _logical_op(np.logical_and, lax.bitwise_and) -logical_not = _logical_op(np.logical_not, lax.bitwise_not) -logical_or = _logical_op(np.logical_or, lax.bitwise_or) -logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor) - - -@_wraps(np.right_shift) -@partial(jit, inline=True) -def right_shift(x1, x2): - x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) - lax_fn = lax.shift_right_logical if \ - np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic - return lax_fn(x1, x2) - - -@_wraps(np.absolute) -@partial(jit, inline=True) -def absolute(x): - _check_arraylike('absolute', x) - dt = _dtype(x) - return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x) -abs = _wraps(np.abs)(absolute) - - -@_wraps(np.rint) -@jit -def rint(x): - _check_arraylike('rint', x) - dtype = _dtype(x) - if issubdtype(dtype, integer): - return lax.convert_element_type(x, float_) - if issubdtype(dtype, complexfloating): - return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) - return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) - - -@_wraps(np.sign) -@jit -def sign(x): - _check_arraylike('sign', x) - dtype = _dtype(x) - if issubdtype(dtype, complexfloating): - re = lax.real(x) - return lax.complex( - lax.sign(where(re != 0, re, lax.imag(x))), _lax_const(re, 0)) - return lax.sign(x) - - -@_wraps(np.copysign) -@jit -def copysign(x1, x2): - x1, x2 = _promote_args_inexact("copysign", x1, x2) - if issubdtype(_dtype(x1), complexfloating): - raise TypeError("copysign does not support complex-valued inputs") - return where(signbit(x2), -lax.abs(x1), lax.abs(x1)) - - -@_wraps(np.true_divide) -@partial(jit, inline=True) -def true_divide(x1, x2): - x1, x2 = _promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) - -divide = true_divide - -@_wraps(np.floor_divide) -@jit -def floor_divide(x1, x2): - x1, x2 = _promote_args("floor_divide", x1, x2) - dtype = _dtype(x1) - if issubdtype(dtype, integer): - quotient = lax.div(x1, x2) - select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) - # TODO(mattjj): investigate why subtracting a scalar was causing promotion - return where(select, quotient - np.array(1, _dtype(quotient)), quotient) - elif issubdtype(dtype, complexfloating): - x1r = lax.real(x1) - x1i = lax.imag(x1) - x2r = lax.real(x2) - x2i = lax.imag(x2) - which = lax.ge(lax.abs(x2r), lax.abs(x2i)) - rat1 = where(which, _lax_const(x2i, 1), lax.div(x2r, x2i)) - rat2 = where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) - out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), - lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) - return lax.convert_element_type(out, dtype) - else: - return _float_divmod(x1, x2)[0] - - -@_wraps(np.divmod) -@jit -def divmod(x1, x2): - x1, x2 = _promote_args("divmod", x1, x2) - if issubdtype(_dtype(x1), integer): - return floor_divide(x1, x2), remainder(x1, x2) - else: - return _float_divmod(x1, x2) - - -def _float_divmod(x1, x2): - # see float_divmod in floatobject.c of CPython - mod = lax.rem(x1, x2) - div = lax.div(lax.sub(x1, mod), x2) - - ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) - mod = lax.select(ind, mod + x2, mod) - div = lax.select(ind, div - _lax_const(div, 1), div) - - return lax.round(div), mod - - -@partial(jit, inline=True) -def _power(x1, x2): - x1, x2 = _promote_args("power", x1, x2) - dtype = _dtype(x1) - if not issubdtype(dtype, integer): - return lax.pow(x1, x2) - - # Integer power => use binary exponentiation. - - # TODO(phawkins): add integer pow support to XLA. - bits = 6 # Anything more would overflow for any x1 > 1 - zero = _lax_const(x2, 0) - one = _lax_const(x2, 1) - # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 - acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) - for _ in range(bits): - acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) - x1 = lax.mul(x1, x1) - x2 = lax.shift_right_logical(x2, one) - return acc - -@_wraps(np.power) -def power(x1, x2): - # Special case for concrete integer scalars: use binary exponentiation. - # Using lax.pow may be imprecise for floating-point values; the goal of this - # code path is to make sure we end up with a precise output for the common - # pattern ``x ** 2`` or similar. - if isinstance(core.get_aval(x2), ConcreteArray): - try: - x2 = operator.index(x2) - except TypeError: - pass - else: - return lax.integer_pow(x1, x2) - return _power(x1, x2) - -@custom_jvp -@_wraps(np.logaddexp) -@jit -def logaddexp(x1, x2): - x1, x2 = _promote_args_inexact("logaddexp", x1, x2) - amax = lax.max(x1, x2) - if issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _lax_const(amax, 2))) - out = lax.add(amax, lax.log1p(lax.exp(delta))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) - -def _wrap_between(x, _a): - """Wraps `x` between `[-a, a]`.""" - a = _lax_const(x, _a) - two_a = _lax_const(x, 2 * _a) - zero = _lax_const(x, 0) - rem = lax.rem(lax.add(x, a), two_a) - rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) - return lax.sub(rem, a) - -@logaddexp.defjvp -def _logaddexp_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) - primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - -def _replace_inf(x): - return lax.select(isposinf(real(x)), zeros_like(x), x) - - -@custom_jvp -@_wraps(np.logaddexp2) -@jit -def logaddexp2(x1, x2): - x1, x2 = _promote_args_inexact("logaddexp2", x1, x2) - amax = lax.max(x1, x2) - if issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - _lax_const(x1, np.log(2))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _lax_const(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _lax_const(x1, np.log(2)))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) - -@logaddexp2.defjvp -def _logaddexp2_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) - primal_out = logaddexp2(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - - -@_wraps(np.log2) -@partial(jit, inline=True) -def log2(x): - x, = _promote_args_inexact("log2", x) - return lax.div(lax.log(x), lax.log(_lax_const(x, 2))) - - -@_wraps(np.log10) -@partial(jit, inline=True) -def log10(x): - x, = _promote_args_inexact("log10", x) - return lax.div(lax.log(x), lax.log(_lax_const(x, 10))) - - -@_wraps(np.exp2) -@partial(jit, inline=True) -def exp2(x): - x, = _promote_args_inexact("exp2", x) - return lax.exp(lax.mul(lax.log(_lax_const(x, 2)), x)) - -@_wraps(np.signbit) -@jit -def signbit(x): - x, = _promote_args("signbit", x) - dtype = _dtype(x) - if issubdtype(dtype, integer): - return lax.lt(x, _lax_const(x, 0)) - elif issubdtype(dtype, bool_): - return full_like(x, False, dtype=bool_) - elif not issubdtype(dtype, floating): - raise ValueError( - "jax.numpy.signbit is not well defined for %s" % dtype) - - # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to - # F32. - if dtype == bfloat16: - dtype = float32 - x = lax.convert_element_type(x, float32) - - info = finfo(dtype) - if info.bits not in _INT_DTYPES: - raise NotImplementedError( - "jax.numpy.signbit only supports 16, 32, and 64-bit types.") - int_type = _INT_DTYPES[info.bits] - x = lax.bitcast_convert_type(x, int_type) - return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_) - @_wraps(np.trapz) @partial(jit, static_argnames=('axis',)) @@ -748,130 +378,6 @@ def correlate(a, v, mode='valid', *, precision=None): return _conv(a, v, mode, 'correlate', precision) -def _normalize_float(x): - info = finfo(_dtype(x)) - cond = lax.abs(x) < info.tiny - x1 = where(cond, x * _lax_const(x, 1 << info.nmant), x) - x2 = where(cond, _lax_const(np.int32, -info.nmant), _lax_const(np.int32, 0)) - int_type = _INT_DTYPES[info.bits] - return lax.bitcast_convert_type(x1, int_type), x2 - - -@_wraps(np.ldexp) -@jit -def ldexp(x1, x2): - _check_arraylike("ldexp", x1, x2) - dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) - x1, x2 = _promote_shapes("ldexp", x1, x2) - x1 = lax.convert_element_type(x1, dtype) - - info = finfo(dtype) - mask = (1 << info.nexp) - 1 - bias = ((1 << info.nexp) - 1) >> 1 - - int_type = _INT_DTYPES[info.bits] - - x, e = _normalize_float(x1) - x2 += e + ((x >> info.nmant) & mask) - bias - - # find underflow/overflow before denormalization - underflow_cond = x2 < -(bias + info.nmant) - overflow_cond = x2 > bias - - m = ones_like(x, dtype=dtype) - - # denormals - cond = x2 < -bias + 1 - x2 = where(cond, x2 + info.nmant, x2) - m = where(cond, m / (1 << info.nmant), m) - - x2 = lax.convert_element_type(x2, np.int32) - x &= ~(mask << info.nmant) - x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) - - x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) - - # underflow - x = where(underflow_cond, zeros_like(x, dtype=dtype), x) - # overflow - x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x) - # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 - return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) - - -@_wraps(np.frexp) -@jit -def frexp(x): - _check_arraylike("frexp", x) - x = asarray(x) - if issubdtype(x.dtype, complexfloating): - raise TypeError("frexp does not support complex-valued inputs") - elif not issubdtype(x.dtype, floating): - x = lax.convert_element_type(x, float_) - - dtype = _dtype(x) - info = finfo(dtype) - mask = (1 << info.nexp) - 1 - bias = ((1 << info.nexp) - 1) >> 1 - - x1, x2 = _normalize_float(x) - x2 += ((x1 >> info.nmant) & mask) - bias + 1 - x1 &= ~(mask << info.nmant) - x1 |= (bias - 1) << info.nmant - x1 = lax.bitcast_convert_type(x1, dtype) - - cond = isinf(x) | isnan(x) | (x == 0) - x2 = where(cond, zeros_like(x2), x2) - return where(cond, x, x1), lax.convert_element_type(x2, int32) - - -@_wraps(np.remainder) -@jit -def remainder(x1, x2): - x1, x2 = _promote_args("remainder", x1, x2) - zero = _lax_const(x1, 0) - trunc_mod = lax.rem(x1, x2) - trunc_mod_not_zero = lax.ne(trunc_mod, zero) - do_plus = lax.bitwise_and( - lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) -mod = _wraps(np.mod)(remainder) - - -@_wraps(np.fmod) -@jit -def fmod(x1, x2): - _check_arraylike("fmod", x1, x2) - if issubdtype(result_type(x1, x2), integer): - x2 = where(x2 == 0, 1, x2) - return lax.rem(*_promote_args("fmod", x1, x2)) - - -@_wraps(np.square) -@partial(jit, inline=True) -def square(x): - _check_arraylike("square", x) - return lax.integer_pow(x, 2) - - -@_wraps(np.deg2rad) -@partial(jit, inline=True) -def deg2rad(x): - x, = _promote_args_inexact("deg2rad", x) - return lax.mul(x, _lax_const(x, pi / 180)) - - -@_wraps(np.rad2deg) -@partial(jit, inline=True) -def rad2deg(x): - x, = _promote_args_inexact("rad2deg", x) - return lax.mul(x, _lax_const(x, 180 / pi)) - - -degrees = rad2deg -radians = deg2rad - - @_wraps(np.histogram_bin_edges) def histogram_bin_edges(a, bins=10, range=None, weights=None): if isinstance(bins, str): @@ -977,59 +483,6 @@ def histogramdd(sample, bins=10, range=None, weights=None, density=None): return hist, bin_edges_by_dim -@_wraps(np.heaviside) -@jit -def heaviside(x1, x2): - _check_arraylike("heaviside", x1, x2) - x1, x2 = _promote_dtypes_inexact(x1, x2) - zero = _lax_const(x1, 0) - return where(lax.lt(x1, zero), zero, - where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) - - -@_wraps(np.hypot) -@jit -def hypot(x1, x2): - _check_arraylike("hypot", x1, x2) - x1, x2 = _promote_dtypes_inexact(x1, x2) - x1 = lax.abs(x1) - x2 = lax.abs(x2) - x1, x2 = maximum(x1, x2), minimum(x1, x2) - return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, ones_like(x1), x1))))) - - -@_wraps(np.reciprocal) -@partial(jit, inline=True) -def reciprocal(x): - _check_arraylike("reciprocal", x) - x, = _promote_dtypes_inexact(x) - return lax.integer_pow(x, -1) - - -@_wraps(np.sinc, update_doc=False) -@jit -def sinc(x): - _check_arraylike("sinc", x) - x, = _promote_dtypes_inexact(x) - eq_zero = lax.eq(x, _lax_const(x, 0)) - pi_x = lax.mul(_lax_const(x, pi), x) - safe_pi_x = where(eq_zero, _lax_const(x, 1), pi_x) - return where(eq_zero, _sinc_maclaurin(0, pi_x), - lax.div(lax.sin(safe_pi_x), safe_pi_x)) - -@partial(custom_jvp, nondiff_argnums=(0,)) -def _sinc_maclaurin(k, x): - # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we - # compute the monomial term in the jvp rule) - if k % 2: - return lax.full_like(x, 0) - else: - return lax.full_like(x, (-1) ** (k // 2) / (k + 1)) - -@_sinc_maclaurin.defjvp -def _sinc_maclaurin_jvp(k, primals, tangents): - (x,), (t,) = primals, tangents - return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t _ARRAY_VIEW_DOC = """ The JAX version of this function may in some cases return a copy rather than a @@ -1088,29 +541,6 @@ def fliplr(m): def flipud(m): return _flip(m, 0) - -@_wraps(np.conjugate) -@partial(jit, inline=True) -def conjugate(x): - _check_arraylike("conjugate", x) - return lax.conj(x) if iscomplexobj(x) else x -conj = conjugate - - -@_wraps(np.imag) -@partial(jit, inline=True) -def imag(val): - _check_arraylike("imag", val) - return lax.imag(val) if iscomplexobj(val) else zeros_like(val) - - -@_wraps(np.real) -@partial(jit, inline=True) -def real(val): - _check_arraylike("real", val) - return lax.real(val) if iscomplexobj(val) else val - - @_wraps(np.iscomplex) @jit def iscomplex(x): @@ -1591,7 +1021,7 @@ def interp(x, xp, fp, left=None, right=None, period=None): In the JAX version, the `assume_unique` argument is not referenced. """) @partial(jit, static_argnames=('assume_unique', 'invert',)) -def in1d(ar1, ar2, assume_unique=False, invert=False): +def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811 _check_arraylike("in1d", ar1, ar2) ar1 = ravel(ar1) ar2 = ravel(ar2) @@ -1752,7 +1182,7 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): @_wraps(np.isin, lax_description=""" In the JAX version, the `assume_unique` argument is not referenced. """) -def isin(element, test_elements, assume_unique=False, invert=False): +def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811 result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert) return result.reshape(shape(element)) @@ -1963,68 +1393,6 @@ def fix(x, out=None): return where(lax.ge(x, zero), floor(x), ceil(x)) -@_wraps(np.modf, skip_params=['out']) -@jit -def modf(x, out=None): - _check_arraylike("modf", x) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") - whole = fix(x) - return x - whole, whole - - -@_wraps(np.isfinite) -@jit -def isfinite(x): - _check_arraylike("isfinite", x) - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.is_finite(x) - elif issubdtype(dtype, complexfloating): - return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) - else: - return full_like(x, True, dtype=bool_) - -@_wraps(np.isinf) -@jit -def isinf(x): - _check_arraylike("isinf", x) - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.eq(lax.abs(x), _lax_const(x, inf)) - elif issubdtype(dtype, complexfloating): - re = lax.real(x) - im = lax.imag(x) - return lax.bitwise_or(lax.eq(lax.abs(re), _lax_const(re, inf)), - lax.eq(lax.abs(im), _lax_const(im, inf))) - else: - return full_like(x, False, dtype=bool_) - -def _isposneginf(infinity, x, out): - if out is not None: - raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.eq(x, _lax_const(x, infinity)) - elif issubdtype(dtype, complexfloating): - raise ValueError("isposinf/isneginf are not well defined for complex types") - else: - return full_like(x, False, dtype=bool_) - -isposinf = _wraps(np.isposinf, skip_params=['out'])( - lambda x, out=None: _isposneginf(inf, x, out) -) - -isneginf = _wraps(np.isneginf, skip_params=['out'])( - lambda x, out=None: _isposneginf(-inf, x, out) -) - -@_wraps(np.isnan) -@jit -def isnan(x): - _check_arraylike("isnan", x) - return lax.ne(x, x) - @_wraps(np.nan_to_num) @jit def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py new file mode 100644 index 000000000000..bc11eaac0243 --- /dev/null +++ b/jax/_src/numpy/ufuncs.py @@ -0,0 +1,711 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file +""" +Implements ufuncs for jax.numpy. +""" + +from functools import partial +import operator +from textwrap import dedent + +import numpy as np + +from jax._src.api import jit, custom_jvp +from jax._src import dtypes +from jax._src.lax import lax as lax_internal +from jax._src.numpy.util import ( + _check_arraylike, _promote_args, _promote_args_inexact, + _promote_dtypes_inexact, _promote_shapes, _where, _wraps) +from jax import core +from jax import lax + +_lax_const = lax_internal._const + +_INT_DTYPES = { + 16: np.int16, + 32: np.int32, + 64: np.int64, +} + + +def _constant_like(x, const): + return np.array(const, dtype=dtypes.dtype(x)) + + +def _result_dtype(op, *args): + """Compute result dtype of applying op to arguments with given dtypes.""" + args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args] + return dtypes.dtype(op(*args)) + + +def _replace_inf(x): + return lax.select(isposinf(real(x)), lax_internal._zeros(x), x) + + +def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): + if promote_to_inexact: + fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x)) + else: + fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x)) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): + if promote_to_inexact: + fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) + else: + fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): + def fn(x1, x2): + x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) + return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _comparison_op(numpy_fn, lax_fn): + # TODO(https://github.com/google/jax/issues/6713): decorate this function with + # jit, after fixing a surprising interaction with remat(..., concrete=True). + def fn(x1, x2): + x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) + # Comparison on complex types are defined as a lexicographic ordering on + # the (real, imag) pair. + if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): + rx = lax.real(x1) + ry = lax.real(x2) + return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), + lax_fn(rx, ry)) + return lax_fn(x1, x2) + return _wraps(numpy_fn)(fn) + + +def _logical_op(np_op, bitwise_op): + @_wraps(np_op, update_doc=False) + @partial(jit, inline=True) + def op(*args): + zero = lambda x: lax.full_like(x, shape=(), fill_value=0) + args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x)) + for x in args) + return bitwise_op(*_promote_args(np_op.__name__, *args)) + return op + + +fabs = _one_to_one_unop(np.fabs, lax.abs, True) +bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) +invert = _one_to_one_unop(np.invert, lax.bitwise_not) +negative = _one_to_one_unop(np.negative, lax.neg) +positive = _one_to_one_unop(np.positive, lambda x: x) +floor = _one_to_one_unop(np.floor, lax.floor, True) +ceil = _one_to_one_unop(np.ceil, lax.ceil, True) +exp = _one_to_one_unop(np.exp, lax.exp, True) +log = _one_to_one_unop(np.log, lax.log, True) +expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) +log1p = _one_to_one_unop(np.log1p, lax.log1p, True) +sin = _one_to_one_unop(np.sin, lax.sin, True) +cos = _one_to_one_unop(np.cos, lax.cos, True) +tan = _one_to_one_unop(np.tan, lax.tan, True) +arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) +arccos = _one_to_one_unop(np.arccos, lax.acos, True) +arctan = _one_to_one_unop(np.arctan, lax.atan, True) +sinh = _one_to_one_unop(np.sinh, lax.sinh, True) +cosh = _one_to_one_unop(np.cosh, lax.cosh, True) +arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) +tanh = _one_to_one_unop(np.tanh, lax.tanh, True) +arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) +sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) +cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) + +add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) +bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) +bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) +bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) +left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) +equal = _one_to_one_binop(np.equal, lax.eq) +multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) +not_equal = _one_to_one_binop(np.not_equal, lax.ne) +subtract = _one_to_one_binop(np.subtract, lax.sub) +arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) +minimum = _one_to_one_binop(np.minimum, lax.min) +maximum = _one_to_one_binop(np.maximum, lax.max) +float_power = _one_to_one_binop(np.float_power, lax.pow, True) +nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) + +greater_equal = _comparison_op(np.greater_equal, lax.ge) +greater = _comparison_op(np.greater, lax.gt) +less_equal = _comparison_op(np.less_equal, lax.le) +less = _comparison_op(np.less, lax.lt) + +logical_and = _logical_op(np.logical_and, lax.bitwise_and) +logical_not = _logical_op(np.logical_not, lax.bitwise_not) +logical_or = _logical_op(np.logical_or, lax.bitwise_or) +logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor) + + +@_wraps(np.arccosh) +@jit +def arccosh(x): + # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different + # convention than np.arccosh. + out = lax.acosh(*_promote_args_inexact("arccosh", x)) + if dtypes.issubdtype(out.dtype, np.complexfloating): + out = _where(real(out) < 0, lax.neg(out), out) + return out + + +@_wraps(np.right_shift) +@partial(jit, inline=True) +def right_shift(x1, x2): + x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) + lax_fn = lax.shift_right_logical if \ + np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic + return lax_fn(x1, x2) + + +@_wraps(np.absolute) +@partial(jit, inline=True) +def absolute(x): + _check_arraylike('absolute', x) + dt = dtypes.dtype(x) + return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +abs = _wraps(np.abs)(absolute) + + +@_wraps(np.rint) +@jit +def rint(x): + _check_arraylike('rint', x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.integer): + return lax.convert_element_type(x, dtypes.float_) + if dtypes.issubdtype(dtype, np.complexfloating): + return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) + return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) + + +@_wraps(np.sign) +@jit +def sign(x): + _check_arraylike('sign', x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.complexfloating): + re = lax.real(x) + return lax.complex( + lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) + return lax.sign(x) + + +@_wraps(np.copysign) +@jit +def copysign(x1, x2): + x1, x2 = _promote_args_inexact("copysign", x1, x2) + if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): + raise TypeError("copysign does not support complex-valued inputs") + return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) + + +@_wraps(np.true_divide) +@partial(jit, inline=True) +def true_divide(x1, x2): + x1, x2 = _promote_args_inexact("true_divide", x1, x2) + return lax.div(x1, x2) + +divide = true_divide + + +@_wraps(np.floor_divide) +@jit +def floor_divide(x1, x2): + x1, x2 = _promote_args("floor_divide", x1, x2) + dtype = dtypes.dtype(x1) + if dtypes.issubdtype(dtype, np.integer): + quotient = lax.div(x1, x2) + select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) + # TODO(mattjj): investigate why subtracting a scalar was causing promotion + return _where(select, quotient - 1, quotient) + elif dtypes.issubdtype(dtype, np.complexfloating): + x1r = lax.real(x1) + x1i = lax.imag(x1) + x2r = lax.real(x2) + x2i = lax.imag(x2) + which = lax.ge(lax.abs(x2r), lax.abs(x2i)) + rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) + rat2 = _where(which, lax.div(x2i, x2r), _lax_const(x2i, 1)) + out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), + lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) + return lax.convert_element_type(out, dtype) + else: + return _float_divmod(x1, x2)[0] + + +@_wraps(np.divmod) +@jit +def divmod(x1, x2): + x1, x2 = _promote_args("divmod", x1, x2) + if dtypes.issubdtype(dtypes.dtype(x1), np.integer): + return floor_divide(x1, x2), remainder(x1, x2) + else: + return _float_divmod(x1, x2) + + +def _float_divmod(x1, x2): + # see float_divmod in floatobject.c of CPython + mod = lax.rem(x1, x2) + div = lax.div(lax.sub(x1, mod), x2) + + ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) + mod = lax.select(ind, mod + x2, mod) + div = lax.select(ind, div - _constant_like(div, 1), div) + + return lax.round(div), mod + + +@partial(jit, inline=True) +def _power(x1, x2): + x1, x2 = _promote_args("power", x1, x2) + dtype = dtypes.dtype(x1) + if not dtypes.issubdtype(dtype, np.integer): + return lax.pow(x1, x2) + + # Integer power => use binary exponentiation. + + # TODO(phawkins): add integer pow support to XLA. + bits = 6 # Anything more would overflow for any x1 > 1 + zero = _constant_like(x2, 0) + one = _constant_like(x2, 1) + # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 + acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) + for _ in range(bits): + acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) + x1 = lax.mul(x1, x1) + x2 = lax.shift_right_logical(x2, one) + return acc + + +@_wraps(np.power) +def power(x1, x2): + # Special case for concrete integer scalars: use binary exponentiation. + # Using lax.pow may be imprecise for floating-point values; the goal of this + # code path is to make sure we end up with a precise output for the common + # pattern ``x ** 2`` or similar. + if isinstance(core.get_aval(x2), core.ConcreteArray): + try: + x2 = operator.index(x2) + except TypeError: + pass + else: + return lax.integer_pow(x1, x2) + return _power(x1, x2) + + +@custom_jvp +@_wraps(np.logaddexp) +@jit +def logaddexp(x1, x2): + x1, x2 = _promote_args_inexact("logaddexp", x1, x2) + amax = lax.max(x1, x2) + if dtypes.issubdtype(x1.dtype, np.floating): + delta = lax.sub(x1, x2) + return lax.select(lax_internal._isnan(delta), + lax.add(x1, x2), # NaNs or infinities of the same sign. + lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) + else: + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) + out = lax.add(amax, lax.log1p(lax.exp(delta))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + + +def _wrap_between(x, _a): + """Wraps `x` between `[-a, a]`.""" + a = _constant_like(x, _a) + two_a = _constant_like(x, 2 * _a) + zero = _constant_like(x, 0) + rem = lax.rem(lax.add(x, a), two_a) + rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) + return lax.sub(rem, a) + + +@logaddexp.defjvp +def _logaddexp_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) + primal_out = logaddexp(x1, x2) + tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out + + +@custom_jvp +@_wraps(np.logaddexp2) +@jit +def logaddexp2(x1, x2): + x1, x2 = _promote_args_inexact("logaddexp2", x1, x2) + amax = lax.max(x1, x2) + if dtypes.issubdtype(x1.dtype, np.floating): + delta = lax.sub(x1, x2) + return lax.select(lax_internal._isnan(delta), + lax.add(x1, x2), # NaNs or infinities of the same sign. + lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), + _constant_like(x1, np.log(2))))) + else: + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) + out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) + + +@logaddexp2.defjvp +def _logaddexp2_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) + primal_out = logaddexp2(x1, x2) + tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out + + +@_wraps(np.log2) +@partial(jit, inline=True) +def log2(x): + x, = _promote_args_inexact("log2", x) + return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + + +@_wraps(np.log10) +@partial(jit, inline=True) +def log10(x): + x, = _promote_args_inexact("log10", x) + return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + + +@_wraps(np.exp2) +@partial(jit, inline=True) +def exp2(x): + x, = _promote_args_inexact("exp2", x) + return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) + + +@_wraps(np.signbit) +@jit +def signbit(x): + x, = _promote_args("signbit", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.integer): + return lax.lt(x, _constant_like(x, 0)) + elif dtypes.issubdtype(dtype, np.bool_): + return lax.full_like(x, False, dtype=np.bool_) + elif not dtypes.issubdtype(dtype, np.floating): + raise ValueError( + "jax.numpy.signbit is not well defined for %s" % dtype) + + # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to + # F32. + if dtype == dtypes.bfloat16: + dtype = np.float32 + x = lax.convert_element_type(x, np.float32) + + info = dtypes.finfo(dtype) + if info.bits not in _INT_DTYPES: + raise NotImplementedError( + "jax.numpy.signbit only supports 16, 32, and 64-bit types.") + int_type = _INT_DTYPES[info.bits] + x = lax.bitcast_convert_type(x, int_type) + return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_) + + +def _normalize_float(x): + info = dtypes.finfo(dtypes.dtype(x)) + cond = lax.abs(x) < info.tiny + x1 = _where(cond, x * _lax_const(x, 1 << info.nmant), x) + x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32)) + int_type = _INT_DTYPES[info.bits] + return lax.bitcast_convert_type(x1, int_type), x2 + + +@_wraps(np.ldexp) +@jit +def ldexp(x1, x2): + _check_arraylike("ldexp", x1, x2) + dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) + x1, x2 = _promote_shapes("ldexp", x1, x2) + x1 = lax.convert_element_type(x1, dtype) + + info = dtypes.finfo(dtype) + mask = (1 << info.nexp) - 1 + bias = ((1 << info.nexp) - 1) >> 1 + + int_type = _INT_DTYPES[info.bits] + + x, e = _normalize_float(x1) + x2 += e + ((x >> info.nmant) & mask) - bias + + # find underflow/overflow before denormalization + underflow_cond = x2 < -(bias + info.nmant) + overflow_cond = x2 > bias + + m = lax.full_like(x, 1, dtype=dtype) + + # denormals + cond = x2 < -bias + 1 + x2 = _where(cond, x2 + info.nmant, x2) + m = _where(cond, m / (1 << info.nmant), m) + + x2 = lax.convert_element_type(x2, np.int32) + x &= ~(mask << info.nmant) + x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) + + x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) + + # underflow + x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) + # overflow + x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) + # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 + return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) + + +@_wraps(np.frexp) +@jit +def frexp(x): + _check_arraylike("frexp", x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise TypeError("frexp does not support complex-valued inputs") + elif not dtypes.issubdtype(dtypes.dtype(x), np.floating): + x = lax.convert_element_type(x, np.float_) + + dtype = dtypes.dtype(x) + info = dtypes.finfo(dtype) + mask = (1 << info.nexp) - 1 + bias = ((1 << info.nexp) - 1) >> 1 + + x1, x2 = _normalize_float(x) + x2 += ((x1 >> info.nmant) & mask) - bias + 1 + x1 &= ~(mask << info.nmant) + x1 |= (bias - 1) << info.nmant + x1 = lax.bitcast_convert_type(x1, dtype) + + cond = isinf(x) | isnan(x) | (x == 0) + x2 = _where(cond, lax_internal._zeros(x2), x2) + return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) + + +@_wraps(np.remainder) +@jit +def remainder(x1, x2): + x1, x2 = _promote_args("remainder", x1, x2) + zero = _constant_like(x1, 0) + trunc_mod = lax.rem(x1, x2) + trunc_mod_not_zero = lax.ne(trunc_mod, zero) + do_plus = lax.bitwise_and( + lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) + return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) +mod = _wraps(np.mod)(remainder) + + +@_wraps(np.fmod) +@jit +def fmod(x1, x2): + _check_arraylike("fmod", x1, x2) + if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): + x2 = _where(x2 == 0, lax_internal._ones(x2), x2) + return lax.rem(*_promote_args("fmod", x1, x2)) + + +@_wraps(np.square) +@partial(jit, inline=True) +def square(x): + _check_arraylike("square", x) + return lax.integer_pow(x, 2) + + +@_wraps(np.deg2rad) +@partial(jit, inline=True) +def deg2rad(x): + x, = _promote_args_inexact("deg2rad", x) + return lax.mul(x, _lax_const(x, np.pi / 180)) + + +@_wraps(np.rad2deg) +@partial(jit, inline=True) +def rad2deg(x): + x, = _promote_args_inexact("rad2deg", x) + return lax.mul(x, _lax_const(x, 180 / np.pi)) + + +degrees = rad2deg +radians = deg2rad + + +@_wraps(np.conjugate) +@partial(jit, inline=True) +def conjugate(x): + _check_arraylike("conjugate", x) + return lax.conj(x) if np.iscomplexobj(x) else x +conj = conjugate + + +@_wraps(np.imag) +@partial(jit, inline=True) +def imag(val): + _check_arraylike("imag", val) + return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) + + +@_wraps(np.real) +@partial(jit, inline=True) +def real(val): + _check_arraylike("real", val) + return lax.real(val) if np.iscomplexobj(val) else val + +@_wraps(np.modf, skip_params=['out']) +@jit +def modf(x, out=None): + _check_arraylike("modf", x) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") + whole = _where(lax.ge(x, lax_internal._zero(x)), floor(x), ceil(x)) + return x - whole, whole + + +@_wraps(np.isfinite) +@jit +def isfinite(x): + _check_arraylike("isfinite", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.is_finite(x) + elif dtypes.issubdtype(dtype, np.complexfloating): + return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) + else: + return lax.full_like(x, True, dtype=np.bool_) + + +@_wraps(np.isinf) +@jit +def isinf(x): + _check_arraylike("isinf", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.eq(lax.abs(x), _constant_like(x, np.inf)) + elif dtypes.issubdtype(dtype, np.complexfloating): + re = lax.real(x) + im = lax.imag(x) + return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), + lax.eq(lax.abs(im), _constant_like(im, np.inf))) + else: + return lax.full_like(x, False, dtype=np.bool_) + + +def _isposneginf(infinity, x, out): + if out is not None: + raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.eq(x, _constant_like(x, infinity)) + elif dtypes.issubdtype(dtype, np.complexfloating): + raise ValueError("isposinf/isneginf are not well defined for complex types") + else: + return lax.full_like(x, False, dtype=np.bool_) + + +isposinf = _wraps(np.isposinf, skip_params=['out'])( + lambda x, out=None: _isposneginf(np.inf, x, out) +) + + +isneginf = _wraps(np.isneginf, skip_params=['out'])( + lambda x, out=None: _isposneginf(-np.inf, x, out) +) + + +@_wraps(np.isnan) +@jit +def isnan(x): + _check_arraylike("isnan", x) + return lax.ne(x, x) + + +@_wraps(np.heaviside) +@jit +def heaviside(x1, x2): + _check_arraylike("heaviside", x1, x2) + x1, x2 = _promote_dtypes_inexact(x1, x2) + zero = _lax_const(x1, 0) + return _where(lax.lt(x1, zero), zero, + _where(lax.gt(x1, zero), _lax_const(x1, 1), x2)) + + +@_wraps(np.hypot) +@jit +def hypot(x1, x2): + _check_arraylike("hypot", x1, x2) + x1, x2 = _promote_dtypes_inexact(x1, x2) + x1 = lax.abs(x1) + x2 = lax.abs(x2) + x1, x2 = maximum(x1, x2), minimum(x1, x2) + return lax.select(x1 == 0, x1, x1 * lax.sqrt(1 + lax.square(lax.div(x2, lax.select(x1 == 0, lax_internal._ones(x1), x1))))) + + +@_wraps(np.reciprocal) +@partial(jit, inline=True) +def reciprocal(x): + _check_arraylike("reciprocal", x) + x, = _promote_dtypes_inexact(x) + return lax.integer_pow(x, -1) + + +@_wraps(np.sinc, update_doc=False) +@jit +def sinc(x): + _check_arraylike("sinc", x) + x, = _promote_dtypes_inexact(x) + eq_zero = lax.eq(x, _lax_const(x, 0)) + pi_x = lax.mul(_lax_const(x, np.pi), x) + safe_pi_x = _where(eq_zero, _lax_const(x, 1), pi_x) + return _where(eq_zero, _sinc_maclaurin(0, pi_x), + lax.div(lax.sin(safe_pi_x), safe_pi_x)) + + +@partial(custom_jvp, nondiff_argnums=(0,)) +def _sinc_maclaurin(k, x): + # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we + # compute the monomial term in the jvp rule) + if k % 2: + return lax.full_like(x, 0) + else: + return lax.full_like(x, (-1) ** (k // 2) / (k + 1)) + +@_sinc_maclaurin.defjvp +def _sinc_maclaurin_jvp(k, primals, tangents): + (x,), (t,) = primals, tangents + return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t