Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for estimate_shift_2D() and other FFTs #2358

Merged
merged 27 commits into from
Apr 15, 2020
Merged

Performance improvements for estimate_shift_2D() and other FFTs #2358

merged 27 commits into from
Apr 15, 2020

Conversation

tjof2
Copy link
Contributor

@tjof2 tjof2 commented Apr 9, 2020

Description of the change

As part of this pyxem issue, I noticed that s.align2D() spends a long time estimating the shifts between images. I've tracked this down to two bottlenecks:

  • Switch scipy.signal.medfilt for scipy.ndimage.median_filter.
    The scipy documentation reference for the old function says itself that "The more general function scipy.ndimage.median_filter has a more efficient implementation of a median filter and therefore runs much faster.", so this seems like a sensible change.
  • Changed how we calculate the optimal FFT size to be consistent with scipy.signal.fftconvolve.

On a related note, I see that in #2295, the code was switched from scipy FFT to numpy FFT. On my machine at least, the numpy FFT correlation implementation is 2x slower than scipy fftconvolve for the examples below. I know why this is now, it is fixed below by using rfft() where possible.

Progress of the PR

  • Change implemented
  • Ready for review

Minimal example of the bug fix or the new feature

>>> s
<Signal2D, title: , dimensions: (100|200, 200)>

>>> %timeit s.estimate_shift2D(show_progressbar=False)
# Using sp.signal.medfilt
4.80 s ± 127 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Using sp.ndimage.median_filter
2.77 s ± 29.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

>>> %timeit s.estimate_shift2D(show_progressbar=False, medfilter=False)
# Using old FFT size calculation
2.22 s ± 30 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# Using new FFT size calculation
1.21 s ± 34.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

If we combine the two changes together:

>>> %timeit s.estimate_shift2D(show_progressbar=False)
# Before
4.80 s ± 127 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# After
1.92 s ± 40.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

@dnjohnstone I think this solves most of the issue discussed in pyxem for center_direct_beam(), aside from the threading already mentioned.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 9, 2020

I think that if the input data is real, we should use np.fft.rfftn() and that might be the speed-up to match scipy. I'll test and update tomorrow.

@ericpre
Copy link
Member

ericpre commented Apr 9, 2020

What are the numpy and scipy version on your machine? Do you have the numpy linked against the mkl libraries?

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 9, 2020

@ericpre I do - as I write I think the difference is due to scipy being "smarter" with the input than numpy, and that's all. I'll confirm tomorrow.

Either way, calculating the optimal size as I have here is the more obvious improvement according to my profiling. Will see what I can do tomorrow.

Not sure why CI is failing btw.

Looking at https://github.com/scipy/scipy/blob/v1.4.1/scipy/signal/signaltools.py#L377 this is almost certainly the reason for the discrepancy - scipy is smarter with rfftn vs fftn for real data, while we rely only on fftn.

I'll test making this change tomorrow to see how it works.

hyperspy/_signals/signal2d.py Outdated Show resolved Hide resolved
hyperspy/_signals/signal2d.py Show resolved Hide resolved
@ericpre
Copy link
Member

ericpre commented Apr 9, 2020

@ericpre looking at https://github.com/scipy/scipy/blob/v1.4.1/scipy/signal/signaltools.py#L377 this is almost certainly the reason for the discrepancy - scipy is smarter with rfftn vs fftn for real data, while we rely only on fftn.

Yes, this sounds sensible: this recent benchmark (https://github.com/project-gemmi/benchmarking-fft/) show the same thing - there is a consistent factor of 2 between real-to-complex and complex-to-complex.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 9, 2020

Awesome thanks! I'll get it tidied up if it works then :-)

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 10, 2020

OK I've tested this now, rfft is noticeably faster, but it looks like the tests only pass if we use rfft() when not doing sub-pixel alignment.

So I've added a check for this to use the faster version if we can, when sub_pixel_factor=1.

>>> %timeit s.estimate_shift2D(show_progressbar=False)
# Before
4.80 s ± 127 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# After (using fftn())
1.92 s ± 40.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# After (using rfftn())
1.44 s ± 40.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Here is an example test error for sub-pixel alignment if we were to use rfftn(). As you can see, there are some differences in the shifts when using the faster function. Perhaps these differences are acceptable?

tests/signal/test_2D_tools.py:76: AssertionError
______________ TestSubPixelAlign.test_estimate_subpix[True-stat] _______________

self = <hyperspy.tests.signal.test_2D_tools.TestSubPixelAlign object at 0x7f3be0035a10>
normalize_corr = True, reference = 'stat'

    @pytest.mark.parametrize(("normalize_corr", "reference"),
                             _generate_parameters())
    def test_estimate_subpix(self, normalize_corr, reference):
        s = self.signal
        shifts = s.estimate_shift2D(sub_pixel_factor=200,
                                    normalize_corr=normalize_corr)
        np.testing.assert_allclose(shifts, self.shifts, rtol=0.2, atol=0.2,
>                                  verbose=True)
E       AssertionError: 
E       Not equal to tolerance rtol=0.2, atol=0.2
E       
E       Mismatched elements: 2 / 20 (10%)
E       Max absolute difference: 0.875
E       Max relative difference: 0.41079812
E        x: array([[-0.   , -0.   ],
E              [ 4.07 ,  1.255],
E              [ 1.93 ,  3.255],...
E        y: array([[ 0.  ,  0.  ],
E              [ 4.3 ,  2.13],
E              [ 1.65,  3.58],...

To be fair, the tests are created by applying shifts using np.fft.fftn(), so perhaps its no surprise that it matches fine when we undo it with fftn() but doesn't match so closely when we use rfftn(). Maybe scipy.ndimage.interpolation.shift in the tests would avoid any such bias.

@dnjohnstone any view on this behaviour and which is "better"?

@ericpre
Copy link
Member

ericpre commented Apr 10, 2020

To be fair, the tests are created by applying shifts using np.fft.fftn(), so perhaps its no surprise that it matches fine when we undo it with fftn() but doesn't match so closely when we use rfftn().

Indeed, this is a fairly bias ground truth...

hyperspy/_signals/signal2d.py Outdated Show resolved Hide resolved
@tjof2 tjof2 mentioned this pull request Apr 12, 2020
17 tasks
@tjof2 tjof2 changed the title Performance improvements for estimate_shift_2D() and align2D() Performance improvements for estimate_shift_2D() and other FFTs Apr 13, 2020
@tjof2
Copy link
Contributor Author

tjof2 commented Apr 13, 2020

I went through and looked at other places where we used the conservative power-of-two for FFT sizes, and replaced them with a wrapper function to next_fast_len() in scipy.

CI failures are unrelated - it's only broken on MacOS

hyperspy/_signals/eels.py Outdated Show resolved Hide resolved
@ericpre
Copy link
Member

ericpre commented Apr 13, 2020

Would it make sense to also use rfft (when suitable) in

def fft(self, shift=False, apodization=False, **kwargs):

On a similar topic, one thing I found not great is that the ifft method of BaseSignal returns the real part only while we should be able to choose between returning real or a both...

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 13, 2020

@ericpre I think you're probably right, but I'm not too sure about potential impact it could have generally - at least inside estimate_shift_2D I'm pretty confident on what the change does and that the tests cover the change in an acceptable manner, but I'm worried it might break things for other users in the general case.

@ericpre
Copy link
Member

ericpre commented Apr 13, 2020

Yes, good point, maybe add an option to choose one of the two?

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 14, 2020

Yes, good point, maybe add an option to choose one of the two?

@ericpre I added an option to return either just the real part from the IFFT (default behaviour) vs. all.

@ericpre
Copy link
Member

ericpre commented Apr 14, 2020

Great thanks, this looks good to me. Is there any reason to revert the commit adding the option to use the rfft in fft?

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 14, 2020

@ericpre I'll unrevert it, then I think this is ready for review. Done.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

To summarise: I don't really know why my changes have affected the memory usage in 3.8, but here are the steps I've taken:

  1. Made all SamfirePools have workers=1
  2. Clarified that test_multiprocessed() uses multiprocessing rather than ipyparallel
  3. Added explicit teardown_method() to do garbage collection and try and delete the SamfirePools where possible.
  4. Shrunk the dataset size from 15, 15, 1024 to 15, 15, 400.
  5. Changed the dataset type to float32.
  6. Investigated the reason why we xfail on test_multiprocessed(). This one worries me a bit. If I run pytest ./hyperspy/hyperspy/samfire/test* it always passes, but if I do pytest . it fails half the time.

Point (1) seems to have helped it get past the first sticking point: 1.0.4485, but I'll let AppVeyor catch up and test the other changes I've made to see if it gets all the way through.

My worry is that is actually a much deeper issue with parallel pools, since the correct way to use them is this, and we don't do that.

with multiprocessing.Pool() as pool:
    pool.map(...)

So that the context manager can properly collect and clean up. An interesting Python 3.8 issue where this is discussed is here (39360).

@ericpre
Copy link
Member

ericpre commented Apr 15, 2020

@ericpre, you report test_samfire.py::TestSamfireMain::()::test_multiprocessed uses 1695.2 MB, I only see it using 210 MB on my machine. I'm using Python 3.7 - are you using 3.8? That might be it.

It is fluctuating a lot and the one I posted here was possibly the first I ran (which was quite surprising) and it is one of the worse in term of memory usage.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

Yeah I'm looking at it still - might have a fix.

samfire tests work with float16 dtype also
@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

@ericpre why does from hyperspy.signals import BaseSignal cause a 270 MB increase in memory on my machine?

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

FYI it looks like switching the dtype to float16 and improving how we close the multiprocessing.Pool() instances may have fixed AppVeyor - see here.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

This is now ready for review.

FWIW I believe the appveyor issues with samfire have nothing to do with the original changes in the PR to the FFTs.

I actually think there is something inherently wrong with multiprocessing as used in samfire and the test suite. If I install pytest-random and randomise the order of the tests, I get random pass/fail for the test_multiprocessed, and the estimated memory usage varies drastically with the test order as well.

Copy link
Member

@ericpre ericpre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very good to me! See comment about putting it back a previous change.
It maybe worth opening an issue about your observation on samfire to keep a record.

hyperspy/_signals/eels.py Outdated Show resolved Hide resolved
@ericpre
Copy link
Member

ericpre commented Apr 15, 2020

@ericpre why does from hyperspy.signals import BaseSignal cause a 270 MB increase in memory on my machine?

As it is currently, BaseSignal will import almost everything, so this doesn't sound unexpected. However, if this is normal behaviour, this is a different question...

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

Thanks @ericpre - I've made the change you requested so I think this is good. I'll open a separate issue(s) about samfire and the basesignal import.

@ericpre
Copy link
Member

ericpre commented Apr 15, 2020

Great thanks! To be on the safe side, I will merge once appveyor have gone through the backlog... but at last now it doesn't hang for 1h on a single build.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

Sounds sensible, I'm happy with that!

but at last now it doesn't hang for 1h on a single build.

Indeed, the last three commits that have run have worked fine.

@tjof2
Copy link
Contributor Author

tjof2 commented Apr 15, 2020

Passed on all of the last 5 commits! Success :)

@ericpre ericpre merged commit c5f2d2e into hyperspy:RELEASE_next_minor Apr 15, 2020
@ericpre ericpre added this to the v1.6 milestone Apr 15, 2020
@tjof2 tjof2 deleted the improve-align-2d branch April 15, 2020 22:28
@ericpre ericpre mentioned this pull request Apr 20, 2020
4 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants