A common use case for many modeling problems (e.g., in machine learning or climate science) is to center data by subtracting an average of some kind over a given axis. The dask scheduler currently falls flat on its face when attempting to schedule these types of problems.
Here's a simple example of such a fail case:
import dask.array as da
x = da.ones((8, 200, 200), chunks=(1, 200, 200)) # e.g., a large stack of image data
mad = abs(x - x.mean(axis=0)).mean()
The scheduler will keep each of the initial chunks in memory that it uses to compute the mean, because they will be used later to as an argument to sub. In contrast, the appropriate way to handle this graph to avoid blowing up memory would be to compute the initial chunks twice.
I know that in principle this could be avoided by using an on-disk cache. But this seems like a waste, because the initial values are often sitting in a file on disk, anyways.
This is a pretty typical use case for dask.array (one of the first things people try with xray), so it's worth seeing if we can come up with a solution that works by default.
If np.ones is added to the set of fast_functions, then the graph looks like:
This results in the desired scheduler behavior. In general, we can't assume that recomputing the initial chunks is fast, but for things like getarray, this should be fine. We should also somehow forward kwargs to optimize. dask.array.optimization.optimize accepts fast_functions as a keyword, but the scheduler get doesn't forward it.
Currently threaded.get, multiprocessing.get share some keywords, and have some other keywords that are specific to each scheduler. array.optimize takes a few keywords, while dataframe.optimize takes none. However, all of them have spots for **kwargs, which means that excess keywords are simply ignored. Thus, we could forward all keywords from expr.get(...) to both the call to optimize and get, and everything would be fine. Not sure if this is the best way, but it would work.
In general, it's probably not a good idea to assume that loading data from disk is "fast", although it's certainly a preferable alternative to exhausting memory.
It would be nice if we could setup dask to recompute chunks once they start to overflow some memory threshold, which might default to some fraction of the available system memory. The challenge then is figuring out which chunks to throw away. Cachey might have most of the appropriate logic for this.
It would be an interesting intellectual exercise on how to do this generally.
Any thoughts on how we could solve this problem if we tracked number of bytes of each output and computation times?
One thought would be to pass in a cache object to replace the dictionary that is used by default. Upon overflow, a decision could be made to drop a cheap result, with a callback on getitem setup to recompute it (based on the graph). A good metric might be dumping large things that would be quick to recompute from things currently in the cache (possibly min(C1*compute_time + C2/memory_used)). Could be done with a mix of callbacks, and a MutableMapping object.
min(C1*compute_time + C2/memory_used)
Yes, I think dask.cache/cachey already uses a roughly appropriate metric.