Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve lazy operations and Chunking #2617

Merged

Conversation

CSSFrancis
Copy link
Member

@CSSFrancis CSSFrancis commented Jan 8, 2021

Description of the change

I have been looking into the map function on lazy datasets and have started to realize that it isn't very well optimized. Currently it creates a bunch of Dask delayed objects from the Dask Array. It then either calculates each chunk many times or operates on only one signal at a time which is pretty slow.

This is the line which slows things down considerably...

all_delayed = [dd(func)(data) for data in zip(*iterators)]

  • I changed the function so that it uses the map_chunks function from dask
  • Constants are not longer passed as arrays of the same size as the input signal
  • Mapped arrays have the same chunking structure as the input array

Progress of the PR

  • Insert map_chunks
  • Map Chunks when using only constants
  • Map Chunks with Passed Signals
  • Allow for Ragged shapes
  • Allow for Changing signal shapes
  • update docstring (if appropriate),
  • update user guide (if appropriate),
  • add entry to CHANGES.rst (if appropriate),
  • add tests,
  • ready for review.

I still need to work on the ragged examples and changing signals. A lot of this code is modified from @magnunor so any input he has might be useful. I just figured that it might be worth putting in the effort to fix this now as opposed to later.

Minimal example of the bug fix or the new feature

import hyperspy.api as hs
import numpy as np
import time

test_data = hs.load("testdata.hspy", lazy=True) # (100,100,128,128) ~1 Gb

def multiply(data, data2, extra=1):
    return np.multiply(data,data2)

mult = hs.signals.Signal2D(np.ones((100,100,1,1)))*2

mapped = test_data.mapold(multiply,data2=mult, extra=2, ragged=False, inplace=False)
tic = time.time()
data = mapped.compute()
toc = time.time()
print("old time:", toc-tic)

mapped = test_data.mapnew(multiply,data2=mult, extra=2, ragged=False, inplace=False)
tic = time.time()
data = mapped.compute()
toc = time.time()
print("new time:", toc-tic)
old time: 375.81656885147095
new time:8.501422882080078

@codecov
Copy link

codecov bot commented Jan 8, 2021

Codecov Report

Merging #2617 (f8ab0ce) into RELEASE_next_patch (97a2075) will increase coverage by 0.03%.
The diff coverage is 98.58%.

Impacted file tree graph

@@                  Coverage Diff                   @@
##           RELEASE_next_patch    #2617      +/-   ##
======================================================
+ Coverage               76.94%   76.97%   +0.03%     
======================================================
  Files                     201      201              
  Lines                   29711    29782      +71     
  Branches                 6515     6536      +21     
======================================================
+ Hits                    22860    22926      +66     
- Misses                   5105     5107       +2     
- Partials                 1746     1749       +3     
Impacted Files Coverage Δ
hyperspy/_signals/hologram_image.py 87.89% <ø> (-4.22%) ⬇️
hyperspy/_signals/lazy.py 90.92% <97.43%> (+0.97%) ⬆️
hyperspy/_signals/signal1d.py 72.89% <100.00%> (-0.06%) ⬇️
hyperspy/_signals/signal2d.py 80.27% <100.00%> (+0.41%) ⬆️
hyperspy/misc/utils.py 85.94% <100.00%> (+0.49%) ⬆️
hyperspy/signal.py 76.05% <100.00%> (+0.16%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 97a2075...f8ab0ce. Read the comment docs.

@CSSFrancis
Copy link
Member Author

Some of these tests are still failing because there is a little bit less flexibility in how this function works now. This is mostly because the output signal size has to be explicit (or ragged) and the output data type has to be explicit as well. There is a possibility to guess at each of these things, but as the map function is more of a back end type function it might be better to allow the user to set it themselves?

I also allowed for ragged signal to be returned as lazy signals.

@ericpre
Copy link
Member

ericpre commented Jan 8, 2021

Can you please rebase on RELEASE_next_patch, so that it will possibly be released more quickly? See https://hyperspy.readthedocs.io/en/latest/dev_guide/git.html#changing-base-branch to rebase.

On a slightly different topic, some files seems to be deleted by mistakes!

@CSSFrancis CSSFrancis changed the base branch from RELEASE_next_minor to RELEASE_next_patch January 9, 2021 15:34
…still some errors which are related to weird edge cases
@CSSFrancis
Copy link
Member Author

Okay I'm still failing on a couple of tests, I can slog through them but I think that we need to be a little bit stricter with how we use the map function.

There are a couple of things that came up again and again when writing this and made what should have been a pretty easy fix much more complicated than it needed to be. I think some of this is older legacy code, but some of it might be newer as well. I was going to change the documentation to make things more explicit and if someone wants to comment to make sure that I understand how the map function (should) work in hyperspy that would be nice.

1 - In particular the _map_iterate function should be protected and only called by the map function. This helps catch cases where people are passing improper arguments to the _map_iterate function. The biggest error being made was passing the iterating_kwarg to the the map_function not as a BaseSignal but as an array which breaks the new code and is kind of a poor workaround for not really understanding the map function operation (3)

2 - Something else that I found frustrating was passing a BaseSignal in that doesn't have the same size as the navigation axes. I handled that case by just converting that to a numpy array if the navigation size is 0 and I threw an error if that doesn't work. Maybe this should be handled by the map function but it works well now so I wouldn't really touch it.

3 - This is related to the first two issues but in general the map function has two different operating principles. If a BaseSignal is passed to the signal and has the same shape navigation signal, then it is iterated alongside the signal. If a NumPy array is passed then the signal is assumed to be a constant and applied to every signal.

@ericpre
Copy link
Member

ericpre commented Jan 16, 2021

There are a couple of things that came up again and again when writing this and made what should have been a pretty easy fix much more complicated than it needed to be.

Indeed, this may need a bit of work, but it we clearly understand what is the issue, then it shouldn't be that much work. It will pay off very easily in the long term instead of using workaround here and there. Thanks for taking a stab at it!

A few comments without looking too much into details.

I have been looking into the map function on lazy datasets and have started to realize that it isn't very well optimized. Currently it creates a bunch of Dask delayed objects from the Dask Array. It then either calculates each chunk many times or operates on only one signal at a time which is pretty slow.

Just to make sure that I don't misunderstand the issue: this applies to a subset of use of the map method. When the axis keyword is provided and the _map_all method is used and in this is no room for improvement?

There is a possibility to guess at each of these things, but as the map function is more of a back end type function it might be better to allow the user to set it themselves?

Indeed, this makes sense to allow the user to set the data dtype (optional), otherwise it should be inferred.

1 - In particular the _map_iterate function should be protected and only called by the map function. This helps catch cases where people are passing improper arguments to the _map_iterate function. The biggest error being made was passing the iterating_kwarg to the the map_function not as a BaseSignal but as an array which breaks the new code and is kind of a poor workaround for not really understanding the map function operation (3)

If I understand correctly, the only things which matters, is that the length of the arguments to iterate on is the same as for the navigation axis. It means that it could a BaseSignal, a numpy array or a list, etc, as long as the length is correct.

2 - Something else that I found frustrating was passing a BaseSignal in that doesn't have the same size as the navigation axes. I handled that case by just converting that to a numpy array if the navigation size is 0 and I threw an error if that doesn't work. Maybe this should be handled by the map function but it works well now so I wouldn't really touch it.

This sounds like the flow control needs to be improved!

3 - This is related to the first two issues but in general the map function has two different operating principles. If a BaseSignal is passed to the signal and has the same shape navigation signal, then it is iterated alongside the signal. If a NumPy array is passed then the signal is assumed to be a constant and applied to every signal.

To avoid confusion, what do you mean with BaseSignal is passed to the signal?

Your started to add documentation to the dev guide, which is good! In case, there are some comments or explanation, it would be very good to add them in the user guide too. For example, how to use these efficiently and highlight the fact that some operation will be more efficient than other - as in the situation where the axis parameter can be provided. I am mentioning this in case, you come across along these lines.
On a similar topic, adding examples about how chunks is done, how they are process and what is happening is not only useful in this PR to make sure we understand what the issue is but it would also help user how it works!

@CSSFrancis
Copy link
Member Author

Just to make sure that I don't misunderstand the issue: this applies tTo avoid confusion, what do you mean with BaseSignal is passed to the signal?o a subset of use of the map method. When the axis keyword is provided and the _map_all method is used and in this is no room for improvement

Honestly I'm not really sure what the _map_all function does. It seems to just apply the function to the full dataset which is perfectly acceptable as long as the function properly handles lazy dataset. For the most part I would hope that it passes a lot of the dask processing stuff to the function, and maybe even there shouldn't be a lazy implementation of the map_all method?

If I understand correctly, the only things which matters, is that the length of the arguments to iterate on is the same as for the navigation axis. It means that it could a BaseSignal, a numpy array or a list, etc, as long as the length is correct.

So for the map function this isn't correct. The map function requires that if you want a signal to iterate alongside your function it must be a Basesignal with the same navigation axes. This is the best way to do things (in my opinion) because it allows you to be very specific with how the datasets allign and reduces errors because of how numpy, matplotlib and hyperspy order arrays.
What that Basesignal is can be as flexible as NumPy and hyperspy allow... Which is really flexible.

For the map_iterate function you can pass an iterating_kwargs variable which is the cause of a lot of headaches for me over the last couple of days. The old behavior was to take every argument and copy it so it was the same size as the dataset. You can kind of see why this might be a problem, especially with a large dataset you are creating a set of non lazy arguments that might be just as big as the dataset! By copying everything, it was kind of a catch all fix, allowed you to pass in numpy arrays to iterating_kwargs didn't care about navigation axes etc. In the end though I could see it being a huge problem because it is hard to be explicit about what should/ shouldn't iterate and how to align the navgiation axes between a numpy array and a hyperspy one.

The solution is just to tighten the restrictions we place on the map function and bring it more into line with the functions intended functionality. I think it might have just drifted mostly as a result of more people using the _map_iterate function when they probably shouldn't be.

To avoid confusion, what do you mean with BaseSignal is passed to the signal?

Good point, I'm not great at including examples sometimes.

def multiply(data1, value):
    # this function multiplys some array by some value. 
    return np.multiply(data1,value

import hyperspy.api as hs
import numpy as np

s = hs.signals.Signal2D(np.ones((2,3,2,3))) # 4 D dataset
s2 = np.reshape(np.arange(6),(2,3))
s2_signal = hs.signals.BaseSignal(s2)# 2D dataset all in signalaxes
s2_navigation = s2_signal.T

s.map(multiply, value=s2_navigation) # This iterates s2 alongside s1 a
s.map(multiply, value=s2_signal) # This would try to multiply s by the array s2 at every nav position
s.map(multiply, value=s2) # This would try to multiply s by the array s2 at every nav position

# now if we wanted to break things a little bit and use ``_map_iterate`` (what I wouldn't recommend) we could do the following
#This currently works but won't in the new code...
s._map_iterate(multiply, iterating_kwargs = (('value'),(s2))) # This iterates s2 alongside s1
s._map_iterate(multiply, value=s2) #  This would try to multiply s by the array s2 at every nav position

It might not be clear from this example why the first one is better, but from a consistency standpoint it is just better to always be dealing with Signals rather than arrays and lists. There is less ambiguity about what is being applied where. I can allow the second set of _map_iterate functions but I kind of question if forcing people to work with Signals might be the better option. Maybe there is a case for some object that doesn't fit into BaseSignal, but I don't really know if that case exsists.

if self.axes_manager.navigation_shape == () and self._lazy:
print("Converting signal to a non-lazy signal because there are no nav dimensions")
self.compute()
# Sepate ndkwargs depending on if they are BaseSignals.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be noted that this actually works with all signals which inherits BaseSignal. Thus, the current implementation works fine with Signal2D and Signal1D as well.

hyperspy/_signals/lazy.py Outdated Show resolved Hide resolved
@CSSFrancis
Copy link
Member Author

@magnunor

Testing some of this functionality now, and the dask array in the signal is changed when running map. Example:

import dask.array as da
import hyperspy.api as hs
def test_function(image):
    return image
dask_array = da.zeros((100, 110, 120, 130), chunks=(32, 32, 32, 32))
s = hs.signals.Signal2D(dask_array).as_lazy()
print(s.data.chunksize) # (32, 32, 32, 32)
s_out = s.map(function=test_function, inplace=False)
print(s.data.chunksize) # (32, 32, 120, 130)

I would expect the original dask array to be unchanged with inplace=False.

I'll probably have more feedback the next days, as I spend some time testing this.

I have fixed this but it involves making a copy of the signal when doing the rechunking so that the signal chunk spans the entire signal dimension. Hopefully that works well enough and the copying doesn't require too much memory or time.

@ericpre
Copy link
Member

ericpre commented Apr 2, 2021

This looks great to me, two things are left:

@CSSFrancis, any change to address the comments above. You can see the missing coverage at https://github.com/hyperspy/hyperspy/pull/2617/checks?check_run_id=2249986920 or in the PR diff https://github.com/hyperspy/hyperspy/pull/2617/files

@ericpre
Copy link
Member

ericpre commented Apr 2, 2021

The failure on azure pipeline is not related to this PR and should be sorted soon - see #2694 (comment).

@CSSFrancis
Copy link
Member Author

@ericpre I think that this should be good now. Let me know if there are any more changes that need to be made.

Copy link
Member

@ericpre ericpre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CSSFrancis, this looks good to me.

The failure on azure pipeline is due to a the tifffile package being broken on anaconda defaults channels - see AnacondaRecipes/tifffile-feedstock#2.

Copy link
Member

@francisco-dlp francisco-dlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants