Skip to content

Commit

Permalink
Merge pull request #8061 from asi1024/wavelet-deprecation
Browse files Browse the repository at this point in the history
Deprecate `cupyx.scipy` wavelet functions
  • Loading branch information
asi1024 committed Jan 24, 2024
2 parents 710257a + 17c4df1 commit 37602b2
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
21 changes: 20 additions & 1 deletion cupyx/scipy/signal/_wavelets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
Wavelet-generating functions.
Expand All @@ -25,12 +24,22 @@
DEALINGS IN THE SOFTWARE.
"""

import warnings

import cupy
import numpy as np

from cupyx.scipy.signal._signaltools import convolve


_deprecate_msg = (
"Following the change in SciPy 1.12, all wavelet functions have been "
"deprecated in CuPy v14 and are planned to be removed in the future. "
"To request continued support of the features, "
"please leave a comment at https://github.com/cupy/cupy/pull/8061."
)


_qmf_kernel = cupy.ElementwiseKernel(
"raw T coef",
"T output",
Expand All @@ -53,6 +62,8 @@ def qmf(hk):
Coefficients of high-pass filter.
"""
warnings.warn(_deprecate_msg, DeprecationWarning)

hk = cupy.asarray(hk)
return _qmf_kernel(hk, size=len(hk))

Expand Down Expand Up @@ -131,6 +142,8 @@ def morlet(M, w=5.0, s=1.0, complete=True):
with it.
"""
warnings.warn(_deprecate_msg, DeprecationWarning)

return _morlet_kernel(w, s, complete, size=M)


Expand Down Expand Up @@ -190,6 +203,8 @@ def ricker(points, a):
>>> plt.show()
"""
warnings.warn(_deprecate_msg, DeprecationWarning)

return _ricker_kernel(a, size=int(points))


Expand Down Expand Up @@ -279,6 +294,8 @@ def morlet2(M, s, w=5):
cmap='viridis', shading='gouraud')
>>> plt.show()
"""
warnings.warn(_deprecate_msg, DeprecationWarning)

return _morlet2_kernel(w, s, size=int(M))


Expand Down Expand Up @@ -333,6 +350,8 @@ def cwt(data, wavelet, widths):
>>> plt.show()
""" # NOQA
warnings.warn(_deprecate_msg, DeprecationWarning)

if cupy.asarray(wavelet(1, 1)).dtype.char in "FDG":
dtype = cupy.complex128
else:
Expand Down
58 changes: 42 additions & 16 deletions tests/cupyx_tests/scipy_tests/signal_tests/test_wavelets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings

import pytest

Expand All @@ -10,25 +11,30 @@
pass


@testing.with_requires('scipy')
class TestWavelets:
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_qmf(self, xp, scp):
return scp.signal.qmf([1, 1])
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.qmf([1, 1])

@pytest.mark.skip(reason='daub is not available on cupyx.scipy.signal')
@pytest.mark.parametrize('p', list(range(1, 15)))
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_daub(self, p, xp, scp):
return scp.signal.daub(p)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.daub(p)

@pytest.mark.skip(reason='cascade is not available on cupyx.scipy.signal')
@pytest.mark.parametrize('J', list(range(1, 7)))
@pytest.mark.parametrize('i', list(range(1, 5)))
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_cascade(self, J, i, xp, scp):
lpcoef = scp.signal.daub(i)
x, phi, psi = scp.signal.cascade(lpcoef, J)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
lpcoef = scp.signal.daub(i)
x, phi, psi = scp.signal.cascade(lpcoef, J)
return x, phi, psi

@pytest.mark.parametrize('args,kwargs', [
Expand All @@ -46,44 +52,60 @@ def test_cascade(self, J, i, xp, scp):
])
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_morlet(self, args, kwargs, xp, scp):
return scp.signal.morlet(*args, **kwargs)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.morlet(*args, **kwargs)

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_morlet2(self, xp, scp):
return scp.signal.morlet2(1.0, 0.5)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.morlet2(1.0, 0.5)

@pytest.mark.parametrize('length', [5, 11, 15, 51, 101])
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_morlet2_length(self, length, xp, scp):
return scp.signal.morlet2(length, 1.0)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.morlet2(length, 1.0)

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_morlet2_points(self, xp, scp):
points = 100
w = scp.signal.morlet2(points, 2.0)
y = scp.signal.morlet2(3, s=1/(2*xp.pi), w=2)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
w = scp.signal.morlet2(points, 2.0)
y = scp.signal.morlet2(3, s=1/(2*xp.pi), w=2)
return w, y

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_ricker(self, xp, scp):
return scp.signal.ricker(1.0, 1)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.ricker(1.0, 1)

@pytest.mark.parametrize('length', [5, 11, 15, 51, 101])
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_ricker_length(self, length, xp, scp):
return scp.signal.ricker(length, 1.0)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.ricker(length, 1.0)

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_ricker_points(self, xp, scp):
points = 100
return scp.signal.ricker(points, 2.0)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.ricker(points, 2.0)

@pytest.mark.parametrize('a', [5, 10, 15, 20, 30])
@testing.numpy_cupy_allclose(scipy_name="scp")
def test_ricker_zeros(self, a, xp, scp):
# Check zeros
points = 99
return scp.signal.ricker(points, a)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.ricker(points, a)

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_cwt_delta(self, xp, scp):
Expand All @@ -94,12 +116,16 @@ def test_cwt_delta(self, xp, scp):
def delta_wavelet(s, t):
return xp.array([1])

return scp.signal.cwt(test_data, delta_wavelet, widths)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.cwt(test_data, delta_wavelet, widths)

@testing.numpy_cupy_allclose(scipy_name="scp")
def test_cwt_ricker(self, xp, scp):
len_data = 100
test_data = xp.sin(xp.pi * xp.arange(0, len_data) / 10.0)
# Check proper shape on output
widths = [1, 3, 4, 5, 10]
return scp.signal.cwt(test_data, scp.signal.ricker, widths)
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
return scp.signal.cwt(test_data, scp.signal.ricker, widths)

0 comments on commit 37602b2

Please sign in to comment.