Skip to content

Commit

Permalink
Updating the test suite so the timers all work on the correct dtypes.
Browse files Browse the repository at this point in the history
  • Loading branch information
hgomersall committed Feb 26, 2012
1 parent 3c1ee59 commit 2539216
Showing 1 changed file with 108 additions and 14 deletions.
122 changes: 108 additions & 14 deletions test/test_pyfftw.py
Expand Up @@ -18,6 +18,7 @@

from pyfftw import FFTW, n_byte_align, n_byte_align_empty
import numpy
from timeit import Timer

import unittest

Expand Down Expand Up @@ -82,6 +83,7 @@ def setUp(self):

self.input_dtype = numpy.complex64
self.output_dtype = numpy.complex64
self.np_fft_comparison = numpy.fft.fft
return

def create_test_arrays(self, input_shape, output_shape):
Expand Down Expand Up @@ -111,12 +113,49 @@ def make_shapes(self):
def reference_fftn(self, a, axes):
return numpy.fft.fftn(a, axes=axes)

def timer_routine(self, pyfftw_callable, numpy_fft_callable):

N = 100

t = Timer(stmt=pyfftw_callable)
t_numpy_fft = Timer(stmt=numpy_fft_callable)

t_str = ("%.2f" % (1000.0/N*t.timeit(N)))+' ms'
t_numpy_str = ("%.2f" % (1000.0/N*t_numpy_fft.timeit(N)))+' ms'

print ('One run: '+ t_str + \
' (versus ' + t_numpy_str + ' for numpy.fft)')

def test_time(self):
timer()

in_shape = self.input_shapes['2d']
out_shape = self.output_shapes['2d']

axes=(-1,)
a, b = self.create_test_arrays(in_shape, out_shape)

fft, ifft = self.run_validate_fft(a, b, axes)

self.timer_routine(fft.execute,
lambda: self.np_fft_comparison(a))
self.assertTrue(True)

def test_time_with_array_update(self):
timer_with_array_update()
in_shape = self.input_shapes['2d']
out_shape = self.output_shapes['2d']

axes=(-1,)
a, b = self.create_test_arrays(in_shape, out_shape)

fft, ifft = self.run_validate_fft(a, b, axes)

def fftw_callable():
fft.update_arrays(a,b)
fft.execute()

self.timer_routine(fftw_callable,
lambda: self.np_fft_comparison(a))

self.assertTrue(True)

def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
Expand All @@ -132,6 +171,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
If force_unaligned_data is True, the flag FFTW_UNALIGNED
will be passed to the fftw routines.
'''

if create_array_copies:
# Don't corrupt the original mutable arrays
a = a.copy()
Expand All @@ -143,7 +183,7 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,

if force_unaligned_data:
flags.append('FFTW_UNALIGNED')

if fft == None:
fft = FFTW(a,b,axes=axes,
direction='FFTW_FORWARD',flags=flags)
Expand Down Expand Up @@ -497,15 +537,17 @@ class Complex128FFTWTest(Complex64FFTWTest):
def setUp(self):

self.input_dtype = numpy.complex128
self.output_dtype = numpy.complex128
self.output_dtype = numpy.complex128
self.np_fft_comparison = numpy.fft.fft
return

class ComplexLongDoubleFFTWTest(Complex64FFTWTest):

def setUp(self):

self.input_dtype = numpy.clongdouble
self.output_dtype = numpy.clongdouble
self.output_dtype = numpy.clongdouble
self.np_fft_comparison = self.reference_fftn
return

def reference_fftn(self, a, axes):
Expand All @@ -515,12 +557,21 @@ def reference_fftn(self, a, axes):
a = numpy.complex128(a)
return numpy.fft.fftn(a, axes=axes)

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time(self):
pass

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time_with_array_update(self):
pass

class RealForwardDoubleFFTWTest(Complex64FFTWTest):

def setUp(self):

self.input_dtype = numpy.float64
self.output_dtype = numpy.complex128
self.output_dtype = numpy.complex128
self.np_fft_comparison = numpy.fft.rfft
return

def make_shapes(self):
Expand Down Expand Up @@ -588,17 +639,27 @@ class RealForwardSingleFFTWTest(RealForwardDoubleFFTWTest):
def setUp(self):

self.input_dtype = numpy.float32
self.output_dtype = numpy.complex64
self.output_dtype = numpy.complex64
self.np_fft_comparison = numpy.fft.rfft
return

class RealForwardLongDoubleFFTWTest(RealForwardDoubleFFTWTest):

def setUp(self):

self.input_dtype = numpy.longdouble
self.output_dtype = numpy.clongdouble
self.output_dtype = numpy.clongdouble
self.np_fft_comparison = numpy.fft.rfft
return

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time(self):
pass

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time_with_array_update(self):
pass

def reference_fftn(self, a, axes):

a = numpy.float64(a)
Expand All @@ -609,7 +670,8 @@ class RealBackwardDoubleFFTWTest(Complex64FFTWTest):
def setUp(self):

self.input_dtype = numpy.complex128
self.output_dtype = numpy.float64
self.output_dtype = numpy.float64
self.np_fft_comparison = numpy.fft.irfft
return

def make_shapes(self):
Expand All @@ -630,7 +692,7 @@ def create_test_arrays(self, input_shape, output_shape, axes=None):
+1j*numpy.random.randn(*input_shape))

b = self.output_dtype(numpy.random.randn(*output_shape))

# We fill a by doing the forward FFT from b.
# This means that the relevant bits that should be purely
# real will be (for example the zero freq component).
Expand All @@ -646,15 +708,15 @@ def create_test_arrays(self, input_shape, output_shape, axes=None):
fft.execute()

scaling = numpy.prod(numpy.array(a.shape))
a = a/scaling
a = self.input_dtype(a/scaling)

except ValueError:
# In this case, we assume that it was meant to error,
# so we can return what we want.
pass

b = self.output_dtype(numpy.random.randn(*output_shape))

return a, b

def run_validate_fft(self, a, b, axes, fft=None, ifft=None,
Expand Down Expand Up @@ -720,6 +782,24 @@ def run_validate_fft(self, a, b, axes, fft=None, ifft=None,

self.assertTrue(numpy.allclose(a/scaling, a_orig, rtol=1e-2, atol=1e-3))
return fft, ifft

def test_time_with_array_update(self):
in_shape = self.input_shapes['2d']
out_shape = self.output_shapes['2d']

axes=(-1,)
a, b = self.create_test_arrays(in_shape, out_shape)

fft, ifft = self.run_validate_fft(a, b, axes)

def fftw_callable():
fft.update_arrays(b,a)
fft.execute()

self.timer_routine(fftw_callable,
lambda: self.np_fft_comparison(a))

self.assertTrue(True)

def reference_fftn(self, a, axes):
# This needs to be an inverse
Expand Down Expand Up @@ -759,6 +839,8 @@ def test_non_contiguous_2d(self):

self.run_validate_fft(a_sliced, b_sliced, axes, create_array_copies=False)

@unittest.skipIf(numpy.version.version <= '1.6.1',
'numpy.fft <= 1.6.1 has a bug that causes this test to fail.')
def test_non_contiguous_2d_in_3d(self):
in_shape = (256, 4, 1025)
out_shape = (256, 4, 2048)
Expand All @@ -785,22 +867,34 @@ class RealBackwardSingleFFTWTest(RealBackwardDoubleFFTWTest):
def setUp(self):

self.input_dtype = numpy.complex64
self.output_dtype = numpy.float32
self.output_dtype = numpy.float32
self.np_fft_comparison = numpy.fft.irfft

return

class RealBackwardLongDoubleFFTWTest(RealBackwardDoubleFFTWTest):

def setUp(self):

self.input_dtype = numpy.clongdouble
self.output_dtype = numpy.longdouble
self.output_dtype = numpy.longdouble
self.np_fft_comparison = numpy.fft.irfft
return

def reference_fftn(self, a, axes):

a = numpy.complex128(a)
return numpy.fft.irfftn(a, axes=axes)

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time(self):
pass

@unittest.skip('numpy.fft has issues with this dtype.')
def test_time_with_array_update(self):
pass


class NByteAlignTest(unittest.TestCase):

def setUp(self):
Expand Down

0 comments on commit 2539216

Please sign in to comment.