Skip to content

Commit

Permalink
[typing] add types for jax.numpy.polynomial
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 5, 2022
1 parent dd0a455 commit 6a348f9
Showing 1 changed file with 47 additions and 44 deletions.
91 changes: 47 additions & 44 deletions jax/_src/numpy/polynomial.py
Expand Up @@ -15,6 +15,7 @@

from functools import partial
import operator
from typing import Optional, Tuple, Union

from jax import core
from jax import jit
Expand All @@ -26,11 +27,12 @@
vander, zeros)
from jax._src.numpy import linalg
from jax._src.numpy.util import _check_arraylike, _promote_dtypes, _promote_dtypes_inexact, _where, _wraps
from jax._src.typing import Array, ArrayLike
import numpy as np


@jit
def _roots_no_zeros(p):
def _roots_no_zeros(p: Array) -> Array:
# build companion matrix and find its eigenvalues (the roots)
if p.size < 2:
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
Expand All @@ -40,7 +42,7 @@ def _roots_no_zeros(p):


@jit
def _roots_with_zeros(p, num_leading_zeros):
def _roots_with_zeros(p: Array, num_leading_zeros: int) -> Array:
# Avoid lapack errors when p is all zero
p = _where(len(p) == num_leading_zeros, 1.0, p)
# Roll any leading zeros to the end & compute the roots
Expand Down Expand Up @@ -77,23 +79,23 @@ def _roots_with_zeros(p, num_leading_zeros):
``strip_zeros`` must be set to ``False`` for the function to be compatible with
:func:`jax.jit` and other JAX transformations.
""")
def roots(p, *, strip_zeros=True):
def roots(p: ArrayLike, *, strip_zeros: bool = True) -> Array:
_check_arraylike("roots", p)
p = atleast_1d(*_promote_dtypes_inexact(p))
if p.ndim != 1:
p_arr = atleast_1d(*_promote_dtypes_inexact(p))
if p_arr.ndim != 1:
raise ValueError("Input must be a rank-1 array.")
if p.size < 2:
return array([], dtype=dtypes.to_complex_dtype(p.dtype))
num_leading_zeros = _where(all(p == 0), len(p), argmin(p == 0))
if p_arr.size < 2:
return array([], dtype=dtypes.to_complex_dtype(p_arr.dtype))
num_leading_zeros = _where(all(p_arr == 0), len(p_arr), argmin(p_arr == 0))

if strip_zeros:
num_leading_zeros = core.concrete_or_error(int, num_leading_zeros,
"The error occurred in the jnp.roots() function. To use this within a "
"JIT-compiled context, pass strip_zeros=False, but be aware that leading zeros "
"will be result in some returned roots being set to NaN.")
return _roots_no_zeros(p[num_leading_zeros:])
return _roots_no_zeros(p_arr[num_leading_zeros:])
else:
return _roots_with_zeros(p, num_leading_zeros)
return _roots_with_zeros(p_arr, num_leading_zeros)


_POLYFIT_DOC = """\
Expand All @@ -102,7 +104,9 @@ def roots(p, *, strip_zeros=True):
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
def polyfit(x: Array, y: Array, deg: int, rcond: Optional[float] = None,
full: bool = False, w: Optional[Array] = None, cov: bool = False
) -> Union[Array, Tuple[Array, ...]]:
_check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
Expand Down Expand Up @@ -147,7 +151,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
c = (c.T/scale).T # broadcast scale coefficients

if full:
return c, resids, rank, s, rcond
return c, resids, rank, s, asarray(rcond)
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
Expand Down Expand Up @@ -181,7 +185,7 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):

@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros):
def poly(seq_of_zeros: Array) -> Array:
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)
Expand Down Expand Up @@ -215,7 +219,7 @@ def poly(seq_of_zeros):
compilation time.
""")
@partial(jit, static_argnames=['unroll'])
def polyval(p, x, *, unroll=16):
def polyval(p: Array, x: Array, *, unroll: int = 16) -> Array:
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
Expand All @@ -225,7 +229,7 @@ def polyval(p, x, *, unroll=16):

@_wraps(np.polyadd)
@jit
def polyadd(a1, a2):
def polyadd(a1: Array, a2: Array) -> Array:
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
Expand All @@ -236,30 +240,30 @@ def polyadd(a1, a2):

@_wraps(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p, m=1, k=None):
def polyint(p: Array, m: int = 1, k: Optional[int] = None) -> Array:
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k = _promote_dtypes_inexact(p, k)
p, k_arr = _promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if k.shape != (m,):
k_arr = atleast_1d(k_arr)
if len(k_arr) == 1:
k_arr = full((m,), k_arr[0])
if k_arr.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
grid = (arange(len(p) + m, dtype=p.dtype)[np.newaxis]
- arange(m, dtype=p.dtype)[:, np.newaxis])
coeff = maximum(1, grid).prod(0)[::-1]
return true_divide(concatenate((p, k)), coeff)
return true_divide(concatenate((p, k_arr)), coeff)


@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p, m=1):
def polyder(p: Array, m: int = 1) -> Array:
_check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
Expand All @@ -281,38 +285,37 @@ def polyder(p, m=1):
"""

@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) -> Array:
_check_arraylike("polymul", a1, a2)
a1, a2 = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0], dtype=a2.dtype)
if len(a2) == 0:
a2 = asarray([0], dtype=a1.dtype)
return convolve(a1, a2, mode='full')
a1_arr, a2_arr = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1_arr) > 1 or len(a2_arr) > 1):
a1_arr, a2_arr = trim_zeros(a1_arr, trim='f'), trim_zeros(a2_arr, trim='f')
if len(a1_arr) == 0:
a1_arr = asarray([0], dtype=a2_arr.dtype)
if len(a2_arr) == 0:
a2_arr = asarray([0], dtype=a1_arr.dtype)
return convolve(a1_arr, a2_arr, mode='full')

@_wraps(np.polydiv, lax_description=_LEADING_ZEROS_DOC)
def polydiv(u, v, *, trim_leading_zeros=False):
def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> Tuple[Array, Array]:
_check_arraylike("polydiv", u, v)
u, v = _promote_dtypes_inexact(u, v)
m = len(u) - 1
n = len(v) - 1
scale = 1. / v[0]
q = zeros(max(m - n + 1, 1), dtype = u.dtype) # force same dtype
u_arr, v_arr = _promote_dtypes_inexact(u, v)
m = len(u_arr) - 1
n = len(v_arr) - 1
scale = 1. / v_arr[0]
q: Array = zeros(max(m - n + 1, 1), dtype = u_arr.dtype) # force same dtype
for k in range(0, m-n+1):
d = scale * u[k]
d = scale * u_arr[k]
q = q.at[k].set(d)
u = u.at[k:k+n+1].add(-d*v)
u_arr = u_arr.at[k:k+n+1].add(-d*v_arr)
if trim_leading_zeros:
# use the square root of finfo(dtype) to approximate the absolute tolerance used in numpy
return q, trim_zeros_tol(u, tol=sqrt(finfo(u.dtype).eps), trim='f')
else:
return q, u
u_arr = trim_zeros_tol(u_arr, tol=sqrt(finfo(u_arr.dtype).eps), trim='f')
return q, u_arr

@_wraps(np.polysub)
@jit
def polysub(a1, a2):
def polysub(a1: Array, a2: Array) -> Array:
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
return polyadd(a1, -a2)

0 comments on commit 6a348f9

Please sign in to comment.