diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index f85af1e2dc7c..607604d9f010 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -51,6 +51,12 @@ from jax._src.lax.lax import _array_copy, _sort_lt_comparator, _sort_le_comparator from jax._src.lax import lax as lax_internal from jax._src.numpy.ndarray import ndarray +from jax._src.numpy.reductions import ( # noqa: F401 + _ensure_optional_axes, _reduction_dims, + alltrue, amin, amax, any, all, average, count_nonzero, cumsum, cumprod, cumproduct, + max, mean, min, nancumsum, nancumprod, nanmax, nanmean, nanmin, nanprod, nanstd, + nansum, nanvar, prod, product, ptp, sometrue, std, sum, var, +) from jax._src.numpy.ufuncs import ( # noqa: F401 abs, absolute, add, arccos, arccosh, arcsin, arcsinh, arctan, arctan2, arctanh, bitwise_and, bitwise_not, bitwise_or, bitwise_xor, cbrt, ceil, conj, conjugate, @@ -69,7 +75,7 @@ from jax._src.numpy.vectorize import vectorize from jax._src.ops import scatter from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio, - canonicalize_axis as _canonicalize_axis, maybe_named_axis) + canonicalize_axis as _canonicalize_axis) newaxis = None @@ -1341,375 +1347,6 @@ def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): x = where(isneginf(x), array(neginf, dtype=x.dtype), x) return x -### Reducers - -def _reduction(a, name, np_fun, op, init_val, has_identity=True, - preproc=None, bool_op=None, upcast_f16_for_computation=False, - axis=None, dtype=None, out=None, keepdims=False, initial=None, - where_=None, parallel_reduce=None): - bool_op = bool_op or op - # Note: we must accept out=None as an argument, because numpy reductions delegate to - # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method - # exists, passing along all its arguments. - if out is not None: - raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") - _check_arraylike(name, a) - lax_internal._check_user_dtype_supported(dtype, name) - axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") - - if initial is None and not has_identity: - if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)): - raise ValueError(f"zero-size array to reduction operation {name} which has no identity") - if where_ is not None: - raise ValueError(f"reduction operation {name} does not have an identity, so to use a " - f"where mask one has to specify 'initial'") - - a = a if isinstance(a, ndarray) else asarray(a) - a = preproc(a) if preproc else a - pos_dims, dims = _reduction_dims(a, axis) - result_dtype = dtypes.canonicalize_dtype(dtype or _dtype(np_fun(np.ones((), dtype=_dtype(a))))) - if upcast_f16_for_computation and issubdtype(result_dtype, inexact): - computation_dtype = promote_types(result_dtype, float32) - else: - computation_dtype = result_dtype - a = lax.convert_element_type(a, computation_dtype) - op = op if computation_dtype != np.bool_ else bool_op - # NB: in XLA, init_val must be an identity for the op, so the user-specified - # initial value must be applied afterward. - init_val = _reduction_init_val(a, init_val) - if where_ is not None: - a = where(where_, a, init_val) - if pos_dims is not dims: - if parallel_reduce is None: - raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()") - result = parallel_reduce(a, dims) - else: - result = lax.reduce(a, init_val, op, dims) - if initial is not None: - result = op(lax.convert_element_type(initial, a.dtype), result) - if keepdims: - result = expand_dims(result, pos_dims) - return lax.convert_element_type(result, dtype or result_dtype) - -def _canonicalize_axis_allow_named(x, rank): - return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) - -def _reduction_dims(a, axis): - if axis is None: - return (tuple(range(ndim(a))),) * 2 - elif not isinstance(axis, (np.ndarray, tuple, list)): - axis = (axis,) - canon_axis = tuple(_canonicalize_axis_allow_named(x, ndim(a)) - for x in axis) - if len(canon_axis) != len(set(canon_axis)): - raise ValueError(f"duplicate value in 'axis': {axis}") - canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int)) - if len(canon_pos_axis) != len(canon_axis): - return canon_pos_axis, canon_axis - else: - return canon_axis, canon_axis - -def _reduction_init_val(a, init_val): - # This function uses np.* functions because lax pattern matches against the - # specific concrete values of the reduction inputs. - a_dtype = dtypes.canonicalize_dtype(_dtype(a)) - if a_dtype == 'bool': - return np.array(init_val > 0, dtype=a_dtype) - try: - return np.array(init_val, dtype=a_dtype) - except OverflowError: - assert issubdtype(a_dtype, integer) - sign, info = np.sign(init_val), iinfo(a_dtype) - return np.array(info.min if sign < 0 else info.max, dtype=a_dtype) - -def _cast_to_bool(operand): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=np.ComplexWarning) - return lax.convert_element_type(operand, bool_) - - -def _ensure_optional_axes(x): - def force(x): - if x is None: - return None - try: - return operator.index(x) - except TypeError: - return tuple(i if isinstance(i, str) else operator.index(i) for i in x) - return core.concrete_or_error( - force, x, "The axis argument must be known statically.") - - -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) -def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None, keepdims=None, initial=None, where=None): - return _reduction(a, "sum", np.sum, lax.add, 0, - bool_op=lax.bitwise_or, upcast_f16_for_computation=True, - axis=axis, dtype=dtype, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.psum) - -@_wraps(np.sum, skip_params=['out']) -def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): - return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, - keepdims=keepdims, initial=initial, where=where) - - -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) -def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None, keepdims=None, initial=None, where=None): - return _reduction(a, "prod", np.prod, lax.mul, 1, - bool_op=lax.bitwise_and, upcast_f16_for_computation=True, - axis=axis, dtype=dtype, out=out, keepdims=keepdims, - initial=initial, where_=where) - -@_wraps(np.prod, skip_params=['out']) -def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): - return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, - out=out, keepdims=keepdims, initial=initial, where=where) - - -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.pmax) - -@_wraps(np.max, skip_params=['out']) -def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, - keepdims=keepdims, initial=initial, where=where) - -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where_=where, parallel_reduce=lax.pmin) - -@_wraps(np.min, skip_params=['out']) -def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, - keepdims=keepdims, initial=initial, where=where) - -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): - return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, - axis=axis, out=out, keepdims=keepdims, where_=where) - -@_wraps(np.all, skip_params=['out']) -def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): - return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, - keepdims=keepdims, where=where) - -@partial(jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): - return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, - axis=axis, out=out, keepdims=keepdims, where_=where) - -@_wraps(np.any, skip_params=['out']) -def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): - return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, - keepdims=keepdims, where=where) - -product = prod -amin = min -amax = max -alltrue = all -sometrue = any - -def _axis_size(a, axis): - if not isinstance(axis, (tuple, list)): - axis = (axis,) - size = 1 - a_shape = shape(a) - for a in axis: - size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) - return size - -@_wraps(np.mean, skip_params=['out']) -def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, *, where=None): - return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, - where=where) - -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) -def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, *, where=None): - _check_arraylike("mean", a) - lax_internal._check_user_dtype_supported(dtype, "mean") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") - - if where is None: - if axis is None: - normalizer = core.dimension_as_value(size(a)) - else: - normalizer = core.dimension_as_value(_axis_size(a, axis)) - else: - normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims) - - if dtype is None: - if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer): - dtype = float_ - else: - dtype = _dtype(a) - dtype = dtypes.canonicalize_dtype(dtype) - - return lax.div( - sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), - lax.convert_element_type(normalizer, dtype)) - -@_wraps(np.average) -def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, - returned=False): - return _average(a, _ensure_optional_axes(axis), weights, returned) - -@partial(jit, static_argnames=('axis', 'returned'), inline=True) -def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, - returned=False): - a = asarray(a) - - if weights is None: # Treat all weights as 1 - avg = mean(a, axis=axis) - if axis is None: - weights_sum = full((), core.dimension_as_value(size(a)), dtype=avg.dtype) - else: - weights_sum = full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) - else: - weights = asarray(weights) - - if issubdtype(a.dtype, inexact): - out_dtype = result_type(a.dtype, weights.dtype) - else: - out_dtype = result_type(a.dtype, weights.dtype, float_) - out_dtype = dtypes.canonicalize_dtype(out_dtype) - - a_shape = shape(a) - a_ndim = len(a_shape) - weights_shape = shape(weights) - axis = None if axis is None else _canonicalize_axis(axis, a_ndim) - - if a_shape != weights_shape: - # Make sure the dimensions work out - if axis is None: - raise ValueError("Axis must be specified when shapes of a and " - "weights differ.") - if len(weights_shape) != 1: - raise ValueError("1D weights expected when shapes of a and " - "weights differ.") - if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): - raise ValueError("Length of weights not " - "compatible with specified axis.") - - weights = broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) - weights = moveaxis(weights, -1, axis) - - weights_sum = sum(weights, axis=axis, dtype=out_dtype) - avg = sum(multiply(a, weights), axis=axis, dtype=out_dtype) / weights_sum - - if returned: - if avg.shape != weights_sum.shape: - weights_sum = broadcast_to(weights_sum, avg.shape) - return avg, weights_sum - return avg - - -@_wraps(np.var, skip_params=['out']) -def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): - return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, - where=where) - -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): - _check_arraylike("var", a) - lax_internal._check_user_dtype_supported(dtype, "var") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.var is not supported.") - - a_dtype, dtype = _var_promote_types(_dtype(a), dtype) - a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where) - centered = a - a_mean - if issubdtype(centered.dtype, complexfloating): - centered = lax.real(lax.mul(centered, lax.conj(centered))) - else: - centered = lax.square(centered) - - if where is None: - if axis is None: - normalizer = core.dimension_as_value(size(a)) - else: - normalizer = core.dimension_as_value(_axis_size(a, axis)) - else: - normalizer = sum(broadcast_to(where, shape(a)), axis, dtype=dtype, keepdims=keepdims) - normalizer = normalizer - ddof - - result = sum(centered, axis, keepdims=keepdims, where=where) - out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) - return lax.convert_element_type(out, dtype) - - -def _var_promote_types(a_dtype, dtype): - if dtype: - if (not issubdtype(dtype, complexfloating) and - issubdtype(a_dtype, complexfloating)): - msg = ("jax.numpy.var does not yet support real dtype parameters when " - "computing the variance of an array of complex values. The " - "semantics of numpy.var seem unclear in this case. Please comment " - "on https://github.com/google/jax/issues/2283 if this behavior is " - "important to you.") - raise ValueError(msg) - a_dtype = promote_types(a_dtype, dtype) - else: - if not issubdtype(a_dtype, inexact): - dtype = a_dtype = dtypes.canonicalize_dtype(float_) - else: - dtype = _complex_elem_type(a_dtype) - a_dtype = promote_types(a_dtype, float32) - return a_dtype, dtype - - -@_wraps(np.std, skip_params=['out']) -def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): - return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, - where=where) - -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): - _check_arraylike("std", a) - lax_internal._check_user_dtype_supported(dtype, "std") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.std is not supported.") - return sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) - - -@_wraps(np.ptp, skip_params=['out']) -def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=False): - return _ptp(a, _ensure_optional_axes(axis), out, keepdims) - -@partial(jit, static_argnames=('axis', 'keepdims')) -def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=False): - _check_arraylike("ptp", a) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") - x = amax(a, axis=axis, keepdims=keepdims) - y = amin(a, axis=axis, keepdims=keepdims) - return lax.sub(x, y) - @_wraps(np.allclose) @partial(jit, static_argnames=('equal_nan',)) @@ -1718,15 +1355,6 @@ def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False): return all(isclose(a, b, rtol, atol, equal_nan)) -@_wraps(np.count_nonzero) -@partial(jit, static_argnames=('axis', 'keepdims')) -def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, - keepdims=False): - _check_arraylike("count_nonzero", a) - return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, - dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims) - - _NONZERO_DOC = """\ Because the size of the output of ``nonzero`` is data-dependent, the function is not typically compatible with JIT. The JAX version adds the optional ``size`` argument which @@ -1770,162 +1398,6 @@ def flatnonzero(a, *, size=None, fill_value=None): return nonzero(ravel(a), size=size, fill_value=fill_value)[0] -def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, - axis=None, keepdims=None, **kwargs): - _check_arraylike(name, a) - if not issubdtype(_dtype(a), inexact): - return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) - - out = jnp_reduction(where(isnan(a), _reduction_init_val(a, init_val), a), - axis=axis, keepdims=keepdims, **kwargs) - if nan_if_all_nan: - return where(all(isnan(a), axis=axis, keepdims=keepdims), - _lax_const(a, nan), out) - else: - return out - -@_wraps(np.nanmin, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'keepdims')) -def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _nan_reduction(a, 'nanmin', min, inf, nan_if_all_nan=initial is None, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where=where) - -@_wraps(np.nanmax, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'keepdims')) -def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): - return _nan_reduction(a, 'nanmax', max, -inf, nan_if_all_nan=initial is None, - axis=axis, out=out, keepdims=keepdims, - initial=initial, where=where) - -@_wraps(np.nansum, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): - lax_internal._check_user_dtype_supported(dtype, "nanprod") - return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, - axis=axis, dtype=dtype, out=out, keepdims=keepdims, - initial=initial, where=where) - -# Work around a sphinx documentation warning in NumPy 1.22. -nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n") - -@_wraps(np.nanprod, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): - lax_internal._check_user_dtype_supported(dtype, "nanprod") - return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, - axis=axis, dtype=dtype, out=out, keepdims=keepdims, - initial=initial, where=where) - -@_wraps(np.nanmean, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, where=None): - _check_arraylike("nanmean", a) - lax_internal._check_user_dtype_supported(dtype, "nanmean") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") - if issubdtype(_dtype(a), bool_) or issubdtype(_dtype(a), integer): - return mean(a, axis, dtype, out, keepdims, where=where) - if dtype is None: - dtype = _dtype(a) - nan_mask = logical_not(isnan(a)) - normalizer = sum(nan_mask, axis=axis, dtype=int32, keepdims=keepdims, where=where) - normalizer = lax.convert_element_type(normalizer, dtype) - td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) - return td - - -@_wraps(np.nanvar, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, where=None): - _check_arraylike("nanvar", a) - lax_internal._check_user_dtype_supported(dtype, "nanvar") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") - - a_dtype, dtype = _var_promote_types(_dtype(a), dtype) - a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where) - - centered = _where(isnan(a), 0, a - a_mean) # double-where trick for gradients. - if issubdtype(centered.dtype, complexfloating): - centered = lax.real(lax.mul(centered, lax.conj(centered))) - else: - centered = lax.square(centered) - - normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims, where=where) - normalizer = normalizer - ddof - normalizer_mask = lax.le(normalizer, 0) - result = sum(centered, axis, keepdims=keepdims, where=where) - result = _where(normalizer_mask, nan, result) - divisor = _where(normalizer_mask, 1, normalizer) - out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) - return lax.convert_element_type(out, dtype) - - -@_wraps(np.nanstd, skip_params=['out']) -@partial(jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, where=None): - _check_arraylike("nanstd", a) - lax_internal._check_user_dtype_supported(dtype, "nanstd") - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") - return sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) - - -def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0): - @_wraps(np_reduction, skip_params=['out']) - def cumulative_reduction(a, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None): - return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) - - @partial(jit, static_argnames=('axis', 'dtype')) - def _cumulative_reduction(a, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None): - _check_arraylike(np_reduction.__name__, a) - if out is not None: - raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " - f"is not supported.") - lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) - - if axis is None or isscalar(a): - a = ravel(a) - axis = 0 - - a_shape = list(shape(a)) - num_dims = len(a_shape) - axis = _canonicalize_axis(axis, num_dims) - - if fill_nan: - a = where(isnan(a), _lax_const(a, fill_value), a) - - if not dtype and _dtype(a) == bool_: - dtype = int_ - if dtype: - a = lax.convert_element_type(a, dtype) - - return reduction(a, axis) - - return cumulative_reduction - - -cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) -cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) -cumproduct = cumprod -nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, - fill_nan=True, fill_value=0) -nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, - fill_nan=True, fill_value=1) - - @_wraps(np.unwrap) @partial(jit, static_argnames=('axis',)) def unwrap(p, discont=pi, axis: int = -1): @@ -5163,7 +4635,7 @@ def _itemsize(arr): return _dtype(arr).itemsize -def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None): +def _clip(number, min=None, max=None, out=None, *, a_min=None, a_max=None): # noqa: F811 # ndarray.clip has a slightly different API from clip (min -> a_min, max -> a_max) # TODO: remove after deprecation window if a_min is not None or a_max is not None: @@ -5609,7 +5081,7 @@ def power(self, values, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode)) - def min(self, values, indices_are_sorted=False, unique_indices=False, + def min(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811 mode=None): """Pure equivalent of ``x[idx] = minimum(x[idx], y)``. @@ -5624,7 +5096,7 @@ def min(self, values, indices_are_sorted=False, unique_indices=False, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - def max(self, values, indices_are_sorted=False, unique_indices=False, + def max(self, values, indices_are_sorted=False, unique_indices=False, # noqa: F811 mode=None): """Pure equivalent of ``x[idx] = maximum(x[idx], y)``. diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py new file mode 100644 index 000000000000..2aba56add775 --- /dev/null +++ b/jax/_src/numpy/reductions.py @@ -0,0 +1,588 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +from functools import partial +import operator +from typing import Optional, Tuple, Union +import warnings + +import numpy as np + +from jax import core +from jax import lax +from jax._src import api +from jax._src import dtypes +from jax._src.numpy.ndarray import ndarray +from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _where, _wraps +from jax._src.numpy.ufuncs import isnan, logical_not +from jax._src.lax import lax as lax_internal +from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis + + +_all = builtins.all +_lax_const = lax_internal._const + + +def _asarray(a): + # simplified version of jnp.asarray() for local use. + return a if isinstance(a, ndarray) else api.device_put(a) + +def _isscalar(element): + if hasattr(element, '__jax_array__'): + element = element.__jax_array__() + return dtypes.is_python_scalar(element) or np.isscalar(element) + +def _moveaxis(a, source: int, destination: int): + # simplified version of jnp.moveaxis() for local use. + _check_arraylike("moveaxis", a) + a = _asarray(a) + source = _canonicalize_axis(source, np.ndim(a)) + destination = _canonicalize_axis(destination, np.ndim(a)) + perm = [i for i in range(np.ndim(a)) if i != source] + perm.insert(destination, source) + return lax.transpose(a, perm) + + +def _reduction(a, name, np_fun, op, init_val, has_identity=True, + preproc=None, bool_op=None, upcast_f16_for_computation=False, + axis=None, dtype=None, out=None, keepdims=False, initial=None, + where_=None, parallel_reduce=None): + bool_op = bool_op or op + # Note: we must accept out=None as an argument, because numpy reductions delegate to + # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method + # exists, passing along all its arguments. + if out is not None: + raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported.") + _check_arraylike(name, a) + lax_internal._check_user_dtype_supported(dtype, name) + axis = core.concrete_or_error(None, axis, f"axis argument to jnp.{name}().") + + if initial is None and not has_identity: + if not _all(core.greater_equal_dim(d, 1) for d in np.shape(a)): + raise ValueError(f"zero-size array to reduction operation {name} which has no identity") + if where_ is not None: + raise ValueError(f"reduction operation {name} does not have an identity, so to use a " + f"where mask one has to specify 'initial'") + + a = a if isinstance(a, ndarray) else _asarray(a) + a = preproc(a) if preproc else a + pos_dims, dims = _reduction_dims(a, axis) + result_dtype = dtypes.canonicalize_dtype(dtype or dtypes.dtype(np_fun(np.ones((), dtype=dtypes.dtype(a))))) + if upcast_f16_for_computation and dtypes.issubdtype(result_dtype, np.inexact): + computation_dtype = dtypes.promote_types(result_dtype, np.float32) + else: + computation_dtype = result_dtype + a = lax.convert_element_type(a, computation_dtype) + op = op if computation_dtype != np.bool_ else bool_op + # NB: in XLA, init_val must be an identity for the op, so the user-specified + # initial value must be applied afterward. + init_val = _reduction_init_val(a, init_val) + if where_ is not None: + a = _where(where_, a, init_val) + if pos_dims is not dims: + if parallel_reduce is None: + raise NotImplementedError(f"Named reductions not implemented for jnp.{name}()") + result = parallel_reduce(a, dims) + else: + result = lax.reduce(a, init_val, op, dims) + if initial is not None: + result = op(lax.convert_element_type(initial, a.dtype), result) + if keepdims: + result = lax.expand_dims(result, pos_dims) + return lax.convert_element_type(result, dtype or result_dtype) + +def _canonicalize_axis_allow_named(x, rank): + return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) + +def _reduction_dims(a, axis): + if axis is None: + return (tuple(range(np.ndim(a))),) * 2 + elif not isinstance(axis, (np.ndarray, tuple, list)): + axis = (axis,) + canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) + for x in axis) + if len(canon_axis) != len(set(canon_axis)): + raise ValueError(f"duplicate value in 'axis': {axis}") + canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int)) + if len(canon_pos_axis) != len(canon_axis): + return canon_pos_axis, canon_axis + else: + return canon_axis, canon_axis + +def _reduction_init_val(a, init_val): + # This function uses np.* functions because lax pattern matches against the + # specific concrete values of the reduction inputs. + a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) + if a_dtype == 'bool': + return np.array(init_val > 0, dtype=a_dtype) + try: + return np.array(init_val, dtype=a_dtype) + except OverflowError: + assert dtypes.issubdtype(a_dtype, np.integer) + sign, info = np.sign(init_val), dtypes.iinfo(a_dtype) + return np.array(info.min if sign < 0 else info.max, dtype=a_dtype) + +def _cast_to_bool(operand): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=np.ComplexWarning) + return lax.convert_element_type(operand, np.bool_) + + +def _ensure_optional_axes(x): + def force(x): + if x is None: + return None + try: + return operator.index(x) + except TypeError: + return tuple(i if isinstance(i, str) else operator.index(i) for i in x) + return core.concrete_or_error( + force, x, "The axis argument must be known statically.") + + +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) +def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None, keepdims=None, initial=None, where=None): + return _reduction(a, "sum", np.sum, lax.add, 0, + bool_op=lax.bitwise_or, upcast_f16_for_computation=True, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where, parallel_reduce=lax.psum) + +@_wraps(np.sum, skip_params=['out']) +def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None): + return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, + keepdims=keepdims, initial=initial, where=where) + + +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) +def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None, keepdims=None, initial=None, where=None): + return _reduction(a, "prod", np.prod, lax.mul, 1, + bool_op=lax.bitwise_and, upcast_f16_for_computation=True, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where_=where) + +@_wraps(np.prod, skip_params=['out']) +def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None): + return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, + out=out, keepdims=keepdims, initial=initial, where=where) + + +@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, + axis=axis, out=out, keepdims=keepdims, + initial=initial, where_=where, parallel_reduce=lax.pmax) + +@_wraps(np.max, skip_params=['out']) +def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, + keepdims=keepdims, initial=initial, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, + axis=axis, out=out, keepdims=keepdims, + initial=initial, where_=where, parallel_reduce=lax.pmin) + +@_wraps(np.min, skip_params=['out']) +def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, + keepdims=keepdims, initial=initial, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, *, where=None): + return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, + axis=axis, out=out, keepdims=keepdims, where_=where) + +@_wraps(np.all, skip_params=['out']) +def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, *, where=None): + return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, + keepdims=keepdims, where=where) + +@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) +def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, *, where=None): + return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, + axis=axis, out=out, keepdims=keepdims, where_=where) + +@_wraps(np.any, skip_params=['out']) +def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, *, where=None): + return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, + keepdims=keepdims, where=where) + +product = prod +amin = min +amax = max +alltrue = all +sometrue = any + +def _axis_size(a, axis): + if not isinstance(axis, (tuple, list)): + axis = (axis,) + size = 1 + a_shape = np.shape(a) + for a in axis: + size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) + return size + +@_wraps(np.mean, skip_params=['out']) +def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=False, *, where=None): + return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, + where=where) + +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) +def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=False, *, where=None): + _check_arraylike("mean", a) + lax_internal._check_user_dtype_supported(dtype, "mean") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.mean is not supported.") + + if where is None: + if axis is None: + normalizer = core.dimension_as_value(np.size(a)) + else: + normalizer = core.dimension_as_value(_axis_size(a, axis)) + else: + normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) + + if dtype is None: + if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): + dtype = dtypes.float_ + else: + dtype = dtypes.dtype(a) + dtype = dtypes.canonicalize_dtype(dtype) + + return lax.div( + sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), + lax.convert_element_type(normalizer, dtype)) + +@_wraps(np.average) +def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, + returned=False): + return _average(a, _ensure_optional_axes(axis), weights, returned) + +@partial(api.jit, static_argnames=('axis', 'returned'), inline=True) +def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, + returned=False): + a = _asarray(a) + + if weights is None: # Treat all weights as 1 + avg = mean(a, axis=axis) + if axis is None: + weights_sum = lax.full((), core.dimension_as_value(np.size(a)), dtype=avg.dtype) + else: + weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis]), dtype=avg.dtype) + else: + weights = _asarray(weights) + + if dtypes.issubdtype(a.dtype, np.inexact): + out_dtype = dtypes.result_type(a.dtype, weights.dtype) + else: + out_dtype = dtypes.result_type(a.dtype, weights.dtype, dtypes.float_) + out_dtype = dtypes.canonicalize_dtype(out_dtype) + + a_shape = np.shape(a) + a_ndim = len(a_shape) + weights_shape = np.shape(weights) + axis = None if axis is None else _canonicalize_axis(axis, a_ndim) + + if a_shape != weights_shape: + # Make sure the dimensions work out + if axis is None: + raise ValueError("Axis must be specified when shapes of a and " + "weights differ.") + if len(weights_shape) != 1: + raise ValueError("1D weights expected when shapes of a and " + "weights differ.") + if not core.symbolic_equal_dim(weights_shape[0], a_shape[axis]): + raise ValueError("Length of weights not " + "compatible with specified axis.") + + weights = _broadcast_to(weights, (a_ndim - 1) * (1,) + weights_shape) + weights = _moveaxis(weights, -1, axis) + + weights_sum = sum(weights, axis=axis, dtype=out_dtype) + avg = sum(a * weights, axis=axis, dtype=out_dtype) / weights_sum + + if returned: + if avg.shape != weights_sum.shape: + weights_sum = _broadcast_to(weights_sum, avg.shape) + return avg, weights_sum + return avg + + +@_wraps(np.var, skip_params=['out']) +def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, *, where=None): + return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where=where) + +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, *, where=None): + _check_arraylike("var", a) + lax_internal._check_user_dtype_supported(dtype, "var") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.var is not supported.") + + a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) + a_mean = mean(a, axis, dtype=a_dtype, keepdims=True, where=where) + centered = a - a_mean + if dtypes.issubdtype(centered.dtype, np.complexfloating): + centered = lax.real(lax.mul(centered, lax.conj(centered))) + else: + centered = lax.square(centered) + + if where is None: + if axis is None: + normalizer = core.dimension_as_value(np.size(a)) + else: + normalizer = core.dimension_as_value(_axis_size(a, axis)) + else: + normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) + normalizer = normalizer - ddof + + result = sum(centered, axis, keepdims=keepdims, where=where) + out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) + return lax.convert_element_type(out, dtype) + + +def _var_promote_types(a_dtype, dtype): + if dtype: + if (not dtypes.issubdtype(dtype, np.complexfloating) and + dtypes.issubdtype(a_dtype, np.complexfloating)): + msg = ("jax.numpy.var does not yet support real dtype parameters when " + "computing the variance of an array of complex values. The " + "semantics of numpy.var seem unclear in this case. Please comment " + "on https://github.com/google/jax/issues/2283 if this behavior is " + "important to you.") + raise ValueError(msg) + a_dtype = dtypes.promote_types(a_dtype, dtype) + else: + if not dtypes.issubdtype(a_dtype, np.inexact): + dtype = a_dtype = dtypes.canonicalize_dtype(dtypes.float_) + else: + dtype = _complex_elem_type(a_dtype) + a_dtype = dtypes.promote_types(a_dtype, np.float32) + return a_dtype, dtype + + +@_wraps(np.std, skip_params=['out']) +def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, *, where=None): + return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + where=where) + +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, *, where=None): + _check_arraylike("std", a) + lax_internal._check_user_dtype_supported(dtype, "std") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.std is not supported.") + return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + + +@_wraps(np.ptp, skip_params=['out']) +def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=False): + return _ptp(a, _ensure_optional_axes(axis), out, keepdims) + +@partial(api.jit, static_argnames=('axis', 'keepdims')) +def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=False): + _check_arraylike("ptp", a) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") + x = amax(a, axis=axis, keepdims=keepdims) + y = amin(a, axis=axis, keepdims=keepdims) + return lax.sub(x, y) + + +@_wraps(np.count_nonzero) +@partial(api.jit, static_argnames=('axis', 'keepdims')) +def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, + keepdims=False): + _check_arraylike("count_nonzero", a) + return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, + dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims) + + +def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, + axis=None, keepdims=None, **kwargs): + _check_arraylike(name, a) + if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): + return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) + + out = jnp_reduction(_where(isnan(a), _reduction_init_val(a, init_val), a), + axis=axis, keepdims=keepdims, **kwargs) + if nan_if_all_nan: + return _where(all(isnan(a), axis=axis, keepdims=keepdims), + _lax_const(a, np.nan), out) + else: + return out + +@_wraps(np.nanmin, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'keepdims')) +def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None, + axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +@_wraps(np.nanmax, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'keepdims')) +def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None): + return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None, + axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +@_wraps(np.nansum, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None): + lax_internal._check_user_dtype_supported(dtype, "nanprod") + return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where=where) + +# Work around a sphinx documentation warning in NumPy 1.22. +nansum.__doc__ = nansum.__doc__.replace("\n\n\n", "\n\n") + +@_wraps(np.nanprod, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None): + lax_internal._check_user_dtype_supported(dtype, "nanprod") + return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, + axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where=where) + +@_wraps(np.nanmean, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=False, where=None): + _check_arraylike("nanmean", a) + lax_internal._check_user_dtype_supported(dtype, "nanmean") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.nanmean is not supported.") + if dtypes.issubdtype(dtypes.dtype(a), np.bool_) or dtypes.issubdtype(dtypes.dtype(a), np.integer): + return mean(a, axis, dtype, out, keepdims, where=where) + if dtype is None: + dtype = dtypes.dtype(a) + nan_mask = logical_not(isnan(a)) + normalizer = sum(nan_mask, axis=axis, dtype=np.int32, keepdims=keepdims, where=where) + normalizer = lax.convert_element_type(normalizer, dtype) + td = lax.div(nansum(a, axis, dtype=dtype, keepdims=keepdims, where=where), normalizer) + return td + + +@_wraps(np.nanvar, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, where=None): + _check_arraylike("nanvar", a) + lax_internal._check_user_dtype_supported(dtype, "nanvar") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") + + a_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) + a_mean = nanmean(a, axis, dtype=a_dtype, keepdims=True, where=where) + + centered = _where(isnan(a), 0, a - a_mean) # double-where trick for gradients. + if dtypes.issubdtype(centered.dtype, np.complexfloating): + centered = lax.real(lax.mul(centered, lax.conj(centered))) + else: + centered = lax.square(centered) + + normalizer = sum(logical_not(isnan(a)), axis=axis, keepdims=keepdims, where=where) + normalizer = normalizer - ddof + normalizer_mask = lax.le(normalizer, 0) + result = sum(centered, axis, keepdims=keepdims, where=where) + result = _where(normalizer_mask, np.nan, result) + divisor = _where(normalizer_mask, 1, normalizer) + out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) + return lax.convert_element_type(out, dtype) + + +@_wraps(np.nanstd, skip_params=['out']) +@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) +def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, ddof=0, keepdims=False, where=None): + _check_arraylike("nanstd", a) + lax_internal._check_user_dtype_supported(dtype, "nanstd") + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.nanstd is not supported.") + return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + + +def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0): + @_wraps(np_reduction, skip_params=['out']) + def cumulative_reduction(a, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None): + return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) + + @partial(api.jit, static_argnames=('axis', 'dtype')) + def _cumulative_reduction(a, + axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None): + _check_arraylike(np_reduction.__name__, a) + if out is not None: + raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " + f"is not supported.") + lax_internal._check_user_dtype_supported(dtype, np_reduction.__name__) + + if axis is None or _isscalar(a): + a = lax.reshape(a, (np.size(a),)) + axis = 0 + + a_shape = list(np.shape(a)) + num_dims = len(a_shape) + axis = _canonicalize_axis(axis, num_dims) + + if fill_nan: + a = _where(isnan(a), _lax_const(a, fill_value), a) + + if not dtype and dtypes.dtype(a) == np.bool_: + dtype = dtypes.canonicalize_dtype(dtypes.int_) + if dtype: + a = lax.convert_element_type(a, dtype) + + return reduction(a, axis) + + return cumulative_reduction + + +cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) +cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) +cumproduct = cumprod +nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, + fill_nan=True, fill_value=0) +nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, + fill_nan=True, fill_value=1) diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index c1a2b8025c95..a7714c328144 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -308,6 +308,38 @@ roots as roots, ) +from jax._src.numpy.reductions import ( + alltrue as alltrue, + amin as amin, + amax as amax, + any as any, + all as all, + average as average, + count_nonzero as count_nonzero, + cumsum as cumsum, + cumprod as cumprod, + cumproduct as cumproduct, + max as max, + mean as mean, + min as min, + nancumsum as nancumsum, + nancumprod as nancumprod, + nanmax as nanmax, + nanmean as nanmean, + nanmin as nanmin, + nanprod as nanprod, + nanstd as nanstd, + nansum as nansum, + nanvar as nanvar, + prod as prod, + product as product, + ptp as ptp, + sometrue as sometrue, + std as std, + sum as sum, + var as var, +) + from jax._src.numpy.ufuncs import ( abs as abs, absolute as absolute,