Skip to content

Commit

Permalink
lax_numpy: move poly functions into numpy.polynomial
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Mar 17, 2022
1 parent 2d79a64 commit 603bb3c
Show file tree
Hide file tree
Showing 3 changed files with 233 additions and 231 deletions.
198 changes: 0 additions & 198 deletions jax/_src/numpy/lax_numpy.py
Expand Up @@ -702,76 +702,6 @@ def gradient_along_axis(a, h, axis):
def isrealobj(x):
return not iscomplexobj(x)

_POLYFIT_DOC = """\
Unlike NumPy's implementation of polyfit, :py:func:`jax.numpy.polyfit` will not warn on rank reduction, which indicates an ill conditioned matrix
Also, it works best on rcond <= 10e-3 values.
"""
@_wraps(np.polyfit, lax_description=_POLYFIT_DOC)
@partial(jit, static_argnames=('deg', 'rcond', 'full', 'cov'))
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
_check_arraylike("polyfit", x, y)
deg = core.concrete_or_error(int, deg, "deg must be int")
order = deg + 1
# check arguments
if deg < 0:
raise ValueError("expected deg >= 0")
if x.ndim != 1:
raise TypeError("expected 1D vector for x")
if x.size == 0:
raise TypeError("expected non-empty vector for x")
if y.ndim < 1 or y.ndim > 2:
raise TypeError("expected 1D or 2D array for y")
if x.shape[0] != y.shape[0]:
raise TypeError("expected x and y to have same length")

# set rcond
if rcond is None:
rcond = len(x)*finfo(x.dtype).eps
rcond = core.concrete_or_error(float, rcond, "rcond must be float")
# set up least squares equation for powers of x
lhs = vander(x, order)
rhs = y

# apply weighting
if w is not None:
_check_arraylike("polyfit", w)
w, = _promote_dtypes_inexact(w)
if w.ndim != 1:
raise TypeError("expected a 1-d array for weights")
if w.shape[0] != y.shape[0]:
raise TypeError("expected w and y to have the same length")
lhs *= w[:, newaxis]
if rhs.ndim == 2:
rhs *= w[:, newaxis]
else:
rhs *= w

# scale lhs to improve condition number and solve
scale = sqrt((lhs*lhs).sum(axis=0))
lhs /= scale[newaxis,:]
from jax._src.numpy import linalg
c, resids, rank, s = linalg.lstsq(lhs, rhs, rcond)
c = (c.T/scale).T # broadcast scale coefficients

if full:
return c, resids, rank, s, rcond
elif cov:
Vbase = linalg.inv(dot(lhs.T, lhs))
Vbase /= outer(scale, scale)
if cov == "unscaled":
fac = 1
else:
if len(x) <= order:
raise ValueError("the number of data points must exceed order "
"to scale the covariance matrix")
fac = resids / (len(x) - order)
fac = fac[0] #making np.array() of shape (1,) to int
if y.ndim == 1:
return c, Vbase * fac
else:
return c, Vbase[:,:, newaxis] * fac
else:
return c


@_wraps(np.reshape, lax_description=_ARRAY_VIEW_DOC)
Expand Down Expand Up @@ -3242,107 +3172,6 @@ def diagflat(v, k=0):
res = res.reshape(adj_length, adj_length)
return res

_POLY_DOC = """\
This differs from np.poly when an integer array is given.
np.poly returns a result with dtype float64 in this case.
jax returns a result with an inexact type, but not necessarily
float64.
This also differs from np.poly when the input array strictly
contains pairs of complex conjugates, e.g. [1j, -1j, 1-1j, 1+1j].
np.poly returns an array with a real dtype in such cases.
jax returns an array with a complex dtype in such cases.
"""

@_wraps(np.poly, lax_description=_POLY_DOC)
@jit
def poly(seq_of_zeros):
_check_arraylike('poly', seq_of_zeros)
seq_of_zeros, = _promote_dtypes_inexact(seq_of_zeros)
seq_of_zeros = atleast_1d(seq_of_zeros)

sh = seq_of_zeros.shape
if len(sh) == 2 and sh[0] == sh[1] and sh[0] != 0:
# import at runtime to avoid circular import
from jax._src.numpy import linalg
seq_of_zeros = linalg.eigvals(seq_of_zeros)

if seq_of_zeros.ndim != 1:
raise ValueError("input must be 1d or non-empty square 2d array.")

dt = seq_of_zeros.dtype
if len(seq_of_zeros) == 0:
return ones((), dtype=dt)

a = ones((1,), dtype=dt)
for k in range(len(seq_of_zeros)):
a = convolve(a, array([1, -seq_of_zeros[k]], dtype=dt), mode='full')

return a


@_wraps(np.polyval, lax_description="""\
The ``unroll`` parameter is JAX specific. It does not effect correctness but can
have a major impact on performance for evaluating high-order polynomials. The
parameter controls the number of unrolled steps with ``lax.scan`` inside the
``polyval`` implementation. Consider setting ``unroll=128`` (or even higher) to
improve runtime performance on accelerators, at the cost of increased
compilation time.
""")
@partial(jax.jit, static_argnames=['unroll'])
def polyval(p, x, *, unroll=16):
_check_arraylike("polyval", p, x)
p, x = _promote_dtypes_inexact(p, x)
shape = lax.broadcast_shapes(p.shape[1:], x.shape)
y = lax.full_like(x, 0, shape=shape, dtype=x.dtype)
y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
return y

@_wraps(np.polyadd)
@jit
def polyadd(a1, a2):
_check_arraylike("polyadd", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
if a2.shape[0] <= a1.shape[0]:
return a1.at[-a2.shape[0]:].add(a2)
else:
return a2.at[-a1.shape[0]:].add(a1)


@_wraps(np.polyint)
@partial(jit, static_argnames=('m',))
def polyint(p, m=1, k=None):
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyint")
k = 0 if k is None else k
_check_arraylike("polyint", p, k)
p, k = _promote_dtypes_inexact(p, k)
if m < 0:
raise ValueError("Order of integral must be positive (see polyder)")
k = atleast_1d(k)
if len(k) == 1:
k = full((m,), k[0])
if k.shape != (m,):
raise ValueError("k must be a scalar or a rank-1 array of length 1 or m.")
if m == 0:
return p
else:
coeff = maximum(1, arange(len(p) + m, 0, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return true_divide(concatenate((p, k)), coeff)


@_wraps(np.polyder)
@partial(jit, static_argnames=('m',))
def polyder(p, m=1):
_check_arraylike("polyder", p)
m = core.concrete_or_error(operator.index, m, "'m' argument of jnp.polyder")
p, = _promote_dtypes_inexact(p)
if m < 0:
raise ValueError("Order of derivative must be positive")
if m == 0:
return p
coeff = (arange(len(p), m, -1)[newaxis, :] - 1 - arange(m)[:, newaxis]).prod(0)
return p[:-m] * coeff


@_wraps(np.trim_zeros)
def trim_zeros(filt, trim='fb'):
Expand All @@ -3356,33 +3185,6 @@ def trim_zeros(filt, trim='fb'):
return filt[start:len(filt) - end]


_LEADING_ZEROS_DOC = """\
Setting trim_leading_zeros=True makes the output match that of numpy.
But prevents the function from being able to be used in compiled code.
"""

@_wraps(np.polymul, lax_description=_LEADING_ZEROS_DOC)
def polymul(a1, a2, *, trim_leading_zeros=False):
_check_arraylike("polymul", a1, a2)
a1, a2 = _promote_dtypes_inexact(a1, a2)
if trim_leading_zeros and (len(a1) > 1 or len(a2) > 1):
a1, a2 = trim_zeros(a1, trim='f'), trim_zeros(a2, trim='f')
if len(a1) == 0:
a1 = asarray([0.])
if len(a2) == 0:
a2 = asarray([0.])
val = convolve(a1, a2, mode='full')
return val


@_wraps(np.polysub)
@jit
def polysub(a1, a2):
_check_arraylike("polysub", a1, a2)
a1, a2 = _promote_dtypes(a1, a2)
return polyadd(a1, -a2)


@_wraps(np.append)
@partial(jit, static_argnames=('axis',))
def append(arr, values, axis: Optional[int] = None):
Expand Down

0 comments on commit 603bb3c

Please sign in to comment.