diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 73bf62662ac1..4db341ff5351 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -24,7 +24,6 @@ rules for the underlying :code:`lax` primitives. """ -import abc import builtins import collections from functools import partial @@ -40,14 +39,24 @@ import jax from jax import jit, custom_jvp from jax._src.numpy.vectorize import vectorize -from jax._src.numpy.util import _wraps +from jax._src.numpy.ndarray import ndarray +from jax._src.numpy.ufuncs import ( # noqa: F401 + fabs, bitwise_not, invert, negative, positive, floor, ceil, exp, log, expm1, log1p, sin, + cos, tan, arcsin, arccos, arctan, sinh, cosh, arcsinh, tanh, arctanh, sqrt, cbrt, + add, bitwise_and, bitwise_or, bitwise_xor, left_shift, equal, multiply, not_equal, subtract, + arctan2, minimum, maximum, float_power, nextafter, greater_equal, greater, less_equal, less, + logical_and, logical_not, logical_or, logical_xor, arccosh, right_shift, absolute, abs, rint, + sign, copysign, true_divide, divide, floor_divide, divmod, power, logaddexp, logaddexp2, log2, + log10, exp2, signbit, conjugate, conj, real, imag, ldexp, frexp, remainder, mod, fmod, square, deg2rad, + rad2deg, degrees, radians, modf, isfinite, isinf, isposinf, isneginf, isnan) +from jax._src.numpy.util import ( # noqa: F401 + _arraylike, _broadcast_arrays, _broadcast_to, _check_arraylike, _promote_args, _promote_args_inexact, + _promote_dtypes, _promote_dtypes_inexact, _promote_shapes, _where, _wraps) from jax import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax import errors -from jax.core import (UnshapedArray, ShapedArray, DShapedArray, ConcreteArray, - canonicalize_shape) -from jax.config import config +from jax.core import ShapedArray, DShapedArray, ConcreteArray, canonicalize_shape from jax.interpreters import pxla from jax import lax from jax._src import device_array @@ -95,281 +104,6 @@ get_printoptions = np.get_printoptions printoptions = np.printoptions set_printoptions = np.set_printoptions - -# ndarray is defined as an virtual abstract base class. - -class ArrayMeta(abc.ABCMeta): - """Metaclass for overriding ndarray isinstance checks.""" - - def __instancecheck__(self, instance): - # Allow tracer instances with avals that are instances of UnshapedArray. - # We could instead just declare Tracer an instance of the ndarray type, but - # there can be traced values that are not arrays. The main downside here is - # that isinstance(x, ndarray) might return true but - # issubclass(type(x), ndarray) might return false for an array tracer. - try: - return (hasattr(instance, "aval") and - isinstance(instance.aval, UnshapedArray)) - except AttributeError: - super().__instancecheck__(instance) - - -class ndarray(metaclass=ArrayMeta): - dtype: np.dtype - ndim: int - shape: Tuple[int, ...] - size: int - - def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, - order=None): - raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." - " Use jax.numpy.array, or jax.numpy.zeros instead.") - - @abc.abstractmethod - def __getitem__(self, key, indices_are_sorted=False, - unique_indices=False) -> Any: ... - @abc.abstractmethod - def __setitem__(self, key, value) -> Any: ... - @abc.abstractmethod - def __len__(self) -> Any: ... - @abc.abstractmethod - def __iter__(self) -> Any: ... - @abc.abstractmethod - def __reversed__(self) -> Any: ... - - # Comparisons - @abc.abstractmethod - def __lt__(self, other) -> Any: ... - @abc.abstractmethod - def __le__(self, other) -> Any: ... - @abc.abstractmethod - def __eq__(self, other) -> Any: ... - @abc.abstractmethod - def __ne__(self, other) -> Any: ... - @abc.abstractmethod - def __gt__(self, other) -> Any: ... - @abc.abstractmethod - def __ge__(self, other) -> Any: ... - - # Unary arithmetic - - @abc.abstractmethod - def __neg__(self) -> Any: ... - @abc.abstractmethod - def __pos__(self) -> Any: ... - @abc.abstractmethod - def __abs__(self) -> Any: ... - @abc.abstractmethod - def __invert__(self) -> Any: ... - - # Binary arithmetic - - @abc.abstractmethod - def __add__(self, other) -> Any: ... - @abc.abstractmethod - def __sub__(self, other) -> Any: ... - @abc.abstractmethod - def __mul__(self, other) -> Any: ... - @abc.abstractmethod - def __matmul__(self, other) -> Any: ... - @abc.abstractmethod - def __truediv__(self, other) -> Any: ... - @abc.abstractmethod - def __floordiv__(self, other) -> Any: ... - @abc.abstractmethod - def __mod__(self, other) -> Any: ... - @abc.abstractmethod - def __divmod__(self, other) -> Any: ... - @abc.abstractmethod - def __pow__(self, other) -> Any: ... - @abc.abstractmethod - def __lshift__(self, other) -> Any: ... - @abc.abstractmethod - def __rshift__(self, other) -> Any: ... - @abc.abstractmethod - def __and__(self, other) -> Any: ... - @abc.abstractmethod - def __xor__(self, other) -> Any: ... - @abc.abstractmethod - def __or__(self, other) -> Any: ... - - @abc.abstractmethod - def __radd__(self, other) -> Any: ... - @abc.abstractmethod - def __rsub__(self, other) -> Any: ... - @abc.abstractmethod - def __rmul__(self, other) -> Any: ... - @abc.abstractmethod - def __rmatmul__(self, other) -> Any: ... - @abc.abstractmethod - def __rtruediv__(self, other) -> Any: ... - @abc.abstractmethod - def __rfloordiv__(self, other) -> Any: ... - @abc.abstractmethod - def __rmod__(self, other) -> Any: ... - @abc.abstractmethod - def __rdivmod__(self, other) -> Any: ... - @abc.abstractmethod - def __rpow__(self, other) -> Any: ... - @abc.abstractmethod - def __rlshift__(self, other) -> Any: ... - @abc.abstractmethod - def __rrshift__(self, other) -> Any: ... - @abc.abstractmethod - def __rand__(self, other) -> Any: ... - @abc.abstractmethod - def __rxor__(self, other) -> Any: ... - @abc.abstractmethod - def __ror__(self, other) -> Any: ... - - @abc.abstractmethod - def __bool__(self) -> Any: ... - @abc.abstractmethod - def __complex__(self) -> Any: ... - @abc.abstractmethod - def __int__(self) -> Any: ... - @abc.abstractmethod - def __float__(self) -> Any: ... - @abc.abstractmethod - def __round__(self, ndigits=None) -> Any: ... - - @abc.abstractmethod - def __index__(self) -> Any: ... - - # np.ndarray methods: - @abc.abstractmethod - def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None) -> Any: ... - @abc.abstractmethod - def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None) -> Any: ... - @abc.abstractmethod - def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... - @abc.abstractmethod - def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... - @abc.abstractmethod - def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ... - @abc.abstractmethod - def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... - @abc.abstractmethod - def astype(self, dtype) -> Any: ... - @abc.abstractmethod - def choose(self, choices, out=None, mode='raise') -> Any: ... - @abc.abstractmethod - def clip(self, a_min=None, a_max=None, out=None) -> Any: ... - @abc.abstractmethod - def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ... - @abc.abstractmethod - def conj(self) -> Any: ... - @abc.abstractmethod - def conjugate(self) -> Any: ... - @abc.abstractmethod - def copy(self) -> Any: ... - @abc.abstractmethod - def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None) -> Any: ... - @abc.abstractmethod - def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None) -> Any: ... - @abc.abstractmethod - def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ... - @abc.abstractmethod - def dot(self, b, *, precision=None) -> Any: ... - @abc.abstractmethod - def flatten(self) -> Any: ... - @property - @abc.abstractmethod - def imag(self) -> Any: ... - @abc.abstractmethod - def item(self, *args) -> Any: ... - @abc.abstractmethod - def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None) -> Any: ... - @abc.abstractmethod - def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, *, where=None,) -> Any: ... - @abc.abstractmethod - def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None) -> Any: ... - @property - @abc.abstractmethod - def nbytes(self) -> Any: ... - @abc.abstractmethod - def nonzero(self, *, size=None, fill_value=None) -> Any: ... - @abc.abstractmethod - def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Any: ... - @abc.abstractmethod - def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=False,) -> Any: ... - @abc.abstractmethod - def ravel(self, order='C') -> Any: ... - @property - @abc.abstractmethod - def real(self) -> Any: ... - @abc.abstractmethod - def repeat(self, repeats, axis: Optional[int] = None, *, - total_repeat_length=None) -> Any: ... - @abc.abstractmethod - def reshape(self, *args, order='C') -> Any: ... - @abc.abstractmethod - def round(self, decimals=0, out=None) -> Any: ... - @abc.abstractmethod - def searchsorted(self, v, side='left', sorter=None) -> Any: ... - @abc.abstractmethod - def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... - @abc.abstractmethod - def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ... - @abc.abstractmethod - def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... - @abc.abstractmethod - def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Any: ... - @abc.abstractmethod - def swapaxes(self, axis1: int, axis2: int) -> Any: ... - @abc.abstractmethod - def take(self, indices, axis: Optional[int] = None, out=None, - mode=None) -> Any: ... - @abc.abstractmethod - def tobytes(self, order='C') -> Any: ... - @abc.abstractmethod - def tolist(self) -> Any: ... - @abc.abstractmethod - def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, - out=None) -> Any: ... - @abc.abstractmethod - def transpose(self, *args) -> Any: ... - @abc.abstractmethod - def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... - @abc.abstractmethod - def view(self, dtype=None, type=None) -> Any: ... - - # Even though we don't always support the NumPy array protocol, e.g., for - # tracer types, for type checking purposes we must declare support so we - # implement the NumPy ArrayLike protocol. - def __array__(self) -> Any: ... - - # JAX extensions - @property - @abc.abstractmethod - def at(self) -> Any: ... - @property - @abc.abstractmethod - def aval(self) -> Any: ... - @property - @abc.abstractmethod - def weak_type(self) -> bool: ... - - -ndarray.register(device_array.DeviceArray) -for t in device_array.device_array_types: - ndarray.register(t) -ndarray.register(pxla._SDA_BASE_CLASS) - - - iscomplexobj = np.iscomplexobj shape = _shape = np.shape @@ -471,130 +205,6 @@ def _jnp_dtype(obj, align=False, copy=False): np.complex_: complex_ } -_INT_DTYPES = { - 16: np.int16, - 32: np.int32, - 64: np.int64, -} - -def _promote_shapes(fun_name, *args): - """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" - if len(args) < 2: - return args - else: - shapes = [shape(arg) for arg in args] - if _all(len(shapes[0]) == len(s) for s in shapes[1:]): - return args # no need for rank promotion, so rely on lax promotion - nonscalar_ranks = {len(shp) for shp in shapes if shp} - if len(nonscalar_ranks) < 2: - return args - else: - if config.jax_numpy_rank_promotion != "allow": - _rank_promotion_warning_or_error(fun_name, shapes) - if config.jax_dynamic_shapes: - # With dynamic shapes we don't support singleton-dimension broadcasting; - # we instead broadcast out to the full shape as a temporary workaround. - res_shape = lax.broadcast_shapes(*shapes) - return [broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] - else: - result_rank = len(lax.broadcast_shapes(*shapes)) - return [broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) - for arg, shp in zip(args, shapes)] - - -def _rank_promotion_warning_or_error(fun_name, shapes): - if config.jax_numpy_rank_promotion == "warn": - msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " - "Set the jax_numpy_rank_promotion config option to 'allow' to " - "disable this warning; for more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") - warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) - elif config.jax_numpy_rank_promotion == "raise": - msg = ("Operands could not be broadcast together for {} on shapes {} " - "and with the config option jax_numpy_rank_promotion='raise'. " - "For more information, see " - "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") - raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) - -def _promote_dtypes(*args): - """Convenience function to apply Numpy argument dtype promotion.""" - # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. - if len(args) < 2: - return args - else: - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) - return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] - -def _promote_dtypes_inexact(*args): - """Convenience function to apply Numpy argument dtype promotion. - - Promotes arguments to an inexact type.""" - to_dtype, weak_type = dtypes._lattice_result_type(*args) - to_dtype = dtypes.canonicalize_dtype(to_dtype) - to_dtype_inexact = _to_inexact_dtype(to_dtype) - weak_type = (weak_type and to_dtype == to_dtype_inexact) - return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] - -def _to_inexact_dtype(dtype): - """Promotes a dtype into an inexact dtype, if it is not already one.""" - return dtype if issubdtype(dtype, inexact) else promote_types(dtype, float_) - -def _complex_elem_type(dtype): - """Returns the float type of the real/imaginary parts of a complex dtype.""" - return np.abs(np.zeros((), dtype)).dtype - -def _result_dtype(op, *args): - """Compute result dtype of applying op to arguments with given dtypes.""" - args = [np.ones((0,) * ndim(arg), _dtype(arg)) for arg in args] - return _dtype(op(*args)) - - -def _arraylike(x): - return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or - hasattr(x, '__jax_array__') or isscalar(x)) - - -def _stackable(*args): - return _all(type(arg) in stackables for arg in args) -stackables: Set[Type] = set() -_register_stackable: Callable[[Type], None] = stackables.add - -def _check_arraylike(fun_name, *args): - """Check if all args fit JAX's definition of arraylike.""" - assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}" - if _any(not _arraylike(arg) for arg in args): - pos, arg = next((i, arg) for i, arg in enumerate(args) - if not _arraylike(arg)) - msg = "{} requires ndarray or scalar arguments, got {} at position {}." - raise TypeError(msg.format(fun_name, type(arg), pos)) - -def _check_no_float0s(fun_name, *args): - """Check if none of the args have dtype float0.""" - if _any(dtypes.dtype(arg) is dtypes.float0 for arg in args): - raise TypeError( - f"Called {fun_name} with a float0 array. " - "float0s do not support any operations by design because they " - "are not compatible with non-trivial vector spaces. No implicit dtype " - "conversion is done. You can use np.zeros_like(arr, dtype=np.float) " - "to cast a float0 array to a regular zeros array. \n" - "If you didn't expect to get a float0 you might have accidentally " - "taken a gradient with respect to an integer argument.") - -def _promote_args(fun_name, *args): - """Convenience function to apply Numpy argument shape and dtype promotion.""" - _check_arraylike(fun_name, *args) - _check_no_float0s(fun_name, *args) - return _promote_shapes(fun_name, *_promote_dtypes(*args)) - -def _promote_args_inexact(fun_name, *args): - """Convenience function to apply Numpy argument shape and dtype promotion. - - Promotes non-inexact types to an inexact type.""" - _check_arraylike(fun_name, *args) - _check_no_float0s(fun_name, *args) - return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args)) - def _convert_and_clip_integer(val, dtype): """ Convert integer-typed val to specified integer dtype, clipping to dtype @@ -635,6 +245,16 @@ def _convert_and_clip_integer(val, dtype): max_val = _constant_like(val, _min(iinfo(dtype).max, iinfo(val_dtype).max)) return clip(val, min_val, max_val).astype(dtype) +def _complex_elem_type(dtype): + """Returns the float type of the real/imaginary parts of a complex dtype.""" + return np.abs(np.zeros((), dtype)).dtype + + +def _stackable(*args): + return _all(type(arg) in stackables for arg in args) +stackables: Set[Type] = set() +_register_stackable: Callable[[Type], None] = stackables.add + def _constant_like(x, const): return np.array(const, dtype=_dtype(x)) @@ -682,388 +302,6 @@ def isscalar(element): def result_type(*args): return dtypes.result_type(*args) -def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): - if promote_to_inexact: - fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x)) - else: - fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x)) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): - if promote_to_inexact: - fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) - else: - fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): - def fn(x1, x2): - x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) - return lax_fn(x1, x2) if x1.dtype != bool_ else bool_lax_fn(x1, x2) - fn = jit(fn, inline=True) - if lax_doc: - doc = _dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() - return _wraps(numpy_fn, lax_description=doc)(fn) - else: - return _wraps(numpy_fn)(fn) - -fabs = _one_to_one_unop(np.fabs, lax.abs, True) -bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) -invert = _one_to_one_unop(np.invert, lax.bitwise_not) -negative = _one_to_one_unop(np.negative, lax.neg) -positive = _one_to_one_unop(np.positive, lambda x: x) - -floor = _one_to_one_unop(np.floor, lax.floor, True) -ceil = _one_to_one_unop(np.ceil, lax.ceil, True) -exp = _one_to_one_unop(np.exp, lax.exp, True) -log = _one_to_one_unop(np.log, lax.log, True) -expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) -log1p = _one_to_one_unop(np.log1p, lax.log1p, True) -sin = _one_to_one_unop(np.sin, lax.sin, True) -cos = _one_to_one_unop(np.cos, lax.cos, True) -tan = _one_to_one_unop(np.tan, lax.tan, True) -arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) -arccos = _one_to_one_unop(np.arccos, lax.acos, True) -arctan = _one_to_one_unop(np.arctan, lax.atan, True) -sinh = _one_to_one_unop(np.sinh, lax.sinh, True) -cosh = _one_to_one_unop(np.cosh, lax.cosh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -tanh = _one_to_one_unop(np.tanh, lax.tanh, True) -arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) -arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) -sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) -cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) - - -add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) -bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) -bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) -bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) -left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) -equal = _one_to_one_binop(np.equal, lax.eq) -multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) -not_equal = _one_to_one_binop(np.not_equal, lax.ne) -subtract = _one_to_one_binop(np.subtract, lax.sub) -arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) -minimum = _one_to_one_binop(np.minimum, lax.min) -maximum = _one_to_one_binop(np.maximum, lax.max) -float_power = _one_to_one_binop(np.float_power, lax.pow, True) -nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) - -@_wraps(np.arccosh) -@jit -def arccosh(x): - # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different - # convention than np.arccosh. - out = lax.acosh(*_promote_args_inexact("arccosh", x)) - if issubdtype(out.dtype, np.complexfloating): - out = where(real(out) < 0, lax.neg(out), out) - return out - -def _comparison_op(numpy_fn, lax_fn): - # TODO(https://github.com/google/jax/issues/6713): decorate this function with - # jit, after fixing a surprising interaction with remat(..., concrete=True). - def fn(x1, x2): - x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) - # Comparison on complex types are defined as a lexicographic ordering on - # the (real, imag) pair. - if issubdtype(_dtype(x1), complexfloating): - rx = lax.real(x1) - ry = lax.real(x2) - return lax.select(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), - lax_fn(rx, ry)) - return lax_fn(x1, x2) - return _wraps(numpy_fn)(fn) - -greater_equal = _comparison_op(np.greater_equal, lax.ge) -greater = _comparison_op(np.greater, lax.gt) -less_equal = _comparison_op(np.less_equal, lax.le) -less = _comparison_op(np.less, lax.lt) - - -def _logical_op(np_op, bitwise_op): - @_wraps(np_op, update_doc=False) - @partial(jit, inline=True) - def op(*args): - zero = lambda x: lax.full_like(x, shape=(), fill_value=0) - args = (x if issubdtype(_dtype(x), bool_) else lax.ne(x, zero(x)) - for x in args) - return bitwise_op(*_promote_args(np_op.__name__, *args)) - return op - -logical_and = _logical_op(np.logical_and, lax.bitwise_and) -logical_not = _logical_op(np.logical_not, lax.bitwise_not) -logical_or = _logical_op(np.logical_or, lax.bitwise_or) -logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor) - - -@_wraps(np.right_shift) -@partial(jit, inline=True) -def right_shift(x1, x2): - x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) - lax_fn = lax.shift_right_logical if \ - np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic - return lax_fn(x1, x2) - - -@_wraps(np.absolute) -@partial(jit, inline=True) -def absolute(x): - _check_arraylike('absolute', x) - dt = _dtype(x) - return x if dt == bool_ or issubdtype(dt, unsignedinteger) else lax.abs(x) -abs = _wraps(np.abs)(absolute) - - -@_wraps(np.rint) -@jit -def rint(x): - _check_arraylike('rint', x) - dtype = _dtype(x) - if issubdtype(dtype, integer): - return lax.convert_element_type(x, float_) - if issubdtype(dtype, complexfloating): - return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) - return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) - - -@_wraps(np.sign) -@jit -def sign(x): - _check_arraylike('sign', x) - dtype = _dtype(x) - if issubdtype(dtype, complexfloating): - re = lax.real(x) - return lax.complex( - lax.sign(where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) - return lax.sign(x) - - -@_wraps(np.copysign) -@jit -def copysign(x1, x2): - x1, x2 = _promote_args_inexact("copysign", x1, x2) - if issubdtype(_dtype(x1), complexfloating): - raise TypeError("copysign does not support complex-valued inputs") - return where(signbit(x2), -lax.abs(x1), lax.abs(x1)) - - -@_wraps(np.true_divide) -@partial(jit, inline=True) -def true_divide(x1, x2): - x1, x2 = _promote_args_inexact("true_divide", x1, x2) - return lax.div(x1, x2) - -divide = true_divide - -@_wraps(np.floor_divide) -@jit -def floor_divide(x1, x2): - x1, x2 = _promote_args("floor_divide", x1, x2) - dtype = _dtype(x1) - if issubdtype(dtype, integer): - quotient = lax.div(x1, x2) - select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) - # TODO(mattjj): investigate why subtracting a scalar was causing promotion - return where(select, quotient - np.array(1, _dtype(quotient)), quotient) - elif issubdtype(dtype, complexfloating): - x1r = lax.real(x1) - x1i = lax.imag(x1) - x2r = lax.real(x2) - x2i = lax.imag(x2) - which = lax.ge(lax.abs(x2r), lax.abs(x2i)) - rat1 = where(which, lax._const(x2i, 1), lax.div(x2r, x2i)) - rat2 = where(which, lax.div(x2i, x2r), lax._const(x2i, 1)) - out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), - lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) - return lax.convert_element_type(out, dtype) - else: - return _float_divmod(x1, x2)[0] - - -@_wraps(np.divmod) -@jit -def divmod(x1, x2): - x1, x2 = _promote_args("divmod", x1, x2) - if issubdtype(_dtype(x1), integer): - return floor_divide(x1, x2), remainder(x1, x2) - else: - return _float_divmod(x1, x2) - - -def _float_divmod(x1, x2): - # see float_divmod in floatobject.c of CPython - mod = lax.rem(x1, x2) - div = lax.div(lax.sub(x1, mod), x2) - - ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) - mod = lax.select(ind, mod + x2, mod) - div = lax.select(ind, div - _constant_like(div, 1), div) - - return lax.round(div), mod - - -@partial(jit, inline=True) -def _power(x1, x2): - x1, x2 = _promote_args("power", x1, x2) - dtype = _dtype(x1) - if not issubdtype(dtype, integer): - return lax.pow(x1, x2) - - # Integer power => use binary exponentiation. - - # TODO(phawkins): add integer pow support to XLA. - bits = 6 # Anything more would overflow for any x1 > 1 - zero = _constant_like(x2, 0) - one = _constant_like(x2, 1) - # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 - acc = where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) - for _ in range(bits): - acc = where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) - x1 = lax.mul(x1, x1) - x2 = lax.shift_right_logical(x2, one) - return acc - -@_wraps(np.power) -def power(x1, x2): - # Special case for concrete integer scalars: use binary exponentiation. - # Using lax.pow may be imprecise for floating-point values; the goal of this - # code path is to make sure we end up with a precise output for the common - # pattern ``x ** 2`` or similar. - if isinstance(core.get_aval(x2), ConcreteArray): - try: - x2 = operator.index(x2) - except TypeError: - pass - else: - return lax.integer_pow(x1, x2) - return _power(x1, x2) - -@custom_jvp -@_wraps(np.logaddexp) -@jit -def logaddexp(x1, x2): - x1, x2 = _promote_args_inexact("logaddexp", x1, x2) - amax = lax.max(x1, x2) - if issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.log1p(lax.exp(delta))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) - -def _wrap_between(x, _a): - """Wraps `x` between `[-a, a]`.""" - a = _constant_like(x, _a) - two_a = _constant_like(x, 2 * _a) - zero = _constant_like(x, 0) - rem = lax.rem(lax.add(x, a), two_a) - rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) - return lax.sub(rem, a) - -@logaddexp.defjvp -def _logaddexp_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) - primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - -def _replace_inf(x): - return lax.select(isposinf(real(x)), zeros_like(x), x) - - -@custom_jvp -@_wraps(np.logaddexp2) -@jit -def logaddexp2(x1, x2): - x1, x2 = _promote_args_inexact("logaddexp2", x1, x2) - amax = lax.max(x1, x2) - if issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), - _constant_like(x1, np.log(2))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) - -@logaddexp2.defjvp -def _logaddexp2_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) - primal_out = logaddexp2(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - - -@_wraps(np.log2) -@partial(jit, inline=True) -def log2(x): - x, = _promote_args_inexact("log2", x) - return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) - - -@_wraps(np.log10) -@partial(jit, inline=True) -def log10(x): - x, = _promote_args_inexact("log10", x) - return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) - - -@_wraps(np.exp2) -@partial(jit, inline=True) -def exp2(x): - x, = _promote_args_inexact("exp2", x) - return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) - -@_wraps(np.signbit) -@jit -def signbit(x): - x, = _promote_args("signbit", x) - dtype = _dtype(x) - if issubdtype(dtype, integer): - return lax.lt(x, _constant_like(x, 0)) - elif issubdtype(dtype, bool_): - return full_like(x, False, dtype=bool_) - elif not issubdtype(dtype, floating): - raise ValueError( - "jax.numpy.signbit is not well defined for %s" % dtype) - - # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to - # F32. - if dtype == bfloat16: - dtype = float32 - x = lax.convert_element_type(x, float32) - - info = finfo(dtype) - if info.bits not in _INT_DTYPES: - raise NotImplementedError( - "jax.numpy.signbit only supports 16, 32, and 64-bit types.") - int_type = _INT_DTYPES[info.bits] - x = lax.bitcast_convert_type(x, int_type) - return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_) - - @_wraps(np.trapz) @partial(jit, static_argnames=('axis',)) def trapz(y, x=None, dx=1.0, axis: int = -1): @@ -1131,129 +369,6 @@ def correlate(a, v, mode='valid', *, precision=None): return _conv(a, v, mode, 'correlate', precision) -def _normalize_float(x): - info = finfo(_dtype(x)) - cond = lax.abs(x) < info.tiny - x1 = where(cond, x * lax._const(x, 1 << info.nmant), x) - x2 = where(cond, lax._const(np.int32, -info.nmant), lax._const(np.int32, 0)) - int_type = _INT_DTYPES[info.bits] - return lax.bitcast_convert_type(x1, int_type), x2 - - -@_wraps(np.ldexp) -@jit -def ldexp(x1, x2): - _check_arraylike("ldexp", x1, x2) - dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) - x1, x2 = _promote_shapes("ldexp", x1, x2) - x1 = lax.convert_element_type(x1, dtype) - - info = finfo(dtype) - mask = (1 << info.nexp) - 1 - bias = ((1 << info.nexp) - 1) >> 1 - - int_type = _INT_DTYPES[info.bits] - - x, e = _normalize_float(x1) - x2 += e + ((x >> info.nmant) & mask) - bias - - # find underflow/overflow before denormalization - underflow_cond = x2 < -(bias + info.nmant) - overflow_cond = x2 > bias - - m = ones_like(x, dtype=dtype) - - # denormals - cond = x2 < -bias + 1 - x2 = where(cond, x2 + info.nmant, x2) - m = where(cond, m / (1 << info.nmant), m) - - x2 = lax.convert_element_type(x2, np.int32) - x &= ~(mask << info.nmant) - x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) - - x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) - - # underflow - x = where(underflow_cond, zeros_like(x, dtype=dtype), x) - # overflow - x = where(overflow_cond, lax.sign(x1) * full_like(x, np.inf), x) - # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 - return where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) - - -@_wraps(np.frexp) -@jit -def frexp(x): - _check_arraylike("frexp", x) - x = asarray(x) - if issubdtype(x.dtype, complexfloating): - raise TypeError("frexp does not support complex-valued inputs") - elif not issubdtype(x.dtype, floating): - x = lax.convert_element_type(x, float_) - - dtype = _dtype(x) - info = finfo(dtype) - mask = (1 << info.nexp) - 1 - bias = ((1 << info.nexp) - 1) >> 1 - - x1, x2 = _normalize_float(x) - x2 += ((x1 >> info.nmant) & mask) - bias + 1 - x1 &= ~(mask << info.nmant) - x1 |= (bias - 1) << info.nmant - x1 = lax.bitcast_convert_type(x1, dtype) - - cond = isinf(x) | isnan(x) | (x == 0) - x2 = where(cond, zeros_like(x2), x2) - return where(cond, x, x1), lax.convert_element_type(x2, int32) - - -@_wraps(np.remainder) -@jit -def remainder(x1, x2): - x1, x2 = _promote_args("remainder", x1, x2) - zero = _constant_like(x1, 0) - trunc_mod = lax.rem(x1, x2) - trunc_mod_not_zero = lax.ne(trunc_mod, zero) - do_plus = lax.bitwise_and( - lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) - return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) -mod = _wraps(np.mod)(remainder) - - -@_wraps(np.fmod) -@jit -def fmod(x1, x2): - _check_arraylike("fmod", x1, x2) - if issubdtype(result_type(x1, x2), integer): - x2 = where(x2 == 0, 1, x2) - return lax.rem(*_promote_args("fmod", x1, x2)) - - -@_wraps(np.square) -@partial(jit, inline=True) -def square(x): - _check_arraylike("square", x) - return lax.integer_pow(x, 2) - - -@_wraps(np.deg2rad) -@partial(jit, inline=True) -def deg2rad(x): - x, = _promote_args_inexact("deg2rad", x) - return lax.mul(x, lax._const(x, pi / 180)) - - -@_wraps(np.rad2deg) -@partial(jit, inline=True) -def rad2deg(x): - x, = _promote_args_inexact("rad2deg", x) - return lax.mul(x, lax._const(x, 180 / pi)) - - -degrees = rad2deg -radians = deg2rad - @_wraps(np.histogram_bin_edges) def histogram_bin_edges(a, bins=10, range=None, weights=None): @@ -1472,28 +587,6 @@ def flipud(m): return _flip(m, 0) -@_wraps(np.conjugate) -@partial(jit, inline=True) -def conjugate(x): - _check_arraylike("conjugate", x) - return lax.conj(x) if iscomplexobj(x) else x -conj = conjugate - - -@_wraps(np.imag) -@partial(jit, inline=True) -def imag(val): - _check_arraylike("imag", val) - return lax.imag(val) if iscomplexobj(val) else zeros_like(val) - - -@_wraps(np.real) -@partial(jit, inline=True) -def real(val): - _check_arraylike("real", val) - return lax.real(val) if iscomplexobj(val) else val - - @_wraps(np.iscomplex) @jit def iscomplex(x): @@ -1975,7 +1068,7 @@ def interp(x, xp, fp, left=None, right=None, period=None): In the JAX version, the `assume_unique` argument is not referenced. """) @partial(jit, static_argnames=('assume_unique', 'invert',)) -def in1d(ar1, ar2, assume_unique=False, invert=False): +def in1d(ar1, ar2, assume_unique=False, invert=False): # noqa: F811 _check_arraylike("in1d", ar1, ar2) ar1 = ravel(ar1) ar2 = ravel(ar2) @@ -2136,29 +1229,11 @@ def intersect1d(ar1, ar2, assume_unique=False, return_indices=False): @_wraps(np.isin, lax_description=""" In the JAX version, the `assume_unique` argument is not referenced. """) -def isin(element, test_elements, assume_unique=False, invert=False): +def isin(element, test_elements, assume_unique=False, invert=False): # noqa: F811 result = in1d(element, test_elements, assume_unique=assume_unique, invert=invert) return result.reshape(shape(element)) -# The `jit` on `where` exists to avoid materializing constants in cases like -# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to -# materialize the broadcast forms of scalar arguments. -@jit -def _where(condition, x=None, y=None): - if x is None or y is None: - raise ValueError("Either both or neither of the x and y arguments should " - "be provided to jax.numpy.where, got {} and {}." - .format(x, y)) - if not issubdtype(_dtype(condition), bool_): - condition = lax.ne(condition, zeros_like(condition)) - x, y = _promote_dtypes(x, y) - condition, x, y = broadcast_arrays(condition, x, y) - try: is_always_empty = core.is_empty_shape(np.shape(x)) - except: is_always_empty = False # can fail with dynamic shapes - return lax.select(condition, x, y) if not is_always_empty else x - - @_wraps(np.where, lax_description=_dedent(""" At present, JAX does not support JIT-compilation of the single-argument form @@ -2232,6 +1307,7 @@ def bincount(x, weights=None, minlength=0, *, length=None): raise ValueError("shape of weights must match shape of x.") return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights) + @_wraps(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): if not shapes: @@ -2239,44 +1315,15 @@ def broadcast_shapes(*shapes): shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes] return lax.broadcast_shapes(*shapes) -@partial(jit, inline=True) -def broadcast_arrays(*args): - """Like Numpy's broadcast_arrays but doesn't return views.""" - shapes = [shape(arg) for arg in args] - if not shapes or _all(core.symbolic_equal_shape(shapes[0], s) for s in shapes): - # TODO(mattjj): remove the array(arg) here - return [arg if isinstance(arg, ndarray) or isscalar(arg) else array(arg) - for arg in args] - result_shape = lax.broadcast_shapes(*shapes) - return [broadcast_to(arg, result_shape) for arg in args] - - -@_wraps(np.broadcast_to, lax_description="""\ + +broadcast_arrays = _wraps(np.broadcast_arrays, lax_description="""\ The JAX version does not necessarily return a view of the input. -""") -def broadcast_to(arr, shape): - if hasattr(arr, "broadcast_to"): - return arr.broadcast_to(shape) - arr = arr if isinstance(arr, ndarray) else array(arr) - if not isinstance(shape, tuple) and ndim(shape) == 0: - shape = (shape,) - shape = canonicalize_shape(shape) # check that shape is concrete - arr_shape = _shape(arr) - if core.symbolic_equal_shape(arr_shape, shape): - return arr - else: - nlead = len(shape) - len(arr_shape) - shape_tail = shape[nlead:] - compatible = _all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) - for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) - if nlead < 0 or not compatible: - msg = "Incompatible shapes for broadcasting: {} and requested shape {}" - raise ValueError(msg.format(arr_shape, shape)) - diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d) - for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) - new_dims = tuple(range(nlead)) + tuple(nlead + diff) - kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) - return lax.broadcast_in_dim(squeeze(arr, tuple(diff)), shape, kept_dims) +""")(_broadcast_arrays) + + +broadcast_to = _wraps(np.broadcast_to, lax_description="""\ +The JAX version does not necessarily return a view of the input. +""")(_broadcast_to) def _split(op, ary, indices_or_sections, axis=0): @@ -2394,68 +1441,6 @@ def fix(x, out=None): return where(lax.ge(x, zero), floor(x), ceil(x)) -@_wraps(np.modf, skip_params=['out']) -@jit -def modf(x, out=None): - _check_arraylike("modf", x) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") - whole = fix(x) - return x - whole, whole - - -@_wraps(np.isfinite) -@jit -def isfinite(x): - _check_arraylike("isfinite", x) - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.is_finite(x) - elif issubdtype(dtype, complexfloating): - return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) - else: - return full_like(x, True, dtype=bool_) - -@_wraps(np.isinf) -@jit -def isinf(x): - _check_arraylike("isinf", x) - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.eq(lax.abs(x), _constant_like(x, inf)) - elif issubdtype(dtype, complexfloating): - re = lax.real(x) - im = lax.imag(x) - return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, inf)), - lax.eq(lax.abs(im), _constant_like(im, inf))) - else: - return full_like(x, False, dtype=bool_) - -def _isposneginf(infinity, x, out): - if out is not None: - raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") - dtype = _dtype(x) - if issubdtype(dtype, floating): - return lax.eq(x, _constant_like(x, infinity)) - elif issubdtype(dtype, complexfloating): - raise ValueError("isposinf/isneginf are not well defined for complex types") - else: - return full_like(x, False, dtype=bool_) - -isposinf = _wraps(np.isposinf, skip_params=['out'])( - lambda x, out=None: _isposneginf(inf, x, out) -) - -isneginf = _wraps(np.isneginf, skip_params=['out'])( - lambda x, out=None: _isposneginf(-inf, x, out) -) - -@_wraps(np.isnan) -@jit -def isnan(x): - _check_arraylike("isnan", x) - return lax.ne(x, x) - @_wraps(np.nan_to_num) @jit def nan_to_num(x, copy=True, nan=0.0, posinf=None, neginf=None): diff --git a/jax/_src/numpy/ndarray.py b/jax/_src/numpy/ndarray.py new file mode 100644 index 000000000000..2da6673de1f4 --- /dev/null +++ b/jax/_src/numpy/ndarray.py @@ -0,0 +1,295 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ndarray is defined as an virtual abstract base class. + +import abc +from typing import Any, Optional, Tuple, Union + +from jax import core +from jax.interpreters import pxla +from jax._src import device_array + +import numpy as np + + +class ArrayMeta(abc.ABCMeta): + """Metaclass for overriding ndarray isinstance checks.""" + + def __instancecheck__(self, instance): + # Allow tracer instances with avals that are instances of UnshapedArray. + # We could instead just declare Tracer an instance of the ndarray type, but + # there can be traced values that are not arrays. The main downside here is + # that isinstance(x, ndarray) might return true but + # issubclass(type(x), ndarray) might return false for an array tracer. + try: + return (hasattr(instance, "aval") and + isinstance(instance.aval, core.UnshapedArray)) + except AttributeError: + super().__instancecheck__(instance) + + +class ndarray(metaclass=ArrayMeta): + dtype: np.dtype + ndim: int + shape: Tuple[int, ...] + size: int + + def __init__(self, shape, dtype=None, buffer=None, offset=0, strides=None, + order=None): + raise TypeError("jax.numpy.ndarray() should not be instantiated explicitly." + " Use jax.numpy.array, or jax.numpy.zeros instead.") + + @abc.abstractmethod + def __getitem__(self, key, indices_are_sorted=False, + unique_indices=False) -> Any: ... + @abc.abstractmethod + def __setitem__(self, key, value) -> Any: ... + @abc.abstractmethod + def __len__(self) -> Any: ... + @abc.abstractmethod + def __iter__(self) -> Any: ... + @abc.abstractmethod + def __reversed__(self) -> Any: ... + + # Comparisons + @abc.abstractmethod + def __lt__(self, other) -> Any: ... + @abc.abstractmethod + def __le__(self, other) -> Any: ... + @abc.abstractmethod + def __eq__(self, other) -> Any: ... + @abc.abstractmethod + def __ne__(self, other) -> Any: ... + @abc.abstractmethod + def __gt__(self, other) -> Any: ... + @abc.abstractmethod + def __ge__(self, other) -> Any: ... + + # Unary arithmetic + + @abc.abstractmethod + def __neg__(self) -> Any: ... + @abc.abstractmethod + def __pos__(self) -> Any: ... + @abc.abstractmethod + def __abs__(self) -> Any: ... + @abc.abstractmethod + def __invert__(self) -> Any: ... + + # Binary arithmetic + + @abc.abstractmethod + def __add__(self, other) -> Any: ... + @abc.abstractmethod + def __sub__(self, other) -> Any: ... + @abc.abstractmethod + def __mul__(self, other) -> Any: ... + @abc.abstractmethod + def __matmul__(self, other) -> Any: ... + @abc.abstractmethod + def __truediv__(self, other) -> Any: ... + @abc.abstractmethod + def __floordiv__(self, other) -> Any: ... + @abc.abstractmethod + def __mod__(self, other) -> Any: ... + @abc.abstractmethod + def __divmod__(self, other) -> Any: ... + @abc.abstractmethod + def __pow__(self, other) -> Any: ... + @abc.abstractmethod + def __lshift__(self, other) -> Any: ... + @abc.abstractmethod + def __rshift__(self, other) -> Any: ... + @abc.abstractmethod + def __and__(self, other) -> Any: ... + @abc.abstractmethod + def __xor__(self, other) -> Any: ... + @abc.abstractmethod + def __or__(self, other) -> Any: ... + + @abc.abstractmethod + def __radd__(self, other) -> Any: ... + @abc.abstractmethod + def __rsub__(self, other) -> Any: ... + @abc.abstractmethod + def __rmul__(self, other) -> Any: ... + @abc.abstractmethod + def __rmatmul__(self, other) -> Any: ... + @abc.abstractmethod + def __rtruediv__(self, other) -> Any: ... + @abc.abstractmethod + def __rfloordiv__(self, other) -> Any: ... + @abc.abstractmethod + def __rmod__(self, other) -> Any: ... + @abc.abstractmethod + def __rdivmod__(self, other) -> Any: ... + @abc.abstractmethod + def __rpow__(self, other) -> Any: ... + @abc.abstractmethod + def __rlshift__(self, other) -> Any: ... + @abc.abstractmethod + def __rrshift__(self, other) -> Any: ... + @abc.abstractmethod + def __rand__(self, other) -> Any: ... + @abc.abstractmethod + def __rxor__(self, other) -> Any: ... + @abc.abstractmethod + def __ror__(self, other) -> Any: ... + + @abc.abstractmethod + def __bool__(self) -> Any: ... + @abc.abstractmethod + def __complex__(self) -> Any: ... + @abc.abstractmethod + def __int__(self) -> Any: ... + @abc.abstractmethod + def __float__(self) -> Any: ... + @abc.abstractmethod + def __round__(self, ndigits=None) -> Any: ... + + @abc.abstractmethod + def __index__(self) -> Any: ... + + # np.ndarray methods: + @abc.abstractmethod + def all(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None) -> Any: ... + @abc.abstractmethod + def any(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None) -> Any: ... + @abc.abstractmethod + def argmax(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... + @abc.abstractmethod + def argmin(self, axis: Optional[int] = None, out=None, keepdims=None) -> Any: ... + @abc.abstractmethod + def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Any: ... + @abc.abstractmethod + def argsort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... + @abc.abstractmethod + def astype(self, dtype) -> Any: ... + @abc.abstractmethod + def choose(self, choices, out=None, mode='raise') -> Any: ... + @abc.abstractmethod + def clip(self, a_min=None, a_max=None, out=None) -> Any: ... + @abc.abstractmethod + def compress(self, condition, axis: Optional[int] = None, out=None) -> Any: ... + @abc.abstractmethod + def conj(self) -> Any: ... + @abc.abstractmethod + def conjugate(self) -> Any: ... + @abc.abstractmethod + def copy(self) -> Any: ... + @abc.abstractmethod + def cumprod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None) -> Any: ... + @abc.abstractmethod + def cumsum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None) -> Any: ... + @abc.abstractmethod + def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Any: ... + @abc.abstractmethod + def dot(self, b, *, precision=None) -> Any: ... + @abc.abstractmethod + def flatten(self) -> Any: ... + @property + @abc.abstractmethod + def imag(self) -> Any: ... + @abc.abstractmethod + def item(self, *args) -> Any: ... + @abc.abstractmethod + def max(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None) -> Any: ... + @abc.abstractmethod + def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=False, *, where=None,) -> Any: ... + @abc.abstractmethod + def min(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=None, initial=None, where=None) -> Any: ... + @property + @abc.abstractmethod + def nbytes(self) -> Any: ... + @abc.abstractmethod + def nonzero(self, *, size=None, fill_value=None) -> Any: ... + @abc.abstractmethod + def prod(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None) -> Any: ... + @abc.abstractmethod + def ptp(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, + keepdims=False,) -> Any: ... + @abc.abstractmethod + def ravel(self, order='C') -> Any: ... + @property + @abc.abstractmethod + def real(self) -> Any: ... + @abc.abstractmethod + def repeat(self, repeats, axis: Optional[int] = None, *, + total_repeat_length=None) -> Any: ... + @abc.abstractmethod + def reshape(self, *args, order='C') -> Any: ... + @abc.abstractmethod + def round(self, decimals=0, out=None) -> Any: ... + @abc.abstractmethod + def searchsorted(self, v, side='left', sorter=None) -> Any: ... + @abc.abstractmethod + def sort(self, axis: Optional[int] = -1, kind='quicksort', order=None) -> Any: ... + @abc.abstractmethod + def squeeze(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Any: ... + @abc.abstractmethod + def std(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... + @abc.abstractmethod + def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, + out=None, keepdims=None, initial=None, where=None) -> Any: ... + @abc.abstractmethod + def swapaxes(self, axis1: int, axis2: int) -> Any: ... + @abc.abstractmethod + def take(self, indices, axis: Optional[int] = None, out=None, + mode=None) -> Any: ... + @abc.abstractmethod + def tobytes(self, order='C') -> Any: ... + @abc.abstractmethod + def tolist(self) -> Any: ... + @abc.abstractmethod + def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, + out=None) -> Any: ... + @abc.abstractmethod + def transpose(self, *args) -> Any: ... + @abc.abstractmethod + def var(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, + dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Any: ... + @abc.abstractmethod + def view(self, dtype=None, type=None) -> Any: ... + + # Even though we don't always support the NumPy array protocol, e.g., for + # tracer types, for type checking purposes we must declare support so we + # implement the NumPy ArrayLike protocol. + def __array__(self) -> Any: ... + + # JAX extensions + @property + @abc.abstractmethod + def at(self) -> Any: ... + @property + @abc.abstractmethod + def aval(self) -> Any: ... + @property + @abc.abstractmethod + def weak_type(self) -> bool: ... + + +ndarray.register(device_array.DeviceArray) +for t in device_array.device_array_types: + ndarray.register(t) +ndarray.register(pxla._SDA_BASE_CLASS) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py new file mode 100644 index 000000000000..11d21c44491c --- /dev/null +++ b/jax/_src/numpy/ufuncs.py @@ -0,0 +1,654 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pytype: skip-file +""" +Implements ufuncs for jax.numpy. +""" + +from functools import partial +import operator +from textwrap import dedent + +import numpy as np + +from jax._src.api import jit, custom_jvp +from jax._src import dtypes +from jax._src.lax import lax +from jax._src.numpy.util import ( + _check_arraylike, _promote_args, _promote_args_inexact, + _promote_shapes, _where, _wraps) +from jax import core + + +_INT_DTYPES = { + 16: np.int16, + 32: np.int32, + 64: np.int64, +} + + +def _constant_like(x, const): + return np.array(const, dtype=dtypes.dtype(x)) + + +def _result_dtype(op, *args): + """Compute result dtype of applying op to arguments with given dtypes.""" + args = [np.ones((0,) * np.ndim(arg), dtypes.dtype(arg)) for arg in args] + return dtypes.dtype(op(*args)) + + +def _replace_inf(x): + return _where(isposinf(real(x)), lax._zeros(x), x) + + +def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): + if promote_to_inexact: + fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x)) + else: + fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x)) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _one_to_one_binop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False): + if promote_to_inexact: + fn = lambda x1, x2: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x1, x2)) + else: + fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.__name__, x1, x2)) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _maybe_bool_binop(numpy_fn, lax_fn, bool_lax_fn, lax_doc=False): + def fn(x1, x2): + x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) + return lax_fn(x1, x2) if x1.dtype != np.bool_ else bool_lax_fn(x1, x2) + fn = jit(fn, inline=True) + if lax_doc: + doc = dedent('\n\n'.join(lax_fn.__doc__.split('\n\n')[1:])).strip() + return _wraps(numpy_fn, lax_description=doc)(fn) + else: + return _wraps(numpy_fn)(fn) + + +def _comparison_op(numpy_fn, lax_fn): + # TODO(https://github.com/google/jax/issues/6713): decorate this function with + # jit, after fixing a surprising interaction with remat(..., concrete=True). + def fn(x1, x2): + x1, x2 = _promote_args(numpy_fn.__name__, x1, x2) + # Comparison on complex types are defined as a lexicographic ordering on + # the (real, imag) pair. + if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): + rx = lax.real(x1) + ry = lax.real(x2) + return _where(lax.eq(rx, ry), lax_fn(lax.imag(x1), lax.imag(x2)), + lax_fn(rx, ry)) + return lax_fn(x1, x2) + return _wraps(numpy_fn)(fn) + + +def _logical_op(np_op, bitwise_op): + @_wraps(np_op, update_doc=False) + @partial(jit, inline=True) + def op(*args): + zero = lambda x: lax.full_like(x, shape=(), fill_value=0) + args = (x if dtypes.issubdtype(dtypes.dtype(x), np.bool_) else lax.ne(x, zero(x)) + for x in args) + return bitwise_op(*_promote_args(np_op.__name__, *args)) + return op + + +fabs = _one_to_one_unop(np.fabs, lax.abs, True) +bitwise_not = _one_to_one_unop(np.bitwise_not, lax.bitwise_not) +invert = _one_to_one_unop(np.invert, lax.bitwise_not) +negative = _one_to_one_unop(np.negative, lax.neg) +positive = _one_to_one_unop(np.positive, lambda x: x) +floor = _one_to_one_unop(np.floor, lax.floor, True) +ceil = _one_to_one_unop(np.ceil, lax.ceil, True) +exp = _one_to_one_unop(np.exp, lax.exp, True) +log = _one_to_one_unop(np.log, lax.log, True) +expm1 = _one_to_one_unop(np.expm1, lax.expm1, True) +log1p = _one_to_one_unop(np.log1p, lax.log1p, True) +sin = _one_to_one_unop(np.sin, lax.sin, True) +cos = _one_to_one_unop(np.cos, lax.cos, True) +tan = _one_to_one_unop(np.tan, lax.tan, True) +arcsin = _one_to_one_unop(np.arcsin, lax.asin, True) +arccos = _one_to_one_unop(np.arccos, lax.acos, True) +arctan = _one_to_one_unop(np.arctan, lax.atan, True) +sinh = _one_to_one_unop(np.sinh, lax.sinh, True) +cosh = _one_to_one_unop(np.cosh, lax.cosh, True) +arcsinh = _one_to_one_unop(np.arcsinh, lax.asinh, True) +tanh = _one_to_one_unop(np.tanh, lax.tanh, True) +arctanh = _one_to_one_unop(np.arctanh, lax.atanh, True) +sqrt = _one_to_one_unop(np.sqrt, lax.sqrt, True) +cbrt = _one_to_one_unop(np.cbrt, lax.cbrt, True) + +add = _maybe_bool_binop(np.add, lax.add, lax.bitwise_or) +bitwise_and = _one_to_one_binop(np.bitwise_and, lax.bitwise_and) +bitwise_or = _one_to_one_binop(np.bitwise_or, lax.bitwise_or) +bitwise_xor = _one_to_one_binop(np.bitwise_xor, lax.bitwise_xor) +left_shift = _one_to_one_binop(np.left_shift, lax.shift_left) +equal = _one_to_one_binop(np.equal, lax.eq) +multiply = _maybe_bool_binop(np.multiply, lax.mul, lax.bitwise_and) +not_equal = _one_to_one_binop(np.not_equal, lax.ne) +subtract = _one_to_one_binop(np.subtract, lax.sub) +arctan2 = _one_to_one_binop(np.arctan2, lax.atan2, True) +minimum = _one_to_one_binop(np.minimum, lax.min) +maximum = _one_to_one_binop(np.maximum, lax.max) +float_power = _one_to_one_binop(np.float_power, lax.pow, True) +nextafter = _one_to_one_binop(np.nextafter, lax.nextafter, True, True) + +greater_equal = _comparison_op(np.greater_equal, lax.ge) +greater = _comparison_op(np.greater, lax.gt) +less_equal = _comparison_op(np.less_equal, lax.le) +less = _comparison_op(np.less, lax.lt) + +logical_and = _logical_op(np.logical_and, lax.bitwise_and) +logical_not = _logical_op(np.logical_not, lax.bitwise_not) +logical_or = _logical_op(np.logical_or, lax.bitwise_or) +logical_xor = _logical_op(np.logical_xor, lax.bitwise_xor) + + +@_wraps(np.arccosh) +@jit +def arccosh(x): + # Note: arccosh is multi-valued for complex input, and lax.acosh uses a different + # convention than np.arccosh. + out = lax.acosh(*_promote_args_inexact("arccosh", x)) + if dtypes.issubdtype(out.dtype, np.complexfloating): + out = _where(real(out) < 0, lax.neg(out), out) + return out + + +@_wraps(np.right_shift) +@partial(jit, inline=True) +def right_shift(x1, x2): + x1, x2 = _promote_args(np.right_shift.__name__, x1, x2) + lax_fn = lax.shift_right_logical if \ + np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic + return lax_fn(x1, x2) + + +@_wraps(np.absolute) +@partial(jit, inline=True) +def absolute(x): + _check_arraylike('absolute', x) + dt = dtypes.dtype(x) + return x if dt == np.bool_ or dtypes.issubdtype(dt, np.unsignedinteger) else lax.abs(x) +abs = _wraps(np.abs)(absolute) + + +@_wraps(np.rint) +@jit +def rint(x): + _check_arraylike('rint', x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.integer): + return lax.convert_element_type(x, dtypes.float_) + if dtypes.issubdtype(dtype, np.complexfloating): + return lax.complex(rint(lax.real(x)), rint(lax.imag(x))) + return lax.round(x, lax.RoundingMethod.TO_NEAREST_EVEN) + + +@_wraps(np.sign) +@jit +def sign(x): + _check_arraylike('sign', x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.complexfloating): + re = lax.real(x) + return lax.complex( + lax.sign(_where(re != 0, re, lax.imag(x))), _constant_like(re, 0)) + return lax.sign(x) + + +@_wraps(np.copysign) +@jit +def copysign(x1, x2): + x1, x2 = _promote_args_inexact("copysign", x1, x2) + if dtypes.issubdtype(dtypes.dtype(x1), np.complexfloating): + raise TypeError("copysign does not support complex-valued inputs") + return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) + + +@_wraps(np.true_divide) +@partial(jit, inline=True) +def true_divide(x1, x2): + x1, x2 = _promote_args_inexact("true_divide", x1, x2) + return lax.div(x1, x2) + +divide = true_divide + + +@_wraps(np.floor_divide) +@jit +def floor_divide(x1, x2): + x1, x2 = _promote_args("floor_divide", x1, x2) + dtype = dtypes.dtype(x1) + if dtypes.issubdtype(dtype, np.integer): + quotient = lax.div(x1, x2) + select = logical_and(lax.sign(x1) != lax.sign(x2), lax.rem(x1, x2) != 0) + # TODO(mattjj): investigate why subtracting a scalar was causing promotion + return _where(select, quotient - 1, quotient) + elif dtypes.issubdtype(dtype, np.complexfloating): + x1r = lax.real(x1) + x1i = lax.imag(x1) + x2r = lax.real(x2) + x2i = lax.imag(x2) + which = lax.ge(lax.abs(x2r), lax.abs(x2i)) + rat1 = _where(which, lax.full_like(x2i, 1), lax.div(x2r, x2i)) + rat2 = _where(which, lax.div(x2i, x2r), lax._const(x2i, 1)) + out = lax.floor(lax.div(lax.add(lax.mul(x1r, rat1), lax.mul(x1i, rat2)), + lax.add(lax.mul(x2r, rat1), lax.mul(x2i, rat2)))) + return lax.convert_element_type(out, dtype) + else: + return _float_divmod(x1, x2)[0] + + +@_wraps(np.divmod) +@jit +def divmod(x1, x2): + x1, x2 = _promote_args("divmod", x1, x2) + if dtypes.issubdtype(dtypes.dtype(x1), np.integer): + return floor_divide(x1, x2), remainder(x1, x2) + else: + return _float_divmod(x1, x2) + + +def _float_divmod(x1, x2): + # see float_divmod in floatobject.c of CPython + mod = lax.rem(x1, x2) + div = lax.div(lax.sub(x1, mod), x2) + + ind = lax.bitwise_and(mod != 0, lax.sign(x2) != lax.sign(mod)) + mod = _where(ind, mod + x2, mod) + div = _where(ind, div - _constant_like(div, 1), div) + + return lax.round(div), mod + + +@partial(jit, inline=True) +def _power(x1, x2): + x1, x2 = _promote_args("power", x1, x2) + dtype = dtypes.dtype(x1) + if not dtypes.issubdtype(dtype, np.integer): + return lax.pow(x1, x2) + + # Integer power => use binary exponentiation. + + # TODO(phawkins): add integer pow support to XLA. + bits = 6 # Anything more would overflow for any x1 > 1 + zero = _constant_like(x2, 0) + one = _constant_like(x2, 1) + # Initialize acc carefully such that pow(0, x2) is zero for x2 != 0 + acc = _where(lax.bitwise_and(lax.eq(x1, zero), lax.ne(x2, zero)), zero, one) + for _ in range(bits): + acc = _where(lax.bitwise_and(x2, one), lax.mul(acc, x1), acc) + x1 = lax.mul(x1, x1) + x2 = lax.shift_right_logical(x2, one) + return acc + + +@_wraps(np.power) +def power(x1, x2): + # Special case for concrete integer scalars: use binary exponentiation. + # Using lax.pow may be imprecise for floating-point values; the goal of this + # code path is to make sure we end up with a precise output for the common + # pattern ``x ** 2`` or similar. + if isinstance(core.get_aval(x2), core.ConcreteArray): + try: + x2 = operator.index(x2) + except TypeError: + pass + else: + return lax.integer_pow(x1, x2) + return _power(x1, x2) + + +@custom_jvp +@_wraps(np.logaddexp) +@jit +def logaddexp(x1, x2): + x1, x2 = _promote_args_inexact("logaddexp", x1, x2) + amax = lax.max(x1, x2) + if dtypes.issubdtype(x1.dtype, np.floating): + delta = lax.sub(x1, x2) + return _where(lax._isnan(delta), + lax.add(x1, x2), # NaNs or infinities of the same sign. + lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) + else: + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) + out = lax.add(amax, lax.log1p(lax.exp(delta))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + + +def _wrap_between(x, _a): + """Wraps `x` between `[-a, a]`.""" + a = _constant_like(x, _a) + two_a = _constant_like(x, 2 * _a) + zero = _constant_like(x, 0) + rem = lax.rem(lax.add(x, a), two_a) + rem = _where(lax.lt(rem, zero), lax.add(rem, two_a), rem) + return lax.sub(rem, a) + + +@logaddexp.defjvp +def _logaddexp_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + x1, x2, t1, t2 = _promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) + primal_out = logaddexp(x1, x2) + tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out + + +@custom_jvp +@_wraps(np.logaddexp2) +@jit +def logaddexp2(x1, x2): + x1, x2 = _promote_args_inexact("logaddexp2", x1, x2) + amax = lax.max(x1, x2) + if dtypes.issubdtype(x1.dtype, np.floating): + delta = lax.sub(x1, x2) + return _where(lax._isnan(delta), + lax.add(x1, x2), # NaNs or infinities of the same sign. + lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))), + _constant_like(x1, np.log(2))))) + else: + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) + out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2)))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) + + +@logaddexp2.defjvp +def _logaddexp2_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + x1, x2, t1, t2 = _promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2) + primal_out = logaddexp2(x1, x2) + tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out + + +@_wraps(np.log2) +@partial(jit, inline=True) +def log2(x): + x, = _promote_args_inexact("log2", x) + return lax.div(lax.log(x), lax.log(_constant_like(x, 2))) + + +@_wraps(np.log10) +@partial(jit, inline=True) +def log10(x): + x, = _promote_args_inexact("log10", x) + return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) + + +@_wraps(np.exp2) +@partial(jit, inline=True) +def exp2(x): + x, = _promote_args_inexact("exp2", x) + return lax.exp(lax.mul(lax.log(_constant_like(x, 2)), x)) + + +@_wraps(np.signbit) +@jit +def signbit(x): + x, = _promote_args("signbit", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.integer): + return lax.lt(x, _constant_like(x, 0)) + elif dtypes.issubdtype(dtype, np.bool_): + return lax.full_like(x, False, dtype=np.bool_) + elif not dtypes.issubdtype(dtype, np.floating): + raise ValueError( + "jax.numpy.signbit is not well defined for %s" % dtype) + + # TPU supports BF16 but not S16 types, so as a workaround, convert BF16 to + # F32. + if dtype == dtypes.bfloat16: + dtype = np.float32 + x = lax.convert_element_type(x, np.float32) + + info = dtypes.finfo(dtype) + if info.bits not in _INT_DTYPES: + raise NotImplementedError( + "jax.numpy.signbit only supports 16, 32, and 64-bit types.") + int_type = _INT_DTYPES[info.bits] + x = lax.bitcast_convert_type(x, int_type) + return lax.convert_element_type(x >> (info.nexp + info.nmant), np.bool_) + + +@_wraps(np.conjugate) +@partial(jit, inline=True) +def conjugate(x): + _check_arraylike("conjugate", x) + return lax.conj(x) if np.iscomplexobj(x) else x +conj = conjugate + + +@_wraps(np.imag) +@partial(jit, inline=True) +def imag(val): + _check_arraylike("imag", val) + return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) + + +@_wraps(np.real) +@partial(jit, inline=True) +def real(val): + _check_arraylike("real", val) + return lax.real(val) if np.iscomplexobj(val) else val + + +def _normalize_float(x): + info = dtypes.finfo(dtypes.dtype(x)) + cond = lax.abs(x) < info.tiny + x1 = _where(cond, x * lax._const(x, 1 << info.nmant), x) + x2 = _where(cond, lax.full_like(x, -info.nmant, dtype=np.int32), lax.full_like(x, 0, dtype=np.int32)) + int_type = _INT_DTYPES[info.bits] + return lax.bitcast_convert_type(x1, int_type), x2 + + +@_wraps(np.ldexp) +@jit +def ldexp(x1, x2): + _check_arraylike("ldexp", x1, x2) + dtype = dtypes.canonicalize_dtype(_result_dtype(np.ldexp, x1, x2)) + x1, x2 = _promote_shapes("ldexp", x1, x2) + x1 = lax.convert_element_type(x1, dtype) + + info = dtypes.finfo(dtype) + mask = (1 << info.nexp) - 1 + bias = ((1 << info.nexp) - 1) >> 1 + + int_type = _INT_DTYPES[info.bits] + + x, e = _normalize_float(x1) + x2 += e + ((x >> info.nmant) & mask) - bias + + # find underflow/overflow before denormalization + underflow_cond = x2 < -(bias + info.nmant) + overflow_cond = x2 > bias + + m = lax.full_like(x, 1, dtype=dtype) + + # denormals + cond = x2 < -bias + 1 + x2 = _where(cond, x2 + info.nmant, x2) + m = _where(cond, m / (1 << info.nmant), m) + + x2 = lax.convert_element_type(x2, np.int32) + x &= ~(mask << info.nmant) + x |= ((lax.convert_element_type(x2, int_type) + bias) << info.nmant) + + x = lax.convert_element_type(m, dtype) * lax.bitcast_convert_type(x, dtype) + + # underflow + x = _where(underflow_cond, lax.full_like(x, 0, dtype=dtype), x) + # overflow + x = _where(overflow_cond, lax.sign(x1) * lax.full_like(x, np.inf), x) + # ldexp(x1, x2) = x1 for x1 = inf, -inf, nan, 0 + return _where(isinf(x1) | isnan(x1) | (x1 == 0), x1, x) + + +@_wraps(np.frexp) +@jit +def frexp(x): + _check_arraylike("frexp", x) + if dtypes.issubdtype(x.dtype, np.complexfloating): + raise TypeError("frexp does not support complex-valued inputs") + elif not dtypes.issubdtype(dtypes.dtype(x), np.floating): + x = lax.convert_element_type(x, np.float_) + + dtype = dtypes.dtype(x) + info = dtypes.finfo(dtype) + mask = (1 << info.nexp) - 1 + bias = ((1 << info.nexp) - 1) >> 1 + + x1, x2 = _normalize_float(x) + x2 += ((x1 >> info.nmant) & mask) - bias + 1 + x1 &= ~(mask << info.nmant) + x1 |= (bias - 1) << info.nmant + x1 = lax.bitcast_convert_type(x1, dtype) + + cond = isinf(x) | isnan(x) | (x == 0) + x2 = _where(cond, lax._zeros(x2), x2) + return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) + + +@_wraps(np.remainder) +@jit +def remainder(x1, x2): + x1, x2 = _promote_args("remainder", x1, x2) + zero = _constant_like(x1, 0) + trunc_mod = lax.rem(x1, x2) + trunc_mod_not_zero = lax.ne(trunc_mod, zero) + do_plus = lax.bitwise_and( + lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) + return _where(do_plus, lax.add(trunc_mod, x2), trunc_mod) +mod = _wraps(np.mod)(remainder) + + +@_wraps(np.fmod) +@jit +def fmod(x1, x2): + _check_arraylike("fmod", x1, x2) + if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): + x2 = _where(x2 == 0, lax._ones(x2), x2) + return lax.rem(*_promote_args("fmod", x1, x2)) + + +@_wraps(np.square) +@partial(jit, inline=True) +def square(x): + _check_arraylike("square", x) + return lax.integer_pow(x, 2) + + +@_wraps(np.deg2rad) +@partial(jit, inline=True) +def deg2rad(x): + x, = _promote_args_inexact("deg2rad", x) + return lax.mul(x, lax._const(x, np.pi / 180)) + + +@_wraps(np.rad2deg) +@partial(jit, inline=True) +def rad2deg(x): + x, = _promote_args_inexact("rad2deg", x) + return lax.mul(x, lax._const(x, 180 / np.pi)) + + +degrees = rad2deg +radians = deg2rad + + +@_wraps(np.modf, skip_params=['out']) +@jit +def modf(x, out=None): + _check_arraylike("modf", x) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.modf is not supported.") + whole = _where(lax.ge(x, lax._zero(x)), floor(x), ceil(x)) + return x - whole, whole + + +@_wraps(np.isfinite) +@jit +def isfinite(x): + _check_arraylike("isfinite", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.is_finite(x) + elif dtypes.issubdtype(dtype, np.complexfloating): + return lax.bitwise_and(lax.is_finite(real(x)), lax.is_finite(imag(x))) + else: + return lax.full_like(x, True, dtype=np.bool_) + + +@_wraps(np.isinf) +@jit +def isinf(x): + _check_arraylike("isinf", x) + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.eq(lax.abs(x), _constant_like(x, np.inf)) + elif dtypes.issubdtype(dtype, np.complexfloating): + re = lax.real(x) + im = lax.imag(x) + return lax.bitwise_or(lax.eq(lax.abs(re), _constant_like(re, np.inf)), + lax.eq(lax.abs(im), _constant_like(im, np.inf))) + else: + return lax.full_like(x, False, dtype=np.bool_) + + +def _isposneginf(infinity, x, out): + if out is not None: + raise NotImplementedError("The 'out' argument to isneginf/isposinf is not supported.") + dtype = dtypes.dtype(x) + if dtypes.issubdtype(dtype, np.floating): + return lax.eq(x, _constant_like(x, infinity)) + elif dtypes.issubdtype(dtype, np.complexfloating): + raise ValueError("isposinf/isneginf are not well defined for complex types") + else: + return lax.full_like(x, False, dtype=np.bool_) + + +isposinf = _wraps(np.isposinf, skip_params=['out'])( + lambda x, out=None: _isposneginf(np.inf, x, out) +) + + +isneginf = _wraps(np.isneginf, skip_params=['out'])( + lambda x, out=None: _isposneginf(-np.inf, x, out) +) + + +@_wraps(np.isnan) +@jit +def isnan(x): + _check_arraylike("isnan", x) + return lax.ne(x, x) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 365769d6544f..9a74047b516d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -12,11 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import re import textwrap from typing import Callable, NamedTuple, Optional, Dict, Sequence +import warnings from jax._src.config import config +from jax._src import dtypes +from jax._src.numpy.ndarray import ndarray +from jax._src.util import safe_zip +from jax._src import api +from jax import core +from jax import lax + +import numpy as np _parameter_break = re.compile("\n(?=[A-Za-z_])") _section_break = re.compile(r"\n(?=[^\n]{3,15}\n-{3,15})", re.MULTILINE) @@ -178,3 +188,159 @@ def wrap(op): setattr(op, attr, value) return op return wrap + +def _promote_dtypes(*args): + """Convenience function to apply Numpy argument dtype promotion.""" + # TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing. + if len(args) < 2: + return args + else: + to_dtype, weak_type = dtypes._lattice_result_type(*args) + to_dtype = dtypes.canonicalize_dtype(to_dtype) + return [lax._convert_element_type(x, to_dtype, weak_type) for x in args] + +def _promote_dtypes_inexact(*args): + """Convenience function to apply Numpy argument dtype promotion. + + Promotes arguments to an inexact type.""" + to_dtype, weak_type = dtypes._lattice_result_type(*args) + to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype_inexact = _to_inexact_dtype(to_dtype) + weak_type = (weak_type and to_dtype == to_dtype_inexact) + return [lax._convert_element_type(x, to_dtype_inexact, weak_type) for x in args] + +def _to_inexact_dtype(dtype): + """Promotes a dtype into an inexact dtype, if it is not already one.""" + return dtype if dtypes.issubdtype(dtype, np.inexact) else dtypes.promote_types(dtype, dtypes.float_) + +def _arraylike(x): + return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or + hasattr(x, '__jax_array__') or dtypes.is_python_scalar(x) or np.isscalar(x)) + +def _check_arraylike(fun_name, *args): + """Check if all args fit JAX's definition of arraylike.""" + assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}" + if any(not _arraylike(arg) for arg in args): + pos, arg = next((i, arg) for i, arg in enumerate(args) + if not _arraylike(arg)) + msg = "{} requires ndarray or scalar arguments, got {} at position {}." + raise TypeError(msg.format(fun_name, type(arg), pos)) + +def _check_no_float0s(fun_name, *args): + """Check if none of the args have dtype float0.""" + if any(dtypes.dtype(arg) is dtypes.float0 for arg in args): + raise TypeError( + f"Called {fun_name} with a float0 array. " + "float0s do not support any operations by design because they " + "are not compatible with non-trivial vector spaces. No implicit dtype " + "conversion is done. You can use np.zeros_like(arr, dtype=np.float) " + "to cast a float0 array to a regular zeros array. \n" + "If you didn't expect to get a float0 you might have accidentally " + "taken a gradient with respect to an integer argument.") + +def _promote_args(fun_name, *args): + """Convenience function to apply Numpy argument shape and dtype promotion.""" + _check_arraylike(fun_name, *args) + _check_no_float0s(fun_name, *args) + return _promote_shapes(fun_name, *_promote_dtypes(*args)) + +def _promote_args_inexact(fun_name, *args): + """Convenience function to apply Numpy argument shape and dtype promotion. + + Promotes non-inexact types to an inexact type.""" + _check_arraylike(fun_name, *args) + _check_no_float0s(fun_name, *args) + return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args)) + +def _promote_shapes(fun_name, *args): + """Apply NumPy-style broadcasting, making args shape-compatible for lax.py.""" + if len(args) < 2: + return args + else: + shapes = [np.shape(arg) for arg in args] + if all(len(shapes[0]) == len(s) for s in shapes[1:]): + return args # no need for rank promotion, so rely on lax promotion + nonscalar_ranks = {len(shp) for shp in shapes if shp} + if len(nonscalar_ranks) < 2: + return args + else: + if config.jax_numpy_rank_promotion != "allow": + _rank_promotion_warning_or_error(fun_name, shapes) + if config.jax_dynamic_shapes: + # With dynamic shapes we don't support singleton-dimension broadcasting; + # we instead broadcast out to the full shape as a temporary workaround. + res_shape = lax.broadcast_shapes(*shapes) + return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)] + else: + result_rank = len(lax.broadcast_shapes(*shapes)) + return [_broadcast_to(arg, (1,) * (result_rank - len(shp)) + shp) + for arg, shp in zip(args, shapes)] + +def _rank_promotion_warning_or_error(fun_name, shapes): + if config.jax_numpy_rank_promotion == "warn": + msg = ("Following NumPy automatic rank promotion for {} on shapes {}. " + "Set the jax_numpy_rank_promotion config option to 'allow' to " + "disable this warning; for more information, see " + "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + warnings.warn(msg.format(fun_name, ' '.join(map(str, shapes)))) + elif config.jax_numpy_rank_promotion == "raise": + msg = ("Operands could not be broadcast together for {} on shapes {} " + "and with the config option jax_numpy_rank_promotion='raise'. " + "For more information, see " + "https://jax.readthedocs.io/en/latest/rank_promotion_warning.html.") + raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes)))) + + +def _broadcast_to(arr, shape): + if hasattr(arr, "broadcast_to"): + return arr.broadcast_to(shape) + arr = arr if isinstance(arr, ndarray) else api.device_put(arr) + if not isinstance(shape, tuple) and np.ndim(shape) == 0: + shape = (shape,) + shape = core.canonicalize_shape(shape) # check that shape is concrete + arr_shape = np.shape(arr) + if core.symbolic_equal_shape(arr_shape, shape): + return arr + else: + nlead = len(shape) - len(arr_shape) + shape_tail = shape[nlead:] + compatible = all(core.symbolic_equal_one_of_dim(arr_d, [1, shape_d]) + for arr_d, shape_d in safe_zip(arr_shape, shape_tail)) + if nlead < 0 or not compatible: + msg = "Incompatible shapes for broadcasting: {} and requested shape {}" + raise ValueError(msg.format(arr_shape, shape)) + diff, = np.where(tuple(not core.symbolic_equal_dim(arr_d, shape_d) + for arr_d, shape_d in safe_zip(arr_shape, shape_tail))) + new_dims = tuple(range(nlead)) + tuple(nlead + diff) + kept_dims = tuple(np.delete(np.arange(len(shape)), new_dims)) + return lax.broadcast_in_dim(lax.squeeze(arr, tuple(diff)), shape, kept_dims) + + +@partial(api.jit, inline=True) +def _broadcast_arrays(*args): + """Like Numpy's broadcast_arrays but doesn't return views.""" + shapes = [np.shape(arg) for arg in args] + if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes): + # TODO(mattjj): remove the array(arg) here + return [arg if isinstance(arg, ndarray) or np.isscalar(arg) + else api.device_put(arg) for arg in args] + result_shape = lax.broadcast_shapes(*shapes) + return [_broadcast_to(arg, result_shape) for arg in args] + + +# The `jit` on `where` exists to avoid materializing constants in cases like +# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to +# materialize the broadcast forms of scalar arguments. +@api.jit +def _where(condition, x=None, y=None): + if x is None or y is None: + raise ValueError("Either both or neither of the x and y arguments should " + "be provided to jax.numpy.where, got {} and {}." + .format(x, y)) + if not dtypes.issubdtype(dtypes.dtype(condition), np.bool_): + condition = lax.ne(condition, lax._const(condition, 0)) + x, y = _promote_dtypes(x, y) + condition, x, y = _broadcast_arrays(condition, x, y) + try: is_always_empty = core.is_empty_shape(np.shape(x)) + except: is_always_empty = False # can fail with dynamic shapes + return lax.select(condition, x, y) if not is_always_empty else x