diff --git a/test/test_pyfftw.py b/test/test_pyfftw.py index fba387db..a22a0568 100644 --- a/test/test_pyfftw.py +++ b/test/test_pyfftw.py @@ -18,6 +18,7 @@ from pyfftw import FFTW, n_byte_align, n_byte_align_empty import numpy +from timeit import Timer import unittest @@ -82,6 +83,7 @@ def setUp(self): self.input_dtype = numpy.complex64 self.output_dtype = numpy.complex64 + self.np_fft_comparison = numpy.fft.fft return def create_test_arrays(self, input_shape, output_shape): @@ -111,12 +113,49 @@ def make_shapes(self): def reference_fftn(self, a, axes): return numpy.fft.fftn(a, axes=axes) + def timer_routine(self, pyfftw_callable, numpy_fft_callable): + + N = 100 + + t = Timer(stmt=pyfftw_callable) + t_numpy_fft = Timer(stmt=numpy_fft_callable) + + t_str = ("%.2f" % (1000.0/N*t.timeit(N)))+' ms' + t_numpy_str = ("%.2f" % (1000.0/N*t_numpy_fft.timeit(N)))+' ms' + + print ('One run: '+ t_str + \ + ' (versus ' + t_numpy_str + ' for numpy.fft)') + def test_time(self): - timer() + + in_shape = self.input_shapes['2d'] + out_shape = self.output_shapes['2d'] + + axes=(-1,) + a, b = self.create_test_arrays(in_shape, out_shape) + + fft, ifft = self.run_validate_fft(a, b, axes) + + self.timer_routine(fft.execute, + lambda: self.np_fft_comparison(a)) self.assertTrue(True) def test_time_with_array_update(self): - timer_with_array_update() + in_shape = self.input_shapes['2d'] + out_shape = self.output_shapes['2d'] + + axes=(-1,) + a, b = self.create_test_arrays(in_shape, out_shape) + + fft, ifft = self.run_validate_fft(a, b, axes) + + def fftw_callable(): + fft.update_arrays(a,b) + fft.execute() + + self.timer_routine(fftw_callable, + lambda: self.np_fft_comparison(a)) + self.assertTrue(True) def run_validate_fft(self, a, b, axes, fft=None, ifft=None, @@ -132,6 +171,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None, If force_unaligned_data is True, the flag FFTW_UNALIGNED will be passed to the fftw routines. ''' + if create_array_copies: # Don't corrupt the original mutable arrays a = a.copy() @@ -143,7 +183,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None, if force_unaligned_data: flags.append('FFTW_UNALIGNED') - + if fft == None: fft = FFTW(a,b,axes=axes, direction='FFTW_FORWARD',flags=flags) @@ -497,7 +537,8 @@ class Complex128FFTWTest(Complex64FFTWTest): def setUp(self): self.input_dtype = numpy.complex128 - self.output_dtype = numpy.complex128 + self.output_dtype = numpy.complex128 + self.np_fft_comparison = numpy.fft.fft return class ComplexLongDoubleFFTWTest(Complex64FFTWTest): @@ -505,7 +546,8 @@ class ComplexLongDoubleFFTWTest(Complex64FFTWTest): def setUp(self): self.input_dtype = numpy.clongdouble - self.output_dtype = numpy.clongdouble + self.output_dtype = numpy.clongdouble + self.np_fft_comparison = self.reference_fftn return def reference_fftn(self, a, axes): @@ -515,12 +557,21 @@ def reference_fftn(self, a, axes): a = numpy.complex128(a) return numpy.fft.fftn(a, axes=axes) + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time(self): + pass + + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time_with_array_update(self): + pass + class RealForwardDoubleFFTWTest(Complex64FFTWTest): def setUp(self): self.input_dtype = numpy.float64 - self.output_dtype = numpy.complex128 + self.output_dtype = numpy.complex128 + self.np_fft_comparison = numpy.fft.rfft return def make_shapes(self): @@ -588,7 +639,8 @@ class RealForwardSingleFFTWTest(RealForwardDoubleFFTWTest): def setUp(self): self.input_dtype = numpy.float32 - self.output_dtype = numpy.complex64 + self.output_dtype = numpy.complex64 + self.np_fft_comparison = numpy.fft.rfft return class RealForwardLongDoubleFFTWTest(RealForwardDoubleFFTWTest): @@ -596,9 +648,18 @@ class RealForwardLongDoubleFFTWTest(RealForwardDoubleFFTWTest): def setUp(self): self.input_dtype = numpy.longdouble - self.output_dtype = numpy.clongdouble + self.output_dtype = numpy.clongdouble + self.np_fft_comparison = numpy.fft.rfft return + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time(self): + pass + + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time_with_array_update(self): + pass + def reference_fftn(self, a, axes): a = numpy.float64(a) @@ -609,7 +670,8 @@ class RealBackwardDoubleFFTWTest(Complex64FFTWTest): def setUp(self): self.input_dtype = numpy.complex128 - self.output_dtype = numpy.float64 + self.output_dtype = numpy.float64 + self.np_fft_comparison = numpy.fft.irfft return def make_shapes(self): @@ -630,7 +692,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None): +1j*numpy.random.randn(*input_shape)) b = self.output_dtype(numpy.random.randn(*output_shape)) - + # We fill a by doing the forward FFT from b. # This means that the relevant bits that should be purely # real will be (for example the zero freq component). @@ -646,7 +708,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None): fft.execute() scaling = numpy.prod(numpy.array(a.shape)) - a = a/scaling + a = self.input_dtype(a/scaling) except ValueError: # In this case, we assume that it was meant to error, @@ -654,7 +716,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None): pass b = self.output_dtype(numpy.random.randn(*output_shape)) - + return a, b def run_validate_fft(self, a, b, axes, fft=None, ifft=None, @@ -720,6 +782,24 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None, self.assertTrue(numpy.allclose(a/scaling, a_orig, rtol=1e-2, atol=1e-3)) return fft, ifft + + def test_time_with_array_update(self): + in_shape = self.input_shapes['2d'] + out_shape = self.output_shapes['2d'] + + axes=(-1,) + a, b = self.create_test_arrays(in_shape, out_shape) + + fft, ifft = self.run_validate_fft(a, b, axes) + + def fftw_callable(): + fft.update_arrays(b,a) + fft.execute() + + self.timer_routine(fftw_callable, + lambda: self.np_fft_comparison(a)) + + self.assertTrue(True) def reference_fftn(self, a, axes): # This needs to be an inverse @@ -759,6 +839,8 @@ def test_non_contiguous_2d(self): self.run_validate_fft(a_sliced, b_sliced, axes, create_array_copies=False) + @unittest.skipIf(numpy.version.version <= '1.6.1', + 'numpy.fft <= 1.6.1 has a bug that causes this test to fail.') def test_non_contiguous_2d_in_3d(self): in_shape = (256, 4, 1025) out_shape = (256, 4, 2048) @@ -785,7 +867,9 @@ class RealBackwardSingleFFTWTest(RealBackwardDoubleFFTWTest): def setUp(self): self.input_dtype = numpy.complex64 - self.output_dtype = numpy.float32 + self.output_dtype = numpy.float32 + self.np_fft_comparison = numpy.fft.irfft + return class RealBackwardLongDoubleFFTWTest(RealBackwardDoubleFFTWTest): @@ -793,7 +877,8 @@ class RealBackwardLongDoubleFFTWTest(RealBackwardDoubleFFTWTest): def setUp(self): self.input_dtype = numpy.clongdouble - self.output_dtype = numpy.longdouble + self.output_dtype = numpy.longdouble + self.np_fft_comparison = numpy.fft.irfft return def reference_fftn(self, a, axes): @@ -801,6 +886,15 @@ def reference_fftn(self, a, axes): a = numpy.complex128(a) return numpy.fft.irfftn(a, axes=axes) + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time(self): + pass + + @unittest.skip('numpy.fft has issues with this dtype.') + def test_time_with_array_update(self): + pass + + class NByteAlignTest(unittest.TestCase): def setUp(self):