Skip to content

Commit

Permalink
Style and python2.7 fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Aug 5, 2019
1 parent 6d6b630 commit b4a8cf1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 22 deletions.
4 changes: 2 additions & 2 deletions cupyx/scipy/fft/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from cupyx.scipy.fft.fft import *
from cupyx.scipy.fft.fft import __all__
from cupyx.scipy.fft.fft import * # NOQA
from cupyx.scipy.fft.fft import __all__ # NOQA
6 changes: 4 additions & 2 deletions cupyx/scipy/fft/fft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import cupy
from cupy.cuda import cufft
from cupy.fft.fft import (_fft, _default_fft_func, _convert_fft_type,
hfft as _hfft, ihfft as _ihfft)
from cupy.fft.fft import (_fft, _default_fft_func, hfft as _hfft,
ihfft as _ihfft)
from cupy.fft.fft import fftshift, ifftshift, fftfreq, rfftfreq
import numpy as np


__all__ = ['fft', 'ifft', 'fft2', 'ifft2', 'fftn', 'ifftn',
Expand All @@ -25,6 +26,7 @@ def __getattr__(cls, name):
__ua_domain__ = 'numpy.scipy.fft'
_implemented = {}


def __ua_convert__(dispatchables, coerce):
if not all(d.dispatch_type == np.ndarray for d in dispatchables):
return NotImplemented
Expand Down
33 changes: 15 additions & 18 deletions tests/cupyx_tests/scipy_tests/fft_tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def _fft_module(xp):
else:
return np.fft


def _correct_np_dtype(xp, dtype, out):
# NumPy always transforms in double precision, cast output to correct type
if xp == np:
Expand Down Expand Up @@ -74,20 +75,19 @@ def test_ifft_overwrite(self, xp, dtype):
return _correct_np_dtype(xp, dtype, out)


@testing.parameterize(
*testing.product({
@testing.parameterize(*(
testing.product({
'shape': [(3, 4)],
's': [None, (1, 5)],
'axes': [None, (-2, -1), (-1, -2), (0,)],
'norm': [None, 'ortho']
}),
*testing.product({
})
+ testing.product({
'shape': [(2, 3, 4)],
's': [None, (1, 5), (1, 4, 10)],
'axes': [None, (-2, -1), (-1, -2, -3)],
'norm': [None, 'ortho']
})
)
})))
@testing.gpu
@testing.with_requires('numpy>=1.10.0')
class TestFft2(unittest.TestCase):
Expand Down Expand Up @@ -118,7 +118,8 @@ def test_fft2_overwrite(self, xp, dtype):
def test_ifft2(self, xp, dtype):
x = testing.shaped_random(self.shape, xp, dtype)
x_orig = x.copy()
out = _fft_module(xp).ifft2(x, s=self.s, axes=self.axes, norm=self.norm)
out = _fft_module(xp).ifft2(
x, s=self.s, axes=self.axes, norm=self.norm)
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)

Expand All @@ -133,26 +134,25 @@ def test_ifft2_overwrite(self, xp, dtype):
return _correct_np_dtype(xp, dtype, out)


@testing.parameterize(
*testing.product({
@testing.parameterize(*(
testing.product({
'shape': [(3, 4)],
's': [None, (1, 5)],
'axes': [None, (-2, -1), (-1, -2), (0,)],
'norm': [None, 'ortho']
}),
*testing.product({
})
+ testing.product({
'shape': [(2, 3, 4)],
's': [None, (1, 5), (1, 4, 10)],
'axes': [None, (-2, -1), (-1, -2, -3)],
'norm': [None, 'ortho']
}),
*testing.product({
})
+ testing.product({
'shape': [(2, 3, 4, 5)],
's': [None],
'axes': [None, (0, 1, 2, 3)],
'norm': [None, 'ortho']
})
)
})))
@testing.gpu
@testing.with_requires('numpy>=1.10.0')
class TestFftn(unittest.TestCase):
Expand All @@ -178,7 +178,6 @@ def test_fftn_overwrite(self, xp, dtype):
norm=self.norm, **overwrite_kw)
return _correct_np_dtype(xp, dtype, out)


@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
contiguous_check=False)
Expand Down Expand Up @@ -349,7 +348,6 @@ def test_rfftn(self, xp, dtype):
testing.assert_array_equal(x, x_orig)
return _correct_np_dtype(xp, dtype, out)


@testing.for_all_dtypes(no_complex=True)
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
contiguous_check=False)
Expand All @@ -360,7 +358,6 @@ def test_rfftn_overwrite(self, xp, dtype):
norm=self.norm, **overwrite_kw)
return _correct_np_dtype(xp, dtype, out)


@testing.for_all_dtypes()
@testing.numpy_cupy_allclose(rtol=1e-4, atol=1e-7, accept_error=ValueError,
contiguous_check=False)
Expand Down

0 comments on commit b4a8cf1

Please sign in to comment.