Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dask.order causing high memory pressure for multi array compute calls (commonly used in xarray) #10384

Closed
fjetter opened this issue Jun 29, 2023 · 2 comments · Fixed by #10535
Labels
needs attention It's been a while since this was pushed on. Needs attention from the owner or a maintainer. needs triage Needs a response from a contributor

Comments

@fjetter
Copy link
Member

fjetter commented Jun 29, 2023

There is a simple xarray example that shows how memory pressure is building up due to improper dask ordering. See also pangeo-data/distributed-array-examples#2

import xarray as xr
import dask.array as da

ds = xr.Dataset(
    dict(
        anom_u=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
        anom_v=(["time", "face", "j", "i"], da.random.random((5000, 1, 987, 1920), chunks=(10, 1, -1, -1))),
    )
)

quad = ds**2
quad["uv"] = ds.anom_u * ds.anom_v
mean = quad.mean("time")

This example generates about ~140GiB of random data but the mean at the end is reducing this to a couple KiB. This is done by using an ordinary tree reduction. That's easy and cheap and could almost be done on a raspberry pi. However, this thing blows up and it is almost impossible to run this on real data.

(Disclaimer: I'm not an xarray expert so if any drives by, ignore my ignorance if I'm saying something that is wrong)

The mean object above is a xarray.Dataset which is essentially a collection that holds multiple dask.arrays (in this example). When calling compute, all of those arrays are computed simultaneously. Effectively, this is similar to a dask.compute(array1, array2, array3)

In fact, when computing just a single array (e.g. mean['uv'].compute()) this is running wonderfully, is reducing the result immediately and no data generator task is held in memory for any significant amount of time.

However, when executing all of the above, the ordering breaks down and the random data gen tasks are held in memory for a very long time.

Running dask.order.diagnostics on this shows pressure (i.e. how long are certain tasks held in memory) of 267 on average with a maximum of 511. (compared to the single array reduction where we have 7 and 14 respectively).

The entire graph is too large to render but a scaled down version of it shows this (showing order-age) effect as well (albeit much smaller since the graph is smaller, of course)

quadratic_mean

(In the single-array tree reduction, the age of the data generators is somewhere between one and three)

I believe I was able to reduce this to a minimal dask.order example

from dask.base import visualize
a, b, c, d, e = list("abcde")
def f(*args):
    ...
dsk = {}
for ix in range(3):
    part = {
        # Part1
        (a, 0, ix): (f, ),
        (a, 1, ix): (f, ),
        (b, 0, ix): (f, (a, 0, ix)),
        (b, 1, ix): (f, (a, 0, ix), (a, 1, ix)),
        (b, 2, ix): (f, (a, 1, ix)),
        (c, 0, ix): (f, (b, 0, ix)),
        (c, 1, ix): (f, (b, 1, ix)),
        (c, 2, ix): (f, (b, 2, ix)),
    }
    dsk.update(part)
for ix in range(3):
    dsk.update({
        (d, ix): (f, (c, ix, 0), (c, ix, 1), (c, ix, 2)),
    })

raw-graph

A way to compute the result that generates a slightly differnet version of the graph that can be handled better by dask.order is possible by using mean.to_dask_dataframe(). The DataFrame version of this reduces also quite well.

cc @eriknw

@github-actions github-actions bot added the needs triage Needs a response from a contributor label Jun 29, 2023
@hendrikmakait hendrikmakait changed the title Dask.order causing high memory pressure for for xarray datasets Dask.order causing high memory pressure for xarray datasets Jul 4, 2023
@github-actions github-actions bot added the needs attention It's been a while since this was pushed on. Needs attention from the owner or a maintainer. label Aug 7, 2023
@fjetter
Copy link
Member Author

fjetter commented Sep 12, 2023

Interesting observation: When I alter the size of the 0th dimension (5000 in the example; i.e. task graph size) I do see a linear growth in the pressure (i.e. the minimal number of tasks that have to be held in memory if we were to compute the graph sequentially)

(What is currently proposed in #10505 does not change a thing)

image

If we convert to dask.dataframe using mean.to_dask_dataframe() this has a nice N logN growth which is just fine (note the scale is an order of magnitude smaller)

image

@fjetter
Copy link
Member Author

fjetter commented Sep 25, 2023

Pure dask reproducer w/out xarray

import dask.array as da
anom_u = da.random.random((50, 1, 987, 1920), chunks=(10, 1, -1, -1))
anom_v = da.random.random((50, 1, 987, 1920), chunks=(10, 1, -1, -1))
anom_u *= anom_u 
anom_v *= anom_v
uv = anom_u * anom_v
dask.visualize(anom_u.mean(), anom_v.mean(), uv.mean(), color='order-age')

@fjetter fjetter changed the title Dask.order causing high memory pressure for xarray datasets Dask.order causing high memory pressure for multi array compute calls (commonly used in xarray) Sep 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs attention It's been a while since this was pushed on. Needs attention from the owner or a maintainer. needs triage Needs a response from a contributor
Projects
None yet
1 participant