Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dask.array.utils functions more generic to other Dask Arrays #10676

Merged
merged 12 commits into from Jan 4, 2024

Conversation

mrocklin
Copy link
Member

@mrocklin mrocklin commented Dec 6, 2023

This will be helpful for dask_expr.array, if we do that soon.

I also ran py.test dask/arrays/tests/test_array_core.py on both this branch and main and notice no slowdown

This will be helpful for dask_expr.array, if we do that soon.
@github-actions github-actions bot added the array label Dec 6, 2023
@mrocklin mrocklin mentioned this pull request Dec 6, 2023
13 tasks
dask/base.py Outdated
Comment on lines 214 to 217
if "xarray" in type(x).__module__:
return x.__dask_graph__() is not None
else:
return hasattr(x, "__dask_graph__")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@phofl this change is actually kinda important. We were generating the task graph before whenever calling is_dask_collection. Maybe caching was saving us here, but this could be unpleasantly expensive for dask-expr.

The history here is that xarray collections have a __dask_graph__ attribute, but are not always dask collections (they sometimes hold dask arrays but sometimes just have numpy arrays).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I remember that we discussed this a while back and ended at revisiting if necessary because caching worked fine back then. That change is nice anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems pretty expensive for the xarray code path which really should just be something like

any(variable._data for variable in xarray_obj.variables)

Can we add something like __dask_is_present__?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can add a new protocol. We'll have to do this everywhere, but it seems like an ok thing to do. In the meantime, let's just special-case something for Xarray. I think that it's the only project today where this matters.

I started doing this:

        if "xarray" in type(x).__module__:
            import xarray
            if isinstance(x, xarray.Dataset):
                return any(variable._data for variable in x.variables)
            if isinstance(x, xarray.DataArray):
                return bool(x._data)

But I'm not sure that _data exists? Maybe we want any(is_dask_collection(variable.data) ...)? Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this is it:

from dask.base import is_dask_collection


def is_dask_xarray(x):
    import xarray

    if isinstance(x, xarray.Dataset):
        return any(
            is_dask_collection(variable._data) for _, variable in x.variables.items()
        )
    elif isinstance(x, xarray.DataArray):
        return is_dask_collection(x.variable._data)
    elif isinstance(x, xarray.Variable):
        return is_dask_collection(x._data)


import xarray as xr

ds = xr.tutorial.open_dataset("air_temperature", chunks="auto")
computed = ds.compute()

assert is_dask_xarray(ds)
assert is_dask_xarray(ds.air)
assert is_dask_xarray(ds.air.variable)
assert not is_dask_xarray(computed)
assert not is_dask_xarray(computed.air)
assert not is_dask_xarray(computed.air.variable)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also ran into the issue of materialization on __dask_graph__ when playing with the scheduler integration and I strongly recommend moving dask-expr away from this in favor of an explicit method doing that (that's what I am proposing in dask/dask-expr#294)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dcherian I've pushed up your changes. Thanks. Using the air_temperature dataset requires an optional dependency pooch, which I was hesitant to add to CI just for this (although maybe I should if it's very lightweight). Alternatively, can you recommend a way to construct a dataset using maybe da.random.random?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fjetter I'm not sure yet what you mean in the comment above. We need a way to determine if an object is a dask collection. I looked through the PR you mentioned and didn't immediately something there that would solve this problem. Can you help me to understand?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mrocklin: ds = xr.Dataset({"air": (("time", "lat", "lon"), dask.array.ones((120, 120, 120)))})

@mrocklin
Copy link
Member Author

@fjetter can I ask your team to handle this sometime next week?

@fjetter
Copy link
Member

fjetter commented Dec 18, 2023

can I ask your team to handle this sometime next week?

can you state what the expected outcome should be? I'm missing some context about what this change is enabling

@mrocklin
Copy link
Member Author

I'm missing some context about what this change is enabling

These changes enable code sharing between dask.array utilities and dask_expr.array (currently in a branch/PR)

They are, I think, innocuous and should be easy to merge.

dask/base.py Outdated
elif isinstance(x, xarray.Variable):
return is_dask_collection(x._data)
else:
raise TypeError("Unfamiliar with xarray type", type(x))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is suppressed a few lines below.
Can we reduce the scope of the try...except block to just the type(x).__module__ line?

dask/base.py Outdated
and callable(x.__dask_graph__)
and not isinstance(x, type)
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR breaks pint, because just like xarray it violates the dask collection protocol by defining a __dask_graph__() method that may return None:

>>> import dask.array as da, pint
>>> q = pint.Quantity(1, "m")
>>> q.__dask_graph__()
(None)
>>> q = pint.Quantity(da.zeros(5), "m")
>>> q.__dask_graph__()
HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f043c3889a0>
 0. zeros_like-902bf3bcabaabb3fedce3f347fb26081

There may be more third-party libraries in the wild that do the same.
A blocklist approach would be more robust:

if (
    isinstance(x, type)
    or not hasattr(x, "__dask_graph__")
    or not callable(x.__dask_graph__)
):
    return False

pkg_name = getattr(type(x), "__module__", "").split(".")[0]
# Add here other third-party packages where you don't want
# to call the `__dask_graph__` method - typically because it's expensive
if pkg_name == "dask_expr":
    return True

# xarray, pint, and possibly other libraries return None when they wrap a non-dask object
return x.__dask_graph__() is not None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above, however, will mean that we call the __dask_graph__() method of dask_expr.array objects when they're wrapped by xarray or pint. So I would suggest either:

  1. hybrid allowlist+blocklist, with final fallback on calling __dask_graph__()
  2. change the __dask_graph__() method in dask_expr to be trivial (e.g. write a Mapping subclass that is trivial to initialize and is lazy until you actually try to access its contents)

@@ -677,7 +677,6 @@ def __dask_graph__(self):

x = delayed(1) + 2
assert is_dask_collection(x)
assert not is_dask_collection(2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this removed?

dask/base.py Outdated
if isinstance(x, xarray.Dataset):
return any(is_dask_collection(v._data) for _, v in x.variables.items())
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is returning a false negative when you have an eager variable and lazy non-index coordinates

>>> import dask.array as da, xarray
>>> xarray.DataArray([1,2], dims=["x"], coords={"x": [10, 20], "x2": ("x", da.zeros(2))})
<xarray.DataArray (x: 2)>
array([1, 2])
Coordinates:
  * x        (x) int64 10 20
    x2       (x) float64 dask.array<chunksize=(2,), meta=np.ndarray>

dask/base.py Outdated
import xarray

if isinstance(x, xarray.Dataset):
return any(is_dask_collection(v._data) for _, v in x.variables.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return any(is_dask_collection(v._data) for _, v in x.variables.items())
return any(is_dask_collection(v._data) for v in x.variables.values())

dask/base.py Outdated
Comment on lines 219 to 220
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data)
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data) or any(is_dask_collection(var._data) for var in x._coords.values())

@crusaderky
Copy link
Collaborator

Can we go back to the design table for a bit?

What this PR tries to achieve

Chiefly, this PR tries to future-proof is_dask_collection for when we'll migrate dask.array to dask-expr.

In main, is_dask_collection(x) calls x.__dask_graph__().
For dask.array, dask.dataframe, dask.bag, and dask.delayed this is a trivial attribute fetch.

For xarray, today it's a cheap-ish layer merge:
https://github.com/pydata/xarray/blob/b4444388cb0647c4375d6a364290e4fa5e5f94ba/xarray/core/dataset.py#L872-L881

The problems start if you call it on a dask_expr collection, which triggers a very expensive materialization of the whole graph: https://github.com/dask-contrib/dask-expr/blob/0fab3c19de0e54085119e182de92cbe5ec25f479/dask_expr/_core.py#L429-L445

This PR wants to avoid materialization in dask-expr by changing is_dask_collection not to call __dask_graph__, but just check for its existence.

Why it's not trivial

The problem lies within wrapper libraries - xarray and pint are two, but there may be more I'm not aware of and potentially bespoke, unpublished ones too.

Both xarray and pint define a __dask_graph__ method; however it may return None if they wrap non-dask objects exclusively. Code for xarray is linked above; pint is easier as it wraps around zero or one dask collections: https://github.com/hgrecco/pint/blob/83bffe1df0f18acd451ec5d4622442bdfd2d10f5/pint/facets/dask/__init__.py#L43C1-L48

As of today, the only ways for is_dask_collection to deal with wrapper collections is either

a. call their __dask_graph__ method, which would cause an expensive call to dask_expr.array.Array.__dask_graph__. This is clearly not a viable option.

b. special-case them in dask.base.is_dask_collection, as shown in this PR for xarray only.
I don't like this option because dask should not know about the existence of pint, and it would simply not work with more obscure, possibly closed source, packages that wrap around dask.

I suggest two alternative designs:
c. dask_expr.Expr.__dask_graph__() returns a lazy Mapping subclass, which can be instantiated trivially and only materializes the graph when you call a method such as __getitem__ or __len__. Third party collections that wrap around multiple dask collections, such as xarray, would need to write ad-hoc code to merge multiple such graphs without materializing - much like it already happens with HighLevelGraph: https://github.com/pydata/xarray/blob/b4444388cb0647c4375d6a364290e4fa5e5f94ba/xarray/core/dataset.py#L872-L881

d. define a new API endpoint in the dask protocol, __dask_collections__, which in the trivial case returns [self] but in case of xarray it may return an empty list or multiple underlying dask arrays.

For example, for xarray.Dataset:

def __dask_collections__(self) -> list[DaskCollection]:
    from dask.base import is_dask_collection

    return [v._data for v in self.variables.values() if is_dask_collection(v._data)]

@crusaderky
Copy link
Collaborator

Offline comment from @fjetter:

I will nuke all of this with

xarray, pint will have to migrate to this so we can adjust this accordingly. If xarray truly needs a dynamic switch we probably want to add a dedicated method for this.
Either way, I’m not overly concerned about this PR as long as it doesn’t break anything publicly. The behavior for dask-expr will change anyhow

@crusaderky
Copy link
Collaborator

crusaderky commented Dec 22, 2023

In light of the above PRs, I'm switching from allowlist to blocklist, with the expectation that the PRs will be merged before wrapping dask_expr.Array inside xarray or pint becomes important. We can always re-add the allowlist later on.

@mrocklin just wanted to make sure you're happy with this.

@crusaderky crusaderky self-requested a review December 22, 2023 16:06
dask/tests/test_base.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@phofl phofl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally you shouldn't reach optimise either, that's also expensive if you read from remote storage for example

lgtm otherwise

@crusaderky
Copy link
Collaborator

@mrocklin I'm merging this. If you have comments happy to open a follow-up PR to address them.

@crusaderky crusaderky merged commit 6d5b994 into dask:main Jan 4, 2024
23 of 26 checks passed
@mrocklin
Copy link
Member Author

mrocklin commented Jan 4, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants