Skip to content

Commit

Permalink
Add jax.scipy.signal.detrend (#3516)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jun 23, 2020
1 parent ca5b0b1 commit 33c455a
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
33 changes: 33 additions & 0 deletions jax/scipy/signal.py
Expand Up @@ -15,8 +15,11 @@
import scipy.signal as osp_signal
import warnings

import numpy as np

from .. import lax
from ..numpy import lax_numpy as jnp
from ..numpy import linalg
from ..numpy.lax_numpy import _promote_dtypes_inexact
from ..numpy._util import _wraps

Expand Down Expand Up @@ -104,3 +107,33 @@ def correlate2d(in1, in2, mode='full', boundary='fill', fillvalue=0,
if jnp.ndim(in1) != 2 or jnp.ndim(in2) != 2:
raise ValueError("correlate2d() only supports {ndim}-dimensional inputs.")
return _convolve_nd(in1[::-1, ::-1], in2, mode, precision=precision)[::-1, ::-1]


@_wraps(osp_signal.detrend)
def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=None):
if overwrite_data is not None:
raise NotImplementedError("overwrite_data argument not implemented.")
if type not in ['constant', 'linear']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data, = _promote_dtypes_inexact(jnp.asarray(data))
if type == 'constant':
return data - data.mean(axis, keepdims=True)
else:
N = data.shape[axis]
# bp is static, so we use np operations to avoid pushing to device.
bp = np.sort(np.unique(np.r_[0, bp, N]))
if bp[0] < 0 or bp[-1] > N:
raise ValueError("Breakpoints must be non-negative and less than length of data along given axis.")
data = jnp.moveaxis(data, axis, 0)
shape = data.shape
data = data.reshape(N, -1)
for m in range(len(bp) - 1):
Npts = bp[m + 1] - bp[m]
A = jnp.vstack([
jnp.ones(Npts, dtype=data.dtype),
jnp.arange(1, Npts + 1, dtype=data.dtype) / Npts
]).T
sl = slice(bp[m], bp[m + 1])
coef, *_ = linalg.lstsq(A, data[sl])
data = data.at[sl].add(-jnp.matmul(A, coef, precision=lax.Precision.HIGHEST))
return jnp.moveaxis(data.reshape(shape), 0, axis)
22 changes: 20 additions & 2 deletions tests/scipy_signal_test.py
Expand Up @@ -44,7 +44,7 @@ class LaxBackedScipySignalTests(jtu.JaxTestCase):
"""Tests for LAX-backed scipy.stats implementations"""

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
{"testcase_name": "_op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
Expand All @@ -67,7 +67,7 @@ def testConvolutions(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
self._CompileAndCheck(jsp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "op={}_xshape=[{}]_yshape=[{}]_mode={}".format(
{"testcase_name": "op={}_xshape={}_yshape={}_mode={}".format(
op,
jtu.format_shape_dtype_string(xshape, dtype),
jtu.format_shape_dtype_string(yshape, dtype),
Expand All @@ -89,6 +89,24 @@ def testConvolutions2D(self, xshape, yshape, dtype, mode, jsp_op, osp_op):
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_shape={}_axis={}_type={}_bp={}".format(
jtu.format_shape_dtype_string(shape, dtype), axis, type, bp),
"shape": shape, "dtype": dtype, "axis": axis, "type": type, "bp": bp}
for shape in [(5,), (4, 5), (3, 4, 5)]
for dtype in default_dtypes
for axis in [0, -1]
for type in ['constant', 'linear']
for bp in [0, [0, 2]]))
def testDetrend(self, shape, dtype, axis, type, bp):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, dtype)]
osp_fun = partial(osp_signal.detrend, axis=axis, type=type, bp=bp)
jsp_fun = partial(jsp_signal.detrend, axis=axis, type=type, bp=bp)
tol = {onp.float32: 1e-5, onp.float64: 1e-12}
self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, tol=tol)
self._CompileAndCheck(jsp_fun, args_maker, rtol=tol, atol=tol)


if __name__ == "__main__":
absltest.main()

0 comments on commit 33c455a

Please sign in to comment.