New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make dask.array.utils functions more generic to other Dask Arrays #10676
Conversation
This will be helpful for dask_expr.array, if we do that soon.
d0fc005
to
15c7c0f
Compare
dask/base.py
Outdated
if "xarray" in type(x).__module__: | ||
return x.__dask_graph__() is not None | ||
else: | ||
return hasattr(x, "__dask_graph__") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@phofl this change is actually kinda important. We were generating the task graph before whenever calling is_dask_collection
. Maybe caching was saving us here, but this could be unpleasantly expensive for dask-expr.
The history here is that xarray collections have a __dask_graph__
attribute, but are not always dask collections (they sometimes hold dask arrays but sometimes just have numpy arrays).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I remember that we discussed this a while back and ended at revisiting if necessary because caching worked fine back then. That change is nice anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pretty expensive for the xarray code path which really should just be something like
any(variable._data for variable in xarray_obj.variables)
Can we add something like __dask_is_present__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can add a new protocol. We'll have to do this everywhere, but it seems like an ok thing to do. In the meantime, let's just special-case something for Xarray. I think that it's the only project today where this matters.
I started doing this:
if "xarray" in type(x).__module__:
import xarray
if isinstance(x, xarray.Dataset):
return any(variable._data for variable in x.variables)
if isinstance(x, xarray.DataArray):
return bool(x._data)
But I'm not sure that _data
exists? Maybe we want any(is_dask_collection(variable.data) ...)
? Any suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this is it:
from dask.base import is_dask_collection
def is_dask_xarray(x):
import xarray
if isinstance(x, xarray.Dataset):
return any(
is_dask_collection(variable._data) for _, variable in x.variables.items()
)
elif isinstance(x, xarray.DataArray):
return is_dask_collection(x.variable._data)
elif isinstance(x, xarray.Variable):
return is_dask_collection(x._data)
import xarray as xr
ds = xr.tutorial.open_dataset("air_temperature", chunks="auto")
computed = ds.compute()
assert is_dask_xarray(ds)
assert is_dask_xarray(ds.air)
assert is_dask_xarray(ds.air.variable)
assert not is_dask_xarray(computed)
assert not is_dask_xarray(computed.air)
assert not is_dask_xarray(computed.air.variable)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also ran into the issue of materialization on __dask_graph__
when playing with the scheduler integration and I strongly recommend moving dask-expr away from this in favor of an explicit method doing that (that's what I am proposing in dask/dask-expr#294)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcherian I've pushed up your changes. Thanks. Using the air_temperature
dataset requires an optional dependency pooch
, which I was hesitant to add to CI just for this (although maybe I should if it's very lightweight). Alternatively, can you recommend a way to construct a dataset using maybe da.random.random
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fjetter I'm not sure yet what you mean in the comment above. We need a way to determine if an object is a dask collection. I looked through the PR you mentioned and didn't immediately something there that would solve this problem. Can you help me to understand?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mrocklin: ds = xr.Dataset({"air": (("time", "lat", "lon"), dask.array.ones((120, 120, 120)))})
@fjetter can I ask your team to handle this sometime next week? |
can you state what the expected outcome should be? I'm missing some context about what this change is enabling |
These changes enable code sharing between They are, I think, innocuous and should be easy to merge. |
dask/base.py
Outdated
elif isinstance(x, xarray.Variable): | ||
return is_dask_collection(x._data) | ||
else: | ||
raise TypeError("Unfamiliar with xarray type", type(x)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is suppressed a few lines below.
Can we reduce the scope of the try...except block to just the type(x).__module__
line?
dask/base.py
Outdated
and callable(x.__dask_graph__) | ||
and not isinstance(x, type) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR breaks pint, because just like xarray it violates the dask collection protocol by defining a __dask_graph__()
method that may return None:
>>> import dask.array as da, pint
>>> q = pint.Quantity(1, "m")
>>> q.__dask_graph__()
(None)
>>> q = pint.Quantity(da.zeros(5), "m")
>>> q.__dask_graph__()
HighLevelGraph with 1 layers.
<dask.highlevelgraph.HighLevelGraph object at 0x7f043c3889a0>
0. zeros_like-902bf3bcabaabb3fedce3f347fb26081
There may be more third-party libraries in the wild that do the same.
A blocklist approach would be more robust:
if (
isinstance(x, type)
or not hasattr(x, "__dask_graph__")
or not callable(x.__dask_graph__)
):
return False
pkg_name = getattr(type(x), "__module__", "").split(".")[0]
# Add here other third-party packages where you don't want
# to call the `__dask_graph__` method - typically because it's expensive
if pkg_name == "dask_expr":
return True
# xarray, pint, and possibly other libraries return None when they wrap a non-dask object
return x.__dask_graph__() is not None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The above, however, will mean that we call the __dask_graph__()
method of dask_expr.array objects when they're wrapped by xarray or pint. So I would suggest either:
- hybrid allowlist+blocklist, with final fallback on calling
__dask_graph__()
- change the
__dask_graph__()
method in dask_expr to be trivial (e.g. write a Mapping subclass that is trivial to initialize and is lazy until you actually try to access its contents)
dask/tests/test_base.py
Outdated
@@ -677,7 +677,6 @@ def __dask_graph__(self): | |||
|
|||
x = delayed(1) + 2 | |||
assert is_dask_collection(x) | |||
assert not is_dask_collection(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this removed?
dask/base.py
Outdated
if isinstance(x, xarray.Dataset): | ||
return any(is_dask_collection(v._data) for _, v in x.variables.items()) | ||
elif isinstance(x, xarray.DataArray): | ||
return is_dask_collection(x.variable._data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is returning a false negative when you have an eager variable and lazy non-index coordinates
>>> import dask.array as da, xarray
>>> xarray.DataArray([1,2], dims=["x"], coords={"x": [10, 20], "x2": ("x", da.zeros(2))})
<xarray.DataArray (x: 2)>
array([1, 2])
Coordinates:
* x (x) int64 10 20
x2 (x) float64 dask.array<chunksize=(2,), meta=np.ndarray>
dask/base.py
Outdated
import xarray | ||
|
||
if isinstance(x, xarray.Dataset): | ||
return any(is_dask_collection(v._data) for _, v in x.variables.items()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return any(is_dask_collection(v._data) for _, v in x.variables.items()) | |
return any(is_dask_collection(v._data) for v in x.variables.values()) |
dask/base.py
Outdated
elif isinstance(x, xarray.DataArray): | ||
return is_dask_collection(x.variable._data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif isinstance(x, xarray.DataArray): | |
return is_dask_collection(x.variable._data) | |
elif isinstance(x, xarray.DataArray): | |
return is_dask_collection(x.variable._data) or any(is_dask_collection(var._data) for var in x._coords.values()) |
Can we go back to the design table for a bit? What this PR tries to achieveChiefly, this PR tries to future-proof In main, For xarray, today it's a cheap-ish layer merge: The problems start if you call it on a dask_expr collection, which triggers a very expensive materialization of the whole graph: https://github.com/dask-contrib/dask-expr/blob/0fab3c19de0e54085119e182de92cbe5ec25f479/dask_expr/_core.py#L429-L445 This PR wants to avoid materialization in dask-expr by changing Why it's not trivialThe problem lies within wrapper libraries - xarray and pint are two, but there may be more I'm not aware of and potentially bespoke, unpublished ones too. Both xarray and pint define a As of today, the only ways for a. call their b. special-case them in I suggest two alternative designs: d. define a new API endpoint in the dask protocol, For example, for xarray.Dataset: def __dask_collections__(self) -> list[DaskCollection]:
from dask.base import is_dask_collection
return [v._data for v in self.variables.values() if is_dask_collection(v._data)] |
Offline comment from @fjetter:
|
In light of the above PRs, I'm switching from allowlist to blocklist, with the expectation that the PRs will be merged before wrapping dask_expr.Array inside xarray or pint becomes important. We can always re-add the allowlist later on. @mrocklin just wanted to make sure you're happy with this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ideally you shouldn't reach optimise either, that's also expensive if you read from remote storage for example
lgtm otherwise
@mrocklin I'm merging this. If you have comments happy to open a follow-up PR to address them. |
Grand. Thanks!
…On Thu, Jan 4, 2024 at 9:00 AM crusaderky ***@***.***> wrote:
Merged #10676 <#10676> into main.
—
Reply to this email directly, view it on GitHub
<#10676 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AACKZTBKC6FOL2KDNJATLL3YM3N4TAVCNFSM6AAAAABAIS2SLSVHI2DSMVQWIX3LMV45UABCJFZXG5LFIV3GK3TUJZXXI2LGNFRWC5DJN5XDWMJRGM4TENJZG42DMMA>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
This will be helpful for dask_expr.array, if we do that soon.
I also ran
py.test dask/arrays/tests/test_array_core.py
on both this branch and main and notice no slowdown