Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Updating the test suite so the timers all work on the correct dtypes.

  • Loading branch information...
commit 25392165c28845659f7c0ee647b070a4f21965f9 1 parent 3c1ee59
@hgomersall authored
Showing with 108 additions and 14 deletions.
  1. +108 −14 test/test_pyfftw.py
View
122 test/test_pyfftw.py
@@ -18,6 +18,7 @@
from pyfftw import FFTW, n_byte_align, n_byte_align_empty
import numpy
+from timeit import Timer
import unittest
@@ -82,6 +83,7 @@ def setUp(self):
self.input_dtype = numpy.complex64
self.output_dtype = numpy.complex64
+ self.np_fft_comparison = numpy.fft.fft
return
def create_test_arrays(self, input_shape, output_shape):
@@ -111,12 +113,49 @@ def make_shapes(self):
def reference_fftn(self, a, axes):
return numpy.fft.fftn(a, axes=axes)
+ def timer_routine(self, pyfftw_callable, numpy_fft_callable):
+
+ N = 100
+
+ t = Timer(stmt=pyfftw_callable)
+ t_numpy_fft = Timer(stmt=numpy_fft_callable)
+
+ t_str = ("%.2f" % (1000.0/N*t.timeit(N)))+' ms'
+ t_numpy_str = ("%.2f" % (1000.0/N*t_numpy_fft.timeit(N)))+' ms'
+
+ print ('One run: '+ t_str + \
+ ' (versus ' + t_numpy_str + ' for numpy.fft)')
+
def test_time(self):
- timer()
+
+ in_shape = self.input_shapes['2d']
+ out_shape = self.output_shapes['2d']
+
+ axes=(-1,)
+ a, b = self.create_test_arrays(in_shape, out_shape)
+
+ fft, ifft = self.run_validate_fft(a, b, axes)
+
+ self.timer_routine(fft.execute,
+ lambda: self.np_fft_comparison(a))
self.assertTrue(True)
def test_time_with_array_update(self):
- timer_with_array_update()
+ in_shape = self.input_shapes['2d']
+ out_shape = self.output_shapes['2d']
+
+ axes=(-1,)
+ a, b = self.create_test_arrays(in_shape, out_shape)
+
+ fft, ifft = self.run_validate_fft(a, b, axes)
+
+ def fftw_callable():
+ fft.update_arrays(a,b)
+ fft.execute()
+
+ self.timer_routine(fftw_callable,
+ lambda: self.np_fft_comparison(a))
+
self.assertTrue(True)
def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
@@ -132,6 +171,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
If force_unaligned_data is True, the flag FFTW_UNALIGNED
will be passed to the fftw routines.
'''
+
if create_array_copies:
# Don't corrupt the original mutable arrays
a = a.copy()
@@ -143,7 +183,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
if force_unaligned_data:
flags.append('FFTW_UNALIGNED')
-
+
if fft == None:
fft = FFTW(a,b,axes=axes,
direction='FFTW_FORWARD',flags=flags)
@@ -497,7 +537,8 @@ class Complex128FFTWTest(Complex64FFTWTest):
def setUp(self):
self.input_dtype = numpy.complex128
- self.output_dtype = numpy.complex128
+ self.output_dtype = numpy.complex128
+ self.np_fft_comparison = numpy.fft.fft
return
class ComplexLongDoubleFFTWTest(Complex64FFTWTest):
@@ -505,7 +546,8 @@ class ComplexLongDoubleFFTWTest(Complex64FFTWTest):
def setUp(self):
self.input_dtype = numpy.clongdouble
- self.output_dtype = numpy.clongdouble
+ self.output_dtype = numpy.clongdouble
+ self.np_fft_comparison = self.reference_fftn
return
def reference_fftn(self, a, axes):
@@ -515,12 +557,21 @@ def reference_fftn(self, a, axes):
a = numpy.complex128(a)
return numpy.fft.fftn(a, axes=axes)
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time(self):
+ pass
+
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time_with_array_update(self):
+ pass
+
class RealForwardDoubleFFTWTest(Complex64FFTWTest):
def setUp(self):
self.input_dtype = numpy.float64
- self.output_dtype = numpy.complex128
+ self.output_dtype = numpy.complex128
+ self.np_fft_comparison = numpy.fft.rfft
return
def make_shapes(self):
@@ -588,7 +639,8 @@ class RealForwardSingleFFTWTest(RealForwardDoubleFFTWTest):
def setUp(self):
self.input_dtype = numpy.float32
- self.output_dtype = numpy.complex64
+ self.output_dtype = numpy.complex64
+ self.np_fft_comparison = numpy.fft.rfft
return
class RealForwardLongDoubleFFTWTest(RealForwardDoubleFFTWTest):
@@ -596,9 +648,18 @@ class RealForwardLongDoubleFFTWTest(RealForwardDoubleFFTWTest):
def setUp(self):
self.input_dtype = numpy.longdouble
- self.output_dtype = numpy.clongdouble
+ self.output_dtype = numpy.clongdouble
+ self.np_fft_comparison = numpy.fft.rfft
return
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time(self):
+ pass
+
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time_with_array_update(self):
+ pass
+
def reference_fftn(self, a, axes):
a = numpy.float64(a)
@@ -609,7 +670,8 @@ class RealBackwardDoubleFFTWTest(Complex64FFTWTest):
def setUp(self):
self.input_dtype = numpy.complex128
- self.output_dtype = numpy.float64
+ self.output_dtype = numpy.float64
+ self.np_fft_comparison = numpy.fft.irfft
return
def make_shapes(self):
@@ -630,7 +692,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None):
+1j*numpy.random.randn(*input_shape))
b = self.output_dtype(numpy.random.randn(*output_shape))
-
+
# We fill a by doing the forward FFT from b.
# This means that the relevant bits that should be purely
# real will be (for example the zero freq component).
@@ -646,7 +708,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None):
fft.execute()
scaling = numpy.prod(numpy.array(a.shape))
- a = a/scaling
+ a = self.input_dtype(a/scaling)
except ValueError:
# In this case, we assume that it was meant to error,
@@ -654,7 +716,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None):
pass
b = self.output_dtype(numpy.random.randn(*output_shape))
-
+
return a, b
def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
@@ -720,6 +782,24 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
self.assertTrue(numpy.allclose(a/scaling, a_orig, rtol=1e-2, atol=1e-3))
return fft, ifft
+
+ def test_time_with_array_update(self):
+ in_shape = self.input_shapes['2d']
+ out_shape = self.output_shapes['2d']
+
+ axes=(-1,)
+ a, b = self.create_test_arrays(in_shape, out_shape)
+
+ fft, ifft = self.run_validate_fft(a, b, axes)
+
+ def fftw_callable():
+ fft.update_arrays(b,a)
+ fft.execute()
+
+ self.timer_routine(fftw_callable,
+ lambda: self.np_fft_comparison(a))
+
+ self.assertTrue(True)
def reference_fftn(self, a, axes):
# This needs to be an inverse
@@ -759,6 +839,8 @@ def test_non_contiguous_2d(self):
self.run_validate_fft(a_sliced, b_sliced, axes, create_array_copies=False)
+ @unittest.skipIf(numpy.version.version <= '1.6.1',
+ 'numpy.fft <= 1.6.1 has a bug that causes this test to fail.')
def test_non_contiguous_2d_in_3d(self):
in_shape = (256, 4, 1025)
out_shape = (256, 4, 2048)
@@ -785,7 +867,9 @@ class RealBackwardSingleFFTWTest(RealBackwardDoubleFFTWTest):
def setUp(self):
self.input_dtype = numpy.complex64
- self.output_dtype = numpy.float32
+ self.output_dtype = numpy.float32
+ self.np_fft_comparison = numpy.fft.irfft
+
return
class RealBackwardLongDoubleFFTWTest(RealBackwardDoubleFFTWTest):
@@ -793,7 +877,8 @@ class RealBackwardLongDoubleFFTWTest(RealBackwardDoubleFFTWTest):
def setUp(self):
self.input_dtype = numpy.clongdouble
- self.output_dtype = numpy.longdouble
+ self.output_dtype = numpy.longdouble
+ self.np_fft_comparison = numpy.fft.irfft
return
def reference_fftn(self, a, axes):
@@ -801,6 +886,15 @@ def reference_fftn(self, a, axes):
a = numpy.complex128(a)
return numpy.fft.irfftn(a, axes=axes)
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time(self):
+ pass
+
+ @unittest.skip('numpy.fft has issues with this dtype.')
+ def test_time_with_array_update(self):
+ pass
+
+
class NByteAlignTest(unittest.TestCase):
def setUp(self):
Please sign in to comment.
Something went wrong with that request. Please try again.