Skip to content

Commit

Permalink
Remove last scipy imports
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 3, 2024
1 parent c0cfc7a commit ff67e51
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 31 deletions.
4 changes: 2 additions & 2 deletions jax/_src/scipy/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
>>> x = jnp.array([1., 2., 3., 2., 1., 0., 1., 2.])
>>> f, t, Zxx = jax.scipy.signal.stft(x, nperseg=4)
>>> print(Zxx)
>>> print(Zxx) # doctest: +SKIP
[[ 1. +0.j 2.5+0.j 1. +0.j 1. +0.j 0.5+0.j ]
[-0.5+0.5j -1.5+0.j -0.5-0.5j -0.5+0.5j 0. -0.5j]
[ 0. +0.j 0.5+0.j 0. +0.j 0. +0.j -0.5+0.j ]]
Expand Down Expand Up @@ -1060,7 +1060,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann',
# Get window as array
if window == 'hann':
# Implement the default case without scipy
win = jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2
win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2
win = win.astype(xsubs.dtype)
elif isinstance(window, (str, tuple)):
# TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency
Expand Down
35 changes: 24 additions & 11 deletions jax/_src/third_party/scipy/interpolate.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from itertools import product
import scipy.interpolate as osp_interpolate

from jax.numpy import (asarray, broadcast_arrays, can_cast,
empty, nan, searchsorted, where, zeros)
from jax._src.tree_util import register_pytree_node
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact, implements
from jax._src.numpy.util import check_arraylike, promote_dtypes_inexact


def _ndim_coords_from_arrays(points, ndim=None):
Expand All @@ -31,15 +30,30 @@ def _ndim_coords_from_arrays(points, ndim=None):
return points


@implements(
osp_interpolate.RegularGridInterpolator,
lax_description="""
In the JAX version, `bounds_error` defaults to and must always be `False` since no
bound error may be raised under JIT.
Furthermore, in contrast to SciPy no input validation is performed.
""")
class RegularGridInterpolator:
"""Interpolate points on a regular rectangular grid.
JAX implementation of :func:`scipy.interpolate.RegularGridInterpolator`.
Args:
points: length-N sequence of arrays specifying the grid coordinates.
values: N-dimensional array specifying the grid values.
method: interpolation method, either ``"linear"`` or ``"nearest"``.
bounds_error: not implemented by JAX
fill_value: value returned for points outside the grid, defaults to NaN.
Returns:
interpolator: callable interpolation object.
Example:
>>> points = (jnp.array([1, 2, 3]), jnp.array([4, 5, 6]))
>>> values = jnp.array([[10, 20, 30], [40, 50, 60], [70, 80, 90]])
>>> interpolate = RegularGridInterpolator(points, values, method='linear')
>>> query_points = jnp.array([[1.5, 4.5], [2.2, 5.8]])
>>> interpolate(query_points)
Array([30., 64.], dtype=float32)
"""
# Based on SciPy's implementation which in turn is originally based on an
# implementation by Johannes Buchner

Expand Down Expand Up @@ -76,7 +90,6 @@ def __init__(self,
self.grid = tuple(asarray(p) for p in points)
self.values = values

@implements(osp_interpolate.RegularGridInterpolator.__call__, update_doc=False)
def __call__(self, xi, method=None):
method = self.method if method is None else method
if method not in ("linear", "nearest"):
Expand Down
56 changes: 42 additions & 14 deletions jax/_src/third_party/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

from typing import Callable

import scipy.linalg

from jax import jit, lax
import jax.numpy as jnp
from jax._src.numpy.linalg import norm
from jax._src.numpy.util import implements
from jax._src.scipy.linalg import rsf2csf, schur
from jax._src.typing import ArrayLike, Array

Expand Down Expand Up @@ -40,20 +37,51 @@ def _inner_loop(i, p_F_minden):

return lax.fori_loop(1, N, _outer_loop, (F, minden))

_FUNM_LAX_DESCRIPTION = """\
The array returned by :py:func:`jax.scipy.linalg.funm` may differ in dtype
from the array returned by py:func:`scipy.linalg.funm`. Specifically, in cases
where all imaginary parts of the array values are close to zero, the SciPy
function may return a real-valued array, whereas the JAX implementation will
return a complex-valued array.
Additionally, unlike the SciPy implementation, when ``disp=True`` no warning
will be printed if the error in the array output is estimated to be large.
"""

@implements(scipy.linalg.funm, lax_description=_FUNM_LAX_DESCRIPTION)
def funm(A: ArrayLike, func: Callable[[Array], Array],
disp: bool = True) -> Array | tuple[Array, Array]:
"""Evaluate a matrix-valued function
JAX implementation of :func:`scipy.linalg.funm`.
Args:
A: array of shape ``(N, N)`` for which the function is to be computed.
func: Callable object that takes a scalar argument and returns a scalar result.
Represents the function to be evaluated over the eigenvalues of A.
disp: If true (default), error information is not returned. Unlike scipy's version JAX
does not attempt to display information at runtime.
compute_expm: (N, N) array_like or None, optional.
If provided, the matrix exponential of A. This is used for improving efficiency when `func`
is the exponential function. If not provided, it is computed internally.
Defaults to None.
Returns:
Array of same shape as ``A``, containing the result of ``func`` evaluated on the
eigenvalues of ``A``.
Notes:
The returned dtype of JAX's implementation may differ from that of scipy;
specifically, in cases where all imaginary parts of the array values are
close to zero, the SciPy function may return a real-valued array, whereas
the JAX implementation will return a complex-valued array.
Example:
Applying an arbitrary matrix function:
>>> A = jnp.array([[1., 2.], [3., 4.]])
>>> def func(x):
... return jnp.sin(x) + 2 * jnp.cos(x)
>>> jax.scipy.linalg.funm(A, func) # doctest: +SKIP
Array([[ 1.2452652 +0.j, -0.3701772 +0.j],
[-0.55526584+0.j, 0.6899995 +0.j]], dtype=complex64)
Comparing two ways of computing the matrix exponent:
>>> expA_1 = jax.scipy.linalg.funm(A, jnp.exp)
>>> expA_2 = jax.scipy.linalg.expm(A)
>>> jnp.allclose(expA_1, expA_2, rtol=1E-4)
Array(True, dtype=bool)
"""
A_arr = jnp.asarray(A)
if A_arr.ndim != 2 or A_arr.shape[0] != A_arr.shape[1]:
raise ValueError('expected square array_like input')
Expand Down
17 changes: 13 additions & 4 deletions jax/_src/third_party/scipy/signal_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import scipy.signal as osp_signal
from typing import Any
import warnings

Expand Down Expand Up @@ -43,10 +42,20 @@ def _triage_segments(window: ArrayLike | str | tuple[Any, ...], nperseg: int | N
if isinstance(window, (str, tuple)):
nperseg_int = input_length if nperseg is None else int(nperseg)
if nperseg_int > input_length:
warnings.warn(f'nperseg = {nperseg_int} is greater than input length '
f' = {input_length}, using nperseg = {input_length}')
warnings.warn(f'nperseg={nperseg_int} is greater than {input_length=},'
f' using nperseg={input_length}')
nperseg_int = input_length
win = jnp.array(osp_signal.get_window(window, nperseg_int), dtype=dtype)
if window == 'hann':
# Implement the default case without scipy
win = jnp.array([1.0]) if nperseg_int == 1 else jnp.sin(jnp.linspace(0, jnp.pi, nperseg_int, endpoint=False)) ** 2
else:
# TODO(jakevdp): implement get_window() in JAX to remove optional scipy dependency
try:
from scipy.signal import get_window
except ImportError as err:
raise ImportError(f"scipy must be available to use {window=}") from err
win = get_window(window, nperseg_int)
win = jnp.array(win, dtype=dtype)
else:
win = jnp.asarray(window)
nperseg_int = win.size if nperseg is None else int(nperseg)
Expand Down

0 comments on commit ff67e51

Please sign in to comment.