Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable basic p2p shuffle for dask-cudf #7743

Merged
merged 33 commits into from Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7210c38
basic support for cudf-backed collection in p2p shuffle
rjzamora Apr 3, 2023
b03c132
use schema metadata
rjzamora Apr 4, 2023
0e86785
avoid leaving worker_for as dict
rjzamora Apr 4, 2023
ce23c47
avoid updating metadata for pandas
rjzamora Apr 4, 2023
bdbf05a
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Apr 6, 2023
b60a1da
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora May 16, 2023
cc0f0cc
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora May 24, 2023
ad57395
use new get_meta_library utility
rjzamora May 24, 2023
964b57a
linting
rjzamora May 24, 2023
b088b25
save state
rjzamora May 31, 2023
0a57f59
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora May 31, 2023
f8123fb
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 2, 2023
dd61398
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 8, 2023
20616bb
add test
rjzamora Jun 8, 2023
d5141d6
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 9, 2023
d39a774
leverage meta instead of custom pyarrow metadata
rjzamora Jun 9, 2023
def693a
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 13, 2023
f605ac9
use dispatch functions
rjzamora Jun 14, 2023
ef37b5b
clarify error message
rjzamora Jun 14, 2023
d87db60
check attr
rjzamora Jun 14, 2023
d171dc2
catch importerror in test
rjzamora Jun 14, 2023
cbd3026
assume latest version of dask
rjzamora Jun 14, 2023
dac7291
back to original imports
rjzamora Jun 14, 2023
cb6fce1
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 20, 2023
aae562a
ignore dispatch warnings
rjzamora Jun 20, 2023
af14903
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 21, 2023
338ecff
move imports
rjzamora Jun 21, 2023
3da3a05
use _constructor_sliced
rjzamora Jun 22, 2023
f48a8ca
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jun 22, 2023
b049a84
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Jul 24, 2023
b221af5
DataFrame constructor bugfix
rjzamora Jul 24, 2023
99232e7
mypy workaround
rjzamora Jul 24, 2023
50beefc
Merge remote-tracking branch 'upstream/main' into p2p-cudf-support
rjzamora Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion distributed/shuffle/_arrow.py
Expand Up @@ -49,6 +49,8 @@
import pandas as pd
import pyarrow as pa

from dask.dataframe.dispatch import from_pyarrow_table_dispatch

Check warning on line 52 in distributed/shuffle/_arrow.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_arrow.py#L52

Added line #L52 was not covered by tests

file = BytesIO(data)
end = len(data)
shards = []
Expand All @@ -67,7 +69,9 @@
return pd.StringDtype("pyarrow")
return None

df = table.to_pandas(self_destruct=True, types_mapper=default_types_mapper)
df = from_pyarrow_table_dispatch(

Check warning on line 72 in distributed/shuffle/_arrow.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_arrow.py#L72

Added line #L72 was not covered by tests
meta, table, self_destruct=True, types_mapper=default_types_mapper
)
return df.astype(meta.dtypes, copy=False)


Expand Down
13 changes: 9 additions & 4 deletions distributed/shuffle/_shuffle.py
Expand Up @@ -97,7 +97,7 @@
column: str,
npartitions: int | None = None,
) -> DataFrame:
from dask.dataframe import DataFrame
from dask.dataframe.core import new_dd_object

Check warning on line 100 in distributed/shuffle/_shuffle.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_shuffle.py#L100

Added line #L100 was not covered by tests

meta = df._meta
check_dtype_support(meta)
Expand All @@ -119,7 +119,7 @@
name_input=df._name,
meta_input=meta,
)
return DataFrame(
return new_dd_object(

Check warning on line 122 in distributed/shuffle/_shuffle.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_shuffle.py#L122

Added line #L122 was not covered by tests
HighLevelGraph.from_collections(name, layer, [df]),
name,
meta,
Expand Down Expand Up @@ -273,8 +273,13 @@
Split data into many arrow batches, partitioned by destination worker
"""
import numpy as np
import pyarrow as pa

from dask.dataframe.dispatch import to_pyarrow_table_dispatch

Check warning on line 277 in distributed/shuffle/_shuffle.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_shuffle.py#L277

Added line #L277 was not covered by tests

# (cudf support) Avoid pd.Series
constructor = df._constructor_sliced
assert isinstance(constructor, type)
worker_for = constructor(worker_for)

Check warning on line 282 in distributed/shuffle/_shuffle.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_shuffle.py#L280-L282

Added lines #L280 - L282 were not covered by tests
df = df.merge(
right=worker_for.cat.codes.rename("_worker"),
left_on=column,
Expand All @@ -287,7 +292,7 @@
# assert len(df) == nrows # Not true if some outputs aren't wanted
# FIXME: If we do not preserve the index something is corrupting the
# bytestream such that it cannot be deserialized anymore
t = pa.Table.from_pandas(df, preserve_index=True)
t = to_pyarrow_table_dispatch(df, preserve_index=True)

Check warning on line 295 in distributed/shuffle/_shuffle.py

View check run for this annotation

Codecov / codecov/patch

distributed/shuffle/_shuffle.py#L295

Added line #L295 was not covered by tests
t = t.sort_by("_worker")
codes = np.asarray(t["_worker"])
t = t.drop(["_worker"])
Expand Down
34 changes: 34 additions & 0 deletions distributed/shuffle/tests/test_shuffle.py
Expand Up @@ -121,6 +121,40 @@ async def test_minimal_version(c, s, a, b):
await c.compute(dd.shuffle.shuffle(df, "x", shuffle="p2p"))


@pytest.mark.gpu
@pytest.mark.filterwarnings(
"ignore:Ignoring the following arguments to `from_pyarrow_table_dispatch`."
)
@gen_cluster(client=True)
async def test_basic_cudf_support(c, s, a, b):
cudf = pytest.importorskip("cudf")
pytest.importorskip("dask_cudf")

try:
from dask.dataframe.dispatch import to_pyarrow_table_dispatch

to_pyarrow_table_dispatch(cudf.DataFrame())
except TypeError:
pytest.skip(reason="Newer version of dask_cudf is required.")

df = dask.datasets.timeseries(
start="2000-01-01",
end="2000-01-10",
dtypes={"x": float, "y": float},
freq="10 s",
).to_backend("cudf")
out = dd.shuffle.shuffle(df, "x", shuffle="p2p")
assert out.npartitions == df.npartitions
x, y = c.compute([df.x.size, out.x.size])
x = await x
y = await y
assert x == y

await check_worker_cleanup(a)
await check_worker_cleanup(b)
await check_scheduler_cleanup(s)


def get_shuffle_run_from_worker(shuffle_id: ShuffleId, worker: Worker) -> ShuffleRun:
plugin = worker.plugins["shuffle"]
assert isinstance(plugin, ShuffleWorkerPlugin)
Expand Down