Skip to content

Commit

Permalink
Add types to jax/_src/numpy/util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 4, 2022
1 parent ae49d2e commit 069866e
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 38 deletions.
2 changes: 1 addition & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -4353,7 +4353,7 @@ def _rng_bit_generator_named_shape_rule(key, *, shape, dtype, algorithm):
_rng_bit_generator_lowering)


def _array_copy(arr):
def _array_copy(arr: ArrayLike) -> Array:
return copy_p.bind(arr)

# The copy_p primitive exists for expressing making copies of runtime arrays.
Expand Down
16 changes: 10 additions & 6 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -76,6 +76,7 @@
_register_stackable, _stackable, _where, _wraps)
from jax._src.numpy.vectorize import vectorize
from jax._src.ops import scatter
from jax._src.typing import Array, ArrayLike, DTypeLike
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis)
from jax._src.array import ArrayImpl
Expand Down Expand Up @@ -1838,7 +1839,8 @@ def atleast_3d(*arys):
"""

@_wraps(np.array, lax_description=_ARRAY_DOC)
def array(object, dtype=None, copy=True, order="K", ndmin=0):
def array(object: Any, dtype: Optional[DTypeLike] = None, copy: bool = True,
order: str = "K", ndmin: int = 0) -> Array:
if order is not None and order != "K":
raise NotImplementedError("Only implemented for order='K'")

Expand Down Expand Up @@ -1878,6 +1880,8 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# (See https://github.com/google/jax/issues/8950)
ndarray_types = (device_array.DeviceArray, core.Tracer, ArrayImpl)

out: ArrayLike

if not _any(isinstance(leaf, ndarray_types) for leaf in leaves):
# TODO(jakevdp): falling back to numpy here fails to overflow for lists
# containing large integers; see discussion in
Expand All @@ -1902,10 +1906,10 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):

raise TypeError(f"Unexpected input type for array: {type(object)}")

out = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out):
out = lax.expand_dims(out, range(ndmin - ndim(out)))
return out
out_array: Array = lax_internal._convert_element_type(out, dtype, weak_type=weak_type)
if ndmin > ndim(out_array):
out_array = lax.expand_dims(out_array, range(ndmin - ndim(out_array)))
return out_array


def _convert_to_array_if_dtype_fails(x):
Expand All @@ -1918,7 +1922,7 @@ def _convert_to_array_if_dtype_fails(x):


@_wraps(np.asarray, lax_description=_ARRAY_DOC)
def asarray(a, dtype=None, order=None):
def asarray(a: Any, dtype: Optional[DTypeLike] = None, order: Any = None) -> Array:
lax_internal._check_user_dtype_supported(dtype, "asarray")
dtype = dtypes.canonicalize_dtype(dtype) if dtype is not None else dtype
return array(a, dtype=dtype, copy=False, order=order)
Expand Down
64 changes: 33 additions & 31 deletions jax/_src/numpy/util.py
Expand Up @@ -16,7 +16,7 @@
import re
import textwrap
from typing import (
Any, Callable, NamedTuple, Optional, Dict, Sequence, Set, Type, TypeVar
Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Type, TypeVar
)
import warnings

Expand All @@ -28,6 +28,7 @@
from jax._src import api
from jax import core
from jax._src.lax import lax
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, Shape

import numpy as np

Expand Down Expand Up @@ -215,7 +216,7 @@ def wrap(op):

_dtype = partial(dtypes.dtype, canonicalize=True)

def _asarray(arr):
def _asarray(arr: ArrayLike) -> Array:
"""
Pared-down utility to convert object to a DeviceArray.
Note this will not correctly handle lists or tuples.
Expand All @@ -224,10 +225,10 @@ def _asarray(arr):
dtype, weak_type = dtypes._lattice_result_type(arr)
return lax_internal._convert_element_type(arr, dtype, weak_type)

def _promote_shapes(fun_name, *args):
def _promote_shapes(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Apply NumPy-style broadcasting, making args shape-compatible for lax.py."""
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
shapes = [np.shape(arg) for arg in args]
if config.jax_dynamic_shapes:
Expand All @@ -238,10 +239,10 @@ def _promote_shapes(fun_name, *args):
return [_broadcast_to(arg, res_shape) for arg, shp in zip(args, shapes)]
else:
if all(len(shapes[0]) == len(s) for s in shapes[1:]):
return args # no need for rank promotion, so rely on lax promotion
return [_asarray(arg) for arg in args] # no need for rank promotion, so rely on lax promotion
nonscalar_ranks = {len(shp) for shp in shapes if shp}
if len(nonscalar_ranks) < 2:
return args # rely on lax scalar promotion
return [_asarray(arg) for arg in args] # rely on lax scalar promotion
else:
if config.jax_numpy_rank_promotion != "allow":
_rank_promotion_warning_or_error(fun_name, shapes)
Expand All @@ -250,7 +251,7 @@ def _promote_shapes(fun_name, *args):
for arg, shp in zip(args, shapes)]


def _rank_promotion_warning_or_error(fun_name, shapes):
def _rank_promotion_warning_or_error(fun_name: str, shapes: Sequence[Shape]):
if config.jax_numpy_rank_promotion == "warn":
msg = ("Following NumPy automatic rank promotion for {} on shapes {}. "
"Set the jax_numpy_rank_promotion config option to 'allow' to "
Expand All @@ -265,18 +266,18 @@ def _rank_promotion_warning_or_error(fun_name, shapes):
raise ValueError(msg.format(fun_name, ' '.join(map(str, shapes))))


def _promote_dtypes(*args):
def _promote_dtypes(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion."""
# TODO(dougalm,mattjj): This is a performance bottleneck. Consider memoizing.
if len(args) < 2:
return args
return [_asarray(arg) for arg in args]
else:
to_dtype, weak_type = dtypes._lattice_result_type(*args)
to_dtype = dtypes.canonicalize_dtype(to_dtype)
return [lax_internal._convert_element_type(x, to_dtype, weak_type) for x in args]


def _promote_dtypes_inexact(*args):
def _promote_dtypes_inexact(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to an inexact type."""
Expand All @@ -287,7 +288,7 @@ def _promote_dtypes_inexact(*args):
for x in args]


def _promote_dtypes_numeric(*args):
def _promote_dtypes_numeric(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a numeric (non-bool) type."""
Expand All @@ -298,7 +299,7 @@ def _promote_dtypes_numeric(*args):
for x in args]


def _promote_dtypes_complex(*args):
def _promote_dtypes_complex(*args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument dtype promotion.
Promotes arguments to a complex type."""
Expand All @@ -309,23 +310,23 @@ def _promote_dtypes_complex(*args):
for x in args]


def _complex_elem_type(dtype):
def _complex_elem_type(dtype: DTypeLike) -> DType:
"""Returns the float type of the real/imaginary parts of a complex dtype."""
return np.abs(np.zeros((), dtype)).dtype


def _arraylike(x):
def _arraylike(x: ArrayLike) -> bool:
return (isinstance(x, np.ndarray) or isinstance(x, ndarray) or
hasattr(x, '__jax_array__') or np.isscalar(x))


def _stackable(*args):
def _stackable(*args: Any) -> bool:
return all(type(arg) in stackables for arg in args)
stackables: Set[Type] = set()
_register_stackable: Callable[[Type], None] = stackables.add


def _check_arraylike(fun_name, *args):
def _check_arraylike(fun_name: str, *args: Any):
"""Check if all args fit JAX's definition of arraylike."""
assert isinstance(fun_name, str), f"fun_name must be a string. Got {fun_name}"
if any(not _arraylike(arg) for arg in args):
Expand All @@ -335,7 +336,7 @@ def _check_arraylike(fun_name, *args):
raise TypeError(msg.format(fun_name, type(arg), pos))


def _check_no_float0s(fun_name, *args):
def _check_no_float0s(fun_name: str, *args: Any):
"""Check if none of the args have dtype float0."""
if any(dtypes.dtype(arg) == dtypes.float0 for arg in args):
raise TypeError(
Expand All @@ -348,20 +349,20 @@ def _check_no_float0s(fun_name, *args):
"taken a gradient with respect to an integer argument.")


def _promote_args(fun_name, *args):
def _promote_args(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion."""
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes(*args))


def _promote_args_numeric(fun_name, *args):
def _promote_args_numeric(fun_name: str, *args: ArrayLike) -> List[Array]:
_check_arraylike(fun_name, *args)
_check_no_float0s(fun_name, *args)
return _promote_shapes(fun_name, *_promote_dtypes_numeric(*args))


def _promote_args_inexact(fun_name, *args):
def _promote_args_inexact(fun_name: str, *args: ArrayLike) -> List[Array]:
"""Convenience function to apply Numpy argument shape and dtype promotion.
Promotes non-inexact types to an inexact type."""
Expand All @@ -371,20 +372,18 @@ def _promote_args_inexact(fun_name, *args):


@partial(api.jit, inline=True)
def _broadcast_arrays(*args):
def _broadcast_arrays(*args: ArrayLike) -> List[Array]:
"""Like Numpy's broadcast_arrays but doesn't return views."""
shapes = [np.shape(arg) for arg in args]
if not shapes or all(core.symbolic_equal_shape(shapes[0], s) for s in shapes):
# TODO(mattjj): remove the array(arg) here
return [arg if isinstance(arg, ndarray) or np.isscalar(arg) else _asarray(arg)
for arg in args]
return [_asarray(arg) for arg in args]
result_shape = lax.broadcast_shapes(*shapes)
return [_broadcast_to(arg, result_shape) for arg in args]


def _broadcast_to(arr, shape):
def _broadcast_to(arr: ArrayLike, shape: Shape) -> Array:
if hasattr(arr, "broadcast_to"):
return arr.broadcast_to(shape)
return arr.broadcast_to(shape) # type: ignore[union-attr]
_check_arraylike("broadcast_to", arr)
arr = arr if isinstance(arr, ndarray) else _asarray(arr)
if not isinstance(shape, tuple) and np.ndim(shape) == 0:
Expand Down Expand Up @@ -412,15 +411,18 @@ def _broadcast_to(arr, shape):
# `np.where(np.zeros(1000), 7, 4)`. In op-by-op mode, we don't want to
# materialize the broadcast forms of scalar arguments.
@api.jit
def _where(condition, x=None, y=None):
def _where(condition: ArrayLike, x: Optional[ArrayLike] = None,
y: Optional[ArrayLike] = None) -> Array:
if x is None or y is None:
raise ValueError("Either both or neither of the x and y arguments should "
"be provided to jax.numpy.where, got {} and {}."
.format(x, y))
if not np.issubdtype(_dtype(condition), np.bool_):
condition = lax.ne(condition, lax_internal._zero(condition))
x, y = _promote_dtypes(x, y)
condition, x, y = _broadcast_arrays(condition, x, y)
try: is_always_empty = core.is_empty_shape(np.shape(x))
except: is_always_empty = False # can fail with dynamic shapes
return lax.select(condition, x, y) if not is_always_empty else x
condition_arr, x_arr, y_arr = _broadcast_arrays(condition, x, y)
try:
is_always_empty = core.is_empty_shape(x_arr.shape)
except:
is_always_empty = False # can fail with dynamic shapes
return lax.select(condition_arr, x_arr, y_arr) if not is_always_empty else x_arr

0 comments on commit 069866e

Please sign in to comment.