Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scheduler fail case: centering data with dask.array #874

Open
shoyer opened this issue Dec 9, 2015 · 42 comments
Open

Scheduler fail case: centering data with dask.array #874

shoyer opened this issue Dec 9, 2015 · 42 comments

Comments

@shoyer
Copy link
Member

@shoyer shoyer commented Dec 9, 2015

A common use case for many modeling problems (e.g., in machine learning or climate science) is to center data by subtracting an average of some kind over a given axis. The dask scheduler currently falls flat on its face when attempting to schedule these types of problems.

Here's a simple example of such a fail case:

import dask.array as da
x = da.ones((8, 200, 200), chunks=(1, 200, 200))  # e.g., a large stack of image data
mad = abs(x - x.mean(axis=0)).mean()
mad.visualize()

image

The scheduler will keep each of the initial chunks in memory that it uses to compute the mean, because they will be used later to as an argument to sub. In contrast, the appropriate way to handle this graph to avoid blowing up memory would be to compute the initial chunks twice.

I know that in principle this could be avoided by using an on-disk cache. But this seems like a waste, because the initial values are often sitting in a file on disk, anyways.

This is a pretty typical use case for dask.array (one of the first things people try with xray), so it's worth seeing if we can come up with a solution that works by default.

@jcrist
Copy link
Member

@jcrist jcrist commented Dec 9, 2015

If np.ones is added to the set of fast_functions, then the graph looks like:

mydask

This results in the desired scheduler behavior. In general, we can't assume that recomputing the initial chunks is fast, but for things like getarray, this should be fine. We should also somehow forward kwargs to optimize. dask.array.optimization.optimize accepts fast_functions as a keyword, but the scheduler get doesn't forward it.

@jcrist
Copy link
Member

@jcrist jcrist commented Dec 9, 2015

Currently threaded.get, multiprocessing.get share some keywords, and have some other keywords that are specific to each scheduler. array.optimize takes a few keywords, while dataframe.optimize takes none. However, all of them have spots for **kwargs, which means that excess keywords are simply ignored. Thus, we could forward all keywords from expr.get(...) to both the call to optimize and get, and everything would be fine. Not sure if this is the best way, but it would work.

@shoyer
Copy link
Member Author

@shoyer shoyer commented Dec 9, 2015

In general, it's probably not a good idea to assume that loading data from disk is "fast", although it's certainly a preferable alternative to exhausting memory.

It would be nice if we could setup dask to recompute chunks once they start to overflow some memory threshold, which might default to some fraction of the available system memory. The challenge then is figuring out which chunks to throw away. Cachey might have most of the appropriate logic for this.

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Dec 9, 2015

It would be an interesting intellectual exercise on how to do this generally.

Any thoughts on how we could solve this problem if we tracked number of bytes of each output and computation times?

@jcrist
Copy link
Member

@jcrist jcrist commented Dec 9, 2015

One thought would be to pass in a cache object to replace the dictionary that is used by default. Upon overflow, a decision could be made to drop a cheap result, with a callback on getitem setup to recompute it (based on the graph). A good metric might be dumping large things that would be quick to recompute from things currently in the cache (possibly min(C1*compute_time + C2/memory_used)). Could be done with a mix of callbacks, and a MutableMapping object.

@shoyer
Copy link
Member Author

@shoyer shoyer commented Dec 9, 2015

Yes, I think dask.cache/cachey already uses a roughly appropriate metric.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

From http://xarray.pydata.org/en/stable/dask.html

in practice this is a fail case for the dask scheduler, because it tries to keep every chunk of an array that it computes in memory

@shoyer, does this imply that memory usage for computation of the mean is linearly dependent upon number of the unlimited dimension? We are having challenging memory issues with computing the mean for large spatial resolution datasets for short and long time record datasets.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

In xarray-space is there a way I can get the same behavior as a
ds = xr.open_dataset(ds.to_netcdf(tmpfile))
without actually resorting to this type of trick, e.g., can we just explicitly force dask to flush its memory cache or something?

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

Computing the mean shouldn't be hard. The problem arises when you have a large array interact with the mean, like the following:

(x - x.mean()).sum().compute()

The current workaround is to explicitly compute the mean beforehand

(x - x.mean().compute()).sum().compute()

This will require two passes over the data

Note that this problem should only be an issue if you don't have enough memory. If you were on a distributed system this might not be an issue.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

@mrocklin, the problem is exactly as you imply but we have the limitation that we are trying really hard to keep compute cost down so dask.distributed isn't ideal in this application. It certainly is ideal and arguably necessary for analysis with particles and not Eulerian means, although I'm not quite there yet but should be working on that in the next several months (likely June or July).

For (x - x.mean().compute()).sum().compute() above what this means is that we would have to have enough memory to store x.mean() twice plus say 2-3X a chunk size, correct?

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

Storing x.mean() twice is 16 bytes. Given your question I'm guessing that you're storing x.mean(axis=...). Are you having trouble computing just that value or are you having trouble computing just the reduction or are you having trouble computing a compound expression like (x - x.mean().compute()).sum().compute()?

You should always have significantly more memory than your chunk size. How much RAM do you have? What are your chunk sizes?

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

For reference. My notebook has 16GB of memory and I probably shoot for chunk sizes in the 100 MB range. There is about a 100x gap there.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

Thanks @mrocklin for the quick reply. We are storing x.mean(axis=...) along the time axis, which has large variability from O(10) to O(1000). Other dimensions are spatial and fixed.

We have had both problems but in its simplest form we are having challenges computing just the mean and not having it exceed memory. Fundamentally, I think I'm confused about the rules related to chunk size and total memory available. Does the computation performed indicate the ratio of chunk size to total memory required? It seems like this is the case but this is not obvious how to estimate this ratio. Ultimately, it would be good to have a rule of thumb that we can add to http://xarray.pydata.org/en/stable/dask.html.

To clarify, how many chunks get loaded into memory for computation of (mixed xarray and dask-esq pseudo code):

x.shape = (Nx, Ny, Nt)
x.chunks = {'Nx': Nx, 'Ny': Ny, 'Nt' : Nt/10}
x.mean('Nt')

Does this mean that 10 chunks are loaded into memory to compute the mean because chunk size on Nt is Nt/10? So, if Nx*Ny*Nt/10 = O(RAM) this likely causes a memory error, correct? The metric you give above would imply this is the case but that we really need Nt/(10*100) to be the chunk size for this case. Obviously the scaling here is non-optimal because as Nt grows, we need to make smaller and smaller chunk sizes and we pay more and more for chunking overhead.

I suspect I'm wrong many please here and please set me on the right conceptual path.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

For reference. My notebook has 16GB of memory and I probably shoot for chunk sizes in the 100 MB range. There is about a 100x gap there.

@mrocklin, this is just a good general heuristic, right? Or is this specific to the (x - x.mean().compute()).sum().compute() case?

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

Pessimistically dask will use something like ncores * chunksize + reduced chunksize * nchunks where in your case reduced chunksize is probably Nx * Ny. Dask should do better than this, and should reduce the intermediates as they arrive. You might play with the split_every= keyword here.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

What I'm getting at is that there are two ways to compute the mean that I can forsee:

  1. Compute it with a self-reducing tree: average pairs, averaged produced pairs, etc
  2. Compute it incrementally (like the naive formula):
sum_{i=1}^n xtot = x_i
xtot /= n

Does dask do 1 or 2 or some combination or alternative strategy?

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

Strategy 1. However it tries to walk that tree in a depth-first way, rather than producing all of the leaves first.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

Strategy 1. However it tries to walk that tree in a depth-first way, rather than producing all of the leaves first.

That is just for performance because it is the optimal way to compute on the tree, right? That way some leaves are effectively never needed and we can average a leaf into some averaged-average, correct?

Ok, this is starting to make sense now. nco probably does 2, which is why it sometimes seems to work better because if file access is really expensive to access at a particular point, e.g., get the data chunk from a file, approach 2 will actually make more sense because there will be fewer cache misses overall and the primary cost is getting data from disk into memory.

Do we have freedom or recourse to force dask to minimize the data accesses somehow (optimally force strategy 2 so that adjacent times are computed first)? The issue is that we never get close to using all the cores because we are so disk / memory limited. Maybe this is a set up problem on our end.

I fully recognize I'm asking a lot that is probably not in the dask design here because there is probably a latent assumption that most of the data is already in memory because dask is supposed to be more of a threaded/distributed numpy, where it would be unwise to do all the loading and unloading of data in and out of the data structures anyway.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

Is a performance solution to artificially limit ncores and force the mean computation to span Nx and Ny before spanning Nt? I think this access pattern will yield better performance to help reduce cache misses on file reads but I don't know if this is a practical solution.

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

If you were to set split_every=2 and use the synchronous scheduler then it would probably mimic option 1

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

If you're also computing the mean across Nx and Ny then I wouldn't anticipate a problem. The intermediates would be very small. I don't think that this is what you're asking though.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

No, the problem is we just need a time mean. If data access were free we could do this with some type of SIMD kernel, e.g. at a point just average the time dimension. We just have slow drives on HPC. If we had more compute time dask.distributed would make sense and we could effectively get this to scale like on a laptop, but it doesn't we can't afford the cycles for the analysis so this obvious solution is inapplicable for our case.

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

Does this work for you?

If you were to set split_every=2 and use the synchronous scheduler then it would probably mimic option 1

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

I haven't tried it yet. I'm not sure I understand how to set this option in xarray for the computation to be 100% honest. Advice on this would greatly be appreciated. Is this one of those times I need to convert to dask from xarray first?

Also, is the split going to be over the non-reduced direction first? If so, this is the best solution but I can hack it by setting chunk size to be over the entire spatial dimension. The problem is that we could have more performance hurdles at larger scales if we don't compute over Nx and Ny first, and then over adjacent time steps. So, this fixes today's problem but we could still have a latent issue.

@shoyer
Copy link
Member Author

@shoyer shoyer commented Apr 7, 2017

Xarray actually passes **kwargs on to the dask function from aggregations like method mean(), though it isn't documented. So ds.mean(split_every=2) should work.

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

Xarray actually passes **kwargs on to the dask function from aggregations like method mean(), though it isn't documented. So ds.mean(split_every=2) should work.

@shoyer do you think that this is a feature that will disappear sometime soon? E.g., is it safe to use?

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

Thanks for all the help with good syntax here @shoyer and @mrocklin, it is really appreciated.

@shoyer
Copy link
Member Author

@shoyer shoyer commented Apr 7, 2017

do you think that this is a feature that will disappear sometime soon? E.g., is it safe to use?

Yes, I think this is safe, though we should certainly document it (and test it, if it isn't tested already). See pydata/xarray#1360

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

I would use ds.mean(split_every=2, get=dask.async.get_sync), right? Even with the default multiprocessing scheduler it looks like memory is much better constrained with ds.mean(split_every=2).

@mrocklin
Copy link
Member

@mrocklin mrocklin commented Apr 7, 2017

@shoyer
Copy link
Member Author

@shoyer shoyer commented Apr 7, 2017

@pwolfram
Copy link
Contributor

@pwolfram pwolfram commented Apr 7, 2017

It looks like we might be able to get by without these more complex approaches. We had some non-cached code (https://github.com/MPAS-Dev/MPAS-Analysis/blob/develop/mpas_analysis/shared/climatology/climatology.py#L595) that appears to be resulting in a non-optimal xarray/dask lazy computation that produces a memory overrun. I'm clearly not fully versed as I should be in dask because it wasn't obvious to me that this code would cause problems when I merged it, but following today it seems reasonable to suspect that combinations of lazy reductions are dangerous and should be "barriered" via a .compute(), .persist(), or .load().

The computation with a simple mean, in contrast, appears to be readily solved by chunking, which is consistent with previous explanations and usage of xarray. So the issue appears to thankfully be on our side.

Many thanks to @mrocklin and @shoyer for the special help with commands to keep the computations in check. split_every is a great way to keep memory in check but ultimately I don't think we can pay the performance penalty of not using the full optimized dask reduction algorithm in the mean or a sum.

bmerry added a commit to ska-sa/katsdppipelines that referenced this issue Jul 24, 2017
The main reason is to parallelise execution of stefcal. It could still
be a bit of a memory hog: see dask/dask#874

The phase normalisation is changed a bit to avoid using a median, which
is not available in dask. Instead, the angles are wrapped such that the
branch cut is on the side opposite an arbitrary data point, and then the
mean is taken. Provided all the phases are clustered within pi of each
other this should be just as good, but it is less robust to a bad
outlier (particularly since the "arbitrary" point chosen is the minimum
initial angle).
bmerry added a commit to ska-sa/katsdppipelines that referenced this issue Jul 24, 2017
The main reason is to parallelise execution of stefcal. It could still
be a bit of a memory hog: see dask/dask#874

The phase normalisation is changed a bit to avoid using a median, which
is not available in dask. Instead, the angles are wrapped such that the
branch cut is on the side opposite an arbitrary data point, and then the
mean is taken. Provided all the phases are clustered within pi of each
other this should be just as good, but it is less robust to a bad
outlier (particularly since the "arbitrary" point chosen is the minimum
initial angle).
@rabernat
Copy link
Contributor

@rabernat rabernat commented Jan 16, 2018

Looks like I'm late to the party on this important issue (I discovered it via pydata/xarray#1832).

Is this considered resolved with the split_every solution? Or is there still work to be done to make the default behavior less error-prone?

@jakirkham
Copy link
Member

@jakirkham jakirkham commented Jul 20, 2018

In xarray-space is there a way I can get the same behavior as a
ds = xr.open_dataset(ds.to_netcdf(tmpfile))
without actually resorting to this type of trick, e.g., can we just explicitly force dask to flush its memory cache or something?

Not sure if this is still interesting to you, @pwolfram or other, but figured I'd leave this note in case it was.

It's definitely possible to check-in a result on disk. Dask Array's store has a return_stored option (available in Dask 0.16.1+). By default it is False, but if it is set to True, it returns a Dask Array that can load the stored data back from disk. This uses persist behind the scenes, which is blocking for the single machine scheduler, but is non-blocking when using Distributed.

Below is some sample code demonstrating it's sample usage with Zarr. HTH

In [1]: import dask.array as da

In [2]: import zarr

In [3]: a = da.random.random((10, 10), chunks=(4, 4))

In [4]: z = zarr.open_array("data.zarr", shape=a.shape, dtype=a.dtype, chunks=(4
   ...: , 4))

In [5]: a2 = a.store(z, return_stored=True)

In [6]: a2
Out[6]: dask.array<load-store-50dda186-8bd9-11e8-970a, shape=(10, 10), dtype=float64, chunksize=(4, 4)>

Edit: This is also now documented in this doc section.

ludwigschwardt pushed a commit to ska-sa/katsdpcal that referenced this issue Oct 25, 2019
The main reason is to parallelise execution of stefcal. It could still
be a bit of a memory hog: see dask/dask#874

The phase normalisation is changed a bit to avoid using a median, which
is not available in dask. Instead, the angles are wrapped such that the
branch cut is on the side opposite an arbitrary data point, and then the
mean is taken. Provided all the phases are clustered within pi of each
other this should be just as good, but it is less robust to a bad
outlier (particularly since the "arbitrary" point chosen is the minimum
initial angle).
@shoyer
Copy link
Member Author

@shoyer shoyer commented Sep 25, 2020

I know that persisting data on disk (or in memory) is the currently recommended solution here, but it occurred to me that another good tool to have for fixing these sorts of issues would be a way to explicitly "duplicate" all of the nodes in a the graph underlying a dask collection. That way, there couldn't be any undesirable caching, because they would be totally separate tasks.

I opened a new issue to discuss this proposal: #6674

@jakirkham
Copy link
Member

@jakirkham jakirkham commented Sep 25, 2020

Something else people here might want to play with is graphchain. Admittedly this is still storing results in memory, but it does so without one needing to specify what to persist.

@jakirkham
Copy link
Member

@jakirkham jakirkham commented Mar 8, 2021

People following this issue may be interested in advanced graph manipulation functionality that has been added somewhat recently

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Linked pull requests

Successfully merging a pull request may close this issue.

None yet
6 participants