Skip to content

Commit

Permalink
[typing] add annotations to jax.numpy.linalg
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 4, 2022
1 parent a60ca9f commit 78ed03c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 53 deletions.
8 changes: 7 additions & 1 deletion jax/_src/lax/lax.py
Expand Up @@ -127,8 +127,14 @@ def _try_broadcast_shapes(
return None
return tuple(result_shape)

@overload
def broadcast_shapes(*shapes: Tuple[int, ...]) -> Tuple[int, ...]: ...

@overload
def broadcast_shapes(*shapes: Tuple[Union[int, core.Tracer], ...]
) -> Tuple[Union[int, core.Tracer], ...]:
) -> Tuple[Union[int, core.Tracer], ...]: ...

def broadcast_shapes(*shapes):
"""Returns the shape that results from NumPy broadcasting of `shapes`."""
# NOTE: We have both cached and uncached versions to handle Tracers in shapes.
try:
Expand Down
111 changes: 59 additions & 52 deletions jax/_src/numpy/linalg.py
Expand Up @@ -28,27 +28,27 @@
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy.util import _wraps, _promote_dtypes_inexact
from jax._src.util import canonicalize_axis
from jax._src.typing import ArrayLike, Array


def _T(x):
def _T(x: ArrayLike) -> Array:
return jnp.swapaxes(x, -1, -2)


def _H(x):
def _H(x: ArrayLike) -> Array:
return jnp.conjugate(jnp.swapaxes(x, -1, -2))


@_wraps(np.linalg.cholesky)
@jit
def cholesky(a):
def cholesky(a: ArrayLike) -> Array:
a, = _promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.cholesky(a)


@_wraps(np.linalg.svd)
@partial(jit, static_argnames=('full_matrices', 'compute_uv', 'hermitian'))
def svd(a, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False):
def svd(a: ArrayLike, full_matrices: bool = True, compute_uv: bool = True,
hermitian: bool = False) -> Union[Array, Tuple[Array, Array, Array]]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
if hermitian:
w, v = lax_linalg.eigh(a)
Expand All @@ -71,45 +71,46 @@ def svd(a, full_matrices: bool = True, compute_uv: bool = True,

@_wraps(np.linalg.matrix_power)
@partial(jit, static_argnames=('n',))
def matrix_power(a, n):
a, = _promote_dtypes_inexact(jnp.asarray(a))
def matrix_power(a: ArrayLike, n: int) -> Array:
# TODO(jakevdp): call _check_arraylike
arr, = _promote_dtypes_inexact(jnp.asarray(a))

if a.ndim < 2:
if arr.ndim < 2:
raise TypeError("{}-dimensional array given. Array must be at least "
"two-dimensional".format(a.ndim))
if a.shape[-2] != a.shape[-1]:
"two-dimensional".format(arr.ndim))
if arr.shape[-2] != arr.shape[-1]:
raise TypeError("Last 2 dimensions of the array must be square")
try:
n = operator.index(n)
except TypeError as err:
raise TypeError(f"exponent must be an integer, got {n}") from err

if n == 0:
return jnp.broadcast_to(jnp.eye(a.shape[-2], dtype=a.dtype), a.shape)
return jnp.broadcast_to(jnp.eye(arr.shape[-2], dtype=arr.dtype), arr.shape)
elif n < 0:
a = inv(a)
n = np.abs(n)
arr = inv(arr)
n = abs(n)

if n == 1:
return a
return arr
elif n == 2:
return a @ a
return arr @ arr
elif n == 3:
return (a @ a) @ a
return (arr @ arr) @ arr

z = result = None
while n > 0:
z = a if z is None else (z @ z)
z = arr if z is None else (z @ z) # type: ignore[operator]
n, bit = divmod(n, 2)
if bit:
result = z if result is None else (result @ z)

assert result is not None
return result


@_wraps(np.linalg.matrix_rank)
@jit
def matrix_rank(M, tol=None):
def matrix_rank(M: ArrayLike, tol: Optional[ArrayLike] = None) -> Array:
M, = _promote_dtypes_inexact(jnp.asarray(M))
if M.ndim < 2:
return jnp.any(M != 0).astype(jnp.int32)
Expand All @@ -121,7 +122,7 @@ def matrix_rank(M, tol=None):


@custom_jvp
def _slogdet_lu(a):
def _slogdet_lu(a: Array) -> Tuple[Array, Array]:
dtype = lax.dtype(a)
lu, pivot, _ = lax_linalg.lu(a)
diag = jnp.diagonal(lu, axis1=-2, axis2=-1)
Expand All @@ -143,7 +144,7 @@ def _slogdet_lu(a):
return sign, jnp.real(logdet)

@custom_jvp
def _slogdet_qr(a):
def _slogdet_qr(a: Array) -> Tuple[Array, Array]:
# Implementation of slogdet using QR decomposition. One reason we might prefer
# QR decomposition is that it is more amenable to a fast batched
# implementation on TPU because of the lack of row pivoting.
Expand Down Expand Up @@ -171,7 +172,7 @@ def _slogdet_qr(a):
LU decomposition if ``None``.
"""))
@partial(jit, static_argnames=('method',))
def slogdet(a, *, method: Optional[str] = None):
def slogdet(a: ArrayLike, *, method: Optional[str] = None) -> Tuple[Array, Array]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) < 2 or a_shape[-1] != a_shape[-2]:
Expand Down Expand Up @@ -201,7 +202,7 @@ def _slogdet_jvp(primals, tangents):
_slogdet_lu.defjvp(_slogdet_jvp)
_slogdet_qr.defjvp(_slogdet_jvp)

def _cofactor_solve(a, b):
def _cofactor_solve(a: ArrayLike, b: ArrayLike) -> Tuple[Array, Array]:
"""Equivalent to det(a)*solve(a, b) for nonsingular mat.
Intermediate function used for jvp and vjp of det.
Expand Down Expand Up @@ -273,8 +274,8 @@ def _cofactor_solve(a, b):
# partial_det[:, -2] contains det(u) / u_{nn}.
partial_det = jnp.cumprod(diag, axis=-1) * sign[..., None]
lu = lu.at[..., -1, -1].set(1.0 / partial_det[..., -2])
permutation = jnp.broadcast_to(permutation, batch_dims + (a_shape[-1],))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,)))
permutation = jnp.broadcast_to(permutation, (*batch_dims, a_shape[-1]))
iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in (*batch_dims, 1)))
# filter out any matrices that are not full rank
d = jnp.ones(x.shape[:-1], x.dtype)
d = lax_linalg.triangular_solve(lu, d, left_side=True, lower=False)
Expand All @@ -292,12 +293,12 @@ def _cofactor_solve(a, b):
return partial_det[..., -1], x


def _det_2x2(a):
def _det_2x2(a: Array) -> Array:
return (a[..., 0, 0] * a[..., 1, 1] -
a[..., 0, 1] * a[..., 1, 0])


def _det_3x3(a):
def _det_3x3(a: Array) -> Array:
return (a[..., 0, 0] * a[..., 1, 1] * a[..., 2, 2] +
a[..., 0, 1] * a[..., 1, 2] * a[..., 2, 0] +
a[..., 0, 2] * a[..., 1, 0] * a[..., 2, 1] -
Expand All @@ -309,7 +310,7 @@ def _det_3x3(a):
@custom_jvp
@_wraps(np.linalg.det)
@jit
def det(a):
def det(a: ArrayLike) -> Array:
a, = _promote_dtypes_inexact(jnp.asarray(a))
a_shape = jnp.shape(a)
if len(a_shape) >= 2 and a_shape[-1] == 2 and a_shape[-2] == 2:
Expand Down Expand Up @@ -341,21 +342,22 @@ def _det_jvp(primals, tangents):
backend. However eigendecomposition for symmetric/Hermitian matrices is
implemented more widely (see :func:`jax.numpy.linalg.eigh`).
""")
def eig(a):
def eig(a: ArrayLike) -> Tuple[Array, Array]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.eig(a, compute_left_eigenvectors=False)


@_wraps(np.linalg.eigvals)
@jit
def eigvals(a):
def eigvals(a: ArrayLike) -> Array:
return lax_linalg.eig(a, compute_left_eigenvectors=False,
compute_right_eigenvectors=False)[0]


@_wraps(np.linalg.eigh)
@partial(jit, static_argnames=('UPLO', 'symmetrize_input'))
def eigh(a, UPLO=None, symmetrize_input=True):
def eigh(a: ArrayLike, UPLO: Optional[str] = None,
symmetrize_input: bool = True) -> Tuple[Array, Array]:
if UPLO is None or UPLO == "L":
lower = True
elif UPLO == "U":
Expand All @@ -371,7 +373,7 @@ def eigh(a, UPLO=None, symmetrize_input=True):

@_wraps(np.linalg.eigvalsh)
@partial(jit, static_argnames=('UPLO',))
def eigvalsh(a, UPLO='L'):
def eigvalsh(a: ArrayLike, UPLO: Optional[str] = 'L') -> Array:
w, _ = eigh(a, UPLO)
return w

Expand All @@ -383,22 +385,22 @@ def eigvalsh(a, UPLO='L'):
`10. * max(num_rows, num_cols) * jnp.finfo(dtype).eps`.
"""))
@jit
def pinv(a, rcond=None):
def pinv(a: ArrayLike, rcond: Optional[ArrayLike] = None) -> Array:
# Uses same algorithm as
# https://github.com/numpy/numpy/blob/v1.17.0/numpy/linalg/linalg.py#L1890-L1979
a = jnp.conj(a)
arr = jnp.conj(a)
if rcond is None:
max_rows_cols = max(a.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(a.dtype).eps)
max_rows_cols = max(arr.shape[-2:])
rcond = 10. * max_rows_cols * jnp.array(jnp.finfo(arr.dtype).eps)
rcond = jnp.asarray(rcond)
u, s, vh = svd(a, full_matrices=False)
u, s, vh = svd(arr, full_matrices=False)
# Singular values less than or equal to ``rcond * largest_singular_value``
# are set to zero.
rcond = lax.expand_dims(rcond[..., jnp.newaxis], range(s.ndim - rcond.ndim - 1))
cutoff = rcond * jnp.amax(s, axis=-1, keepdims=True, initial=-jnp.inf)
s = jnp.where(s > cutoff, s, jnp.inf).astype(u.dtype)
res = jnp.matmul(_T(vh), jnp.divide(_T(u), s[..., jnp.newaxis]))
return lax.convert_element_type(res, a.dtype)
return lax.convert_element_type(res, arr.dtype)


@pinv.defjvp
Expand All @@ -423,18 +425,21 @@ def _pinv_jvp(rcond, primals, tangents):

@_wraps(np.linalg.inv)
@jit
def inv(a):
if jnp.ndim(a) < 2 or a.shape[-1] != a.shape[-2]:
def inv(a: ArrayLike) -> Array:
# TODO(jakevdp): call _check_arraylike
arr = jnp.asarray(a)
if arr.ndim < 2 or arr.shape[-1] != arr.shape[-2]:
raise ValueError(
f"Argument to inv must have shape [..., n, n], got {a.shape}.")
f"Argument to inv must have shape [..., n, n], got {arr.shape}.")
return solve(
a, lax.broadcast(jnp.eye(a.shape[-1], dtype=lax.dtype(a)), a.shape[:-2]))
arr, lax.broadcast(jnp.eye(arr.shape[-1], dtype=arr.dtype), arr.shape[:-2]))


@_wraps(np.linalg.norm)
@partial(jit, static_argnames=('ord', 'axis', 'keepdims'))
def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
keepdims=False):
def norm(x: ArrayLike, ord: Union[int, str, None] = None,
axis: Union[None, Tuple[int, ...], int] = None,
keepdims: bool = False) -> Array:
x, = _promote_dtypes_inexact(jnp.asarray(x))
x_shape = jnp.shape(x)
ndim = len(x_shape)
Expand Down Expand Up @@ -477,9 +482,9 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,
raise ValueError(msg)
else:
abs_x = jnp.abs(x)
ord = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord)
out = jnp.sum(abs_x ** ord, axis=axis, keepdims=keepdims)
ord_arr = lax_internal._const(abs_x, ord)
ord_inv = lax_internal._const(abs_x, 1. / ord_arr)
out = jnp.sum(abs_x ** ord_arr, axis=axis, keepdims=keepdims)
return jnp.power(out, ord_inv)

elif num_axes == 2:
Expand Down Expand Up @@ -529,7 +534,7 @@ def norm(x, ord=None, axis : Union[None, Tuple[int, ...], int] = None,

@_wraps(np.linalg.qr)
@partial(jit, static_argnames=('mode',))
def qr(a, mode="reduced"):
def qr(a: ArrayLike, mode: str = "reduced") -> Tuple[Array, Array]:
a, = _promote_dtypes_inexact(jnp.asarray(a))
if mode == "raw":
a, taus = lax_linalg.geqrf(a)
Expand All @@ -548,12 +553,13 @@ def qr(a, mode="reduced"):

@_wraps(np.linalg.solve)
@jit
def solve(a, b):
def solve(a: ArrayLike, b: ArrayLike) -> Array:
a, b = _promote_dtypes_inexact(jnp.asarray(a), jnp.asarray(b))
return lax_linalg._solve(a, b)


def _lstsq(a, b, rcond, *, numpy_resid=False):
def _lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float], *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
# TODO: add lstsq to lax_linalg and implement this function via those wrappers.
# TODO: add custom jvp rule for more robust lstsq differentiation
a, b = _promote_dtypes_inexact(a, b)
Expand Down Expand Up @@ -607,7 +613,8 @@ def _lstsq(a, b, rcond, *, numpy_resid=False):
The lstsq function does not currently have a custom JVP rule, so the gradient is
poorly behaved for some inputs, particularly for low-rank `a`.
"""))
def lstsq(a, b, rcond=None, *, numpy_resid=False):
def lstsq(a: ArrayLike, b: ArrayLike, rcond: Optional[float] = None, *,
numpy_resid: bool = False) -> Tuple[Array, Array, Array, Array]:
if numpy_resid:
return _lstsq(a, b, rcond, numpy_resid=True)
return _jit_lstsq(a, b, rcond)

0 comments on commit 78ed03c

Please sign in to comment.