Skip to content

Commit

Permalink
Enforce dtypes in P2P shuffle (#7879)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Jun 2, 2023
1 parent 8301cb7 commit 57639c1
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 86 deletions.
31 changes: 4 additions & 27 deletions distributed/shuffle/_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def check_minimal_arrow_version() -> None:
)


def convert_partition(data: bytes) -> pa.Table:
def convert_partition(data: bytes, meta: pd.DataFrame) -> pd.DataFrame:
import pyarrow as pa

file = BytesIO(data)
Expand All @@ -54,7 +54,9 @@ def convert_partition(data: bytes) -> pa.Table:
while file.tell() < end:
sr = pa.RecordBatchStreamReader(file)
shards.append(sr.read_all())
return pa.concat_tables(shards)
table = pa.concat_tables(shards)
df = table.to_pandas(self_destruct=True)
return df.astype(meta.dtypes)


def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
Expand All @@ -64,31 +66,6 @@ def list_of_buffers_to_table(data: list[bytes]) -> pa.Table:
return pa.concat_tables(deserialize_table(buffer) for buffer in data)


def deserialize_schema(data: bytes) -> pa.Schema:
"""Deserialize an arrow schema
Examples
--------
>>> b = schema.serialize() # doctest: +SKIP
>>> deserialize_schema(b) # doctest: +SKIP
See also
--------
pa.Schema.serialize
"""
import io

import pyarrow as pa

bio = io.BytesIO()
bio.write(data)
bio.seek(0)
sr = pa.RecordBatchStreamReader(bio)
table = sr.read_all()
bio.close()
return table.schema


def serialize_table(table: pa.Table) -> bytes:
import io

Expand Down
17 changes: 15 additions & 2 deletions distributed/shuffle/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,11 @@ def hash_join_p2p(
join_layer = HashJoinP2PLayer(
name=merge_name,
name_input_left=lhs._name,
meta_input_left=lhs._meta,
left_on=left_on,
n_partitions_left=lhs.npartitions,
name_input_right=rhs._name,
meta_input_right=rhs._meta,
right_on=right_on,
n_partitions_right=rhs.npartitions,
meta_output=meta,
Expand Down Expand Up @@ -138,6 +140,7 @@ def merge_transfer(
input_partition: int,
npartitions: int,
parts_out: set[int],
meta: pd.DataFrame,
):
return shuffle_transfer(
input=input,
Expand All @@ -146,6 +149,7 @@ def merge_transfer(
npartitions=npartitions,
column=_HASH_COLUMN_NAME,
parts_out=parts_out,
meta=meta,
)


Expand All @@ -164,12 +168,13 @@ def merge_unpack(
from dask.dataframe.multi import merge_chunk

ext = _get_worker_extension()
# If the partition is empty, it doesn't contain the hash column name
left = ext.get_output_partition(
shuffle_id_left, barrier_left, output_partition
).drop(columns=_HASH_COLUMN_NAME)
).drop(columns=_HASH_COLUMN_NAME, errors="ignore")
right = ext.get_output_partition(
shuffle_id_right, barrier_right, output_partition
).drop(columns=_HASH_COLUMN_NAME)
).drop(columns=_HASH_COLUMN_NAME, errors="ignore")
return merge_chunk(
left,
right,
Expand All @@ -186,10 +191,12 @@ def __init__(
self,
name: str,
name_input_left: str,
meta_input_left: pd.DataFrame,
left_on,
n_partitions_left: int,
n_partitions_right: int,
name_input_right: str,
meta_input_right: pd.DataFrame,
right_on,
meta_output: pd.DataFrame,
left_index: bool,
Expand All @@ -203,8 +210,10 @@ def __init__(
) -> None:
self.name = name
self.name_input_left = name_input_left
self.meta_input_left = meta_input_left
self.left_on = left_on
self.name_input_right = name_input_right
self.meta_input_right = meta_input_right
self.right_on = right_on
self.how = how
self.npartitions = npartitions
Expand Down Expand Up @@ -285,8 +294,10 @@ def _cull(self, parts_out: Sequence[str]):
return HashJoinP2PLayer(
name=self.name,
name_input_left=self.name_input_left,
meta_input_left=self.meta_input_left,
left_on=self.left_on,
name_input_right=self.name_input_right,
meta_input_right=self.meta_input_right,
right_on=self.right_on,
how=self.how,
npartitions=self.npartitions,
Expand Down Expand Up @@ -344,6 +355,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
i,
self.npartitions,
self.parts_out,
self.meta_input_left,
)
for i in range(self.n_partitions_right):
transfer_keys_right.append((name_right, i))
Expand All @@ -354,6 +366,7 @@ def _construct_graph(self) -> dict[tuple | str, tuple]:
i,
self.npartitions,
self.parts_out,
self.meta_input_right,
)

_barrier_key_left = barrier_key(ShuffleId(token_left))
Expand Down
13 changes: 8 additions & 5 deletions distributed/shuffle/_scheduler_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING, Any, ClassVar

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.protocol import to_serialize
from distributed.shuffle._rechunk import ChunkedAxes, NIndex
from distributed.shuffle._shuffle import (
ShuffleId,
Expand All @@ -21,6 +22,8 @@
)

if TYPE_CHECKING:
import pandas as pd

from distributed.scheduler import (
Recs,
Scheduler,
Expand Down Expand Up @@ -50,7 +53,7 @@ def to_msg(self) -> dict[str, Any]:
class DataFrameShuffleState(ShuffleState):
type: ClassVar[ShuffleType] = ShuffleType.DATAFRAME
worker_for: dict[int, str]
schema: bytes
meta: pd.DataFrame
column: str

def to_msg(self) -> dict[str, Any]:
Expand All @@ -60,7 +63,7 @@ def to_msg(self) -> dict[str, Any]:
"run_id": self.run_id,
"worker_for": self.worker_for,
"column": self.column,
"schema": self.schema,
"meta": to_serialize(self.meta),
"output_workers": self.output_workers,
}

Expand Down Expand Up @@ -186,11 +189,11 @@ def _raise_if_barrier_unknown(self, id: ShuffleId) -> None:
def _create_dataframe_shuffle_state(
self, id: ShuffleId, spec: dict[str, Any]
) -> DataFrameShuffleState:
schema = spec["schema"]
meta = spec["meta"]
column = spec["column"]
npartitions = spec["npartitions"]
parts_out = spec["parts_out"]
assert schema is not None
assert meta is not None
assert column is not None
assert npartitions is not None
assert parts_out is not None
Expand All @@ -204,7 +207,7 @@ def _create_dataframe_shuffle_state(
id=id,
run_id=next(ShuffleState._run_id_iterator),
worker_for=mapping,
schema=schema,
meta=meta,
column=column,
output_workers=output_workers,
participating_workers=output_workers.copy(),
Expand Down
17 changes: 12 additions & 5 deletions distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def shuffle_transfer(
npartitions: int,
column: str,
parts_out: set[int],
meta: pd.DataFrame,
) -> int:
try:
return _get_worker_extension().add_partition(
Expand All @@ -68,6 +69,7 @@ def shuffle_transfer(
npartitions=npartitions,
column=column,
parts_out=parts_out,
meta=meta,
)
except Exception as e:
raise RuntimeError(f"shuffle_transfer failed during shuffle {id}") from e
Expand Down Expand Up @@ -100,13 +102,13 @@ def rearrange_by_column_p2p(
) -> DataFrame:
from dask.dataframe import DataFrame

check_dtype_support(df._meta)
meta = df._meta
check_dtype_support(meta)
npartitions = npartitions or df.npartitions
token = tokenize(df, column, npartitions)

empty = df._meta.copy()
if any(not isinstance(c, str) for c in empty.columns):
unsupported = {c: type(c) for c in empty.columns if not isinstance(c, str)}
if any(not isinstance(c, str) for c in meta.columns):
unsupported = {c: type(c) for c in meta.columns if not isinstance(c, str)}
raise TypeError(
f"p2p requires all column names to be str, found: {unsupported}",
)
Expand All @@ -118,11 +120,12 @@ def rearrange_by_column_p2p(
npartitions,
npartitions_input=df.npartitions,
name_input=df._name,
meta_input=meta,
)
return DataFrame(
HighLevelGraph.from_collections(name, layer, [df]),
name,
empty,
meta,
[None] * (npartitions + 1),
)

Expand All @@ -139,6 +142,7 @@ def __init__(
npartitions: int,
npartitions_input: int,
name_input: str,
meta_input: pd.DataFrame,
parts_out: Iterable | None = None,
annotations: dict | None = None,
):
Expand All @@ -147,6 +151,7 @@ def __init__(
self.column = column
self.npartitions = npartitions
self.name_input = name_input
self.meta_input = meta_input
if parts_out:
self.parts_out = set(parts_out)
else:
Expand Down Expand Up @@ -195,6 +200,7 @@ def _cull(self, parts_out: Iterable[int]) -> P2PShuffleLayer:
self.npartitions,
self.npartitions_input,
self.name_input,
self.meta_input,
parts_out=parts_out,
)

Expand Down Expand Up @@ -245,6 +251,7 @@ def _construct_graph(self) -> _T_LowLevelGraph:
self.npartitions,
self.column,
self.parts_out,
self.meta_input,
)

dsk[_barrier_key] = (shuffle_barrier, token, transfer_keys)
Expand Down
32 changes: 10 additions & 22 deletions distributed/shuffle/_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from distributed.protocol import to_serialize
from distributed.shuffle._arrow import (
convert_partition,
deserialize_schema,
list_of_buffers_to_table,
serialize_table,
)
Expand Down Expand Up @@ -421,8 +420,8 @@ class DataFrameShuffleRun(ShuffleRun[int, int, "pd.DataFrame"]):
A set of all participating worker (addresses).
column:
The data column we split the input partition by.
schema:
The schema of the payload data.
meta:
Empty metadata of the input.
id:
A unique `ShuffleID` this belongs to.
run_id:
Expand Down Expand Up @@ -451,7 +450,7 @@ def __init__(
worker_for: dict[int, str],
output_workers: set,
column: str,
schema: pa.Schema,
meta: pd.DataFrame,
id: ShuffleId,
run_id: int,
local_address: str,
Expand All @@ -477,7 +476,7 @@ def __init__(
memory_limiter_disk=memory_limiter_disk,
)
self.column = column
self.schema = schema
self.meta = meta
partitions_of = defaultdict(list)
for part, addr in worker_for.items():
partitions_of[addr].append(part)
Expand Down Expand Up @@ -543,12 +542,11 @@ async def get_output_partition(self, i: int, key: str) -> pd.DataFrame:
data = self._read_from_disk((i,))

def _() -> pd.DataFrame:
df = convert_partition(data)
return df.to_pandas()
return convert_partition(data, self.meta)

out = await self.offload(_)
except KeyError:
out = self.schema.empty_table().to_pandas()
out = self.meta.copy()
return out

def _get_assigned_worker(self, i: int) -> str:
Expand Down Expand Up @@ -649,8 +647,6 @@ def add_partition(
type: ShuffleType,
**kwargs: Any,
) -> int:
if type == ShuffleType.DATAFRAME:
kwargs["empty"] = data
shuffle = self.get_or_create_shuffle(shuffle_id, type=type, **kwargs)
return sync(
self.worker.loop,
Expand Down Expand Up @@ -723,12 +719,8 @@ async def _get_or_create_shuffle(
----------
shuffle_id
Unique identifier of the shuffle
empty
Empty metadata of input collection
column
Column to be used to map rows to output partitions (by hashing)
npartitions
Number of output partitions
type:
Type of the shuffle operation
"""
shuffle = self.shuffles.get(shuffle_id, None)
if shuffle is None:
Expand Down Expand Up @@ -774,16 +766,12 @@ async def _refresh_shuffle(
worker=self.worker.address,
)
elif type == ShuffleType.DATAFRAME:
import pyarrow as pa

assert kwargs is not None
result = await self.worker.scheduler.shuffle_get_or_create(
id=shuffle_id,
type=type,
spec={
"schema": pa.Schema.from_pandas(kwargs["empty"])
.serialize()
.to_pybytes(),
"meta": to_serialize(kwargs["meta"]),
"npartitions": kwargs["npartitions"],
"column": kwargs["column"],
"parts_out": kwargs["parts_out"],
Expand Down Expand Up @@ -829,7 +817,7 @@ async def _(
column=result["column"],
worker_for=result["worker_for"],
output_workers=result["output_workers"],
schema=deserialize_schema(result["schema"]),
meta=result["meta"],
id=shuffle_id,
run_id=result["run_id"],
directory=os.path.join(
Expand Down

0 comments on commit 57639c1

Please sign in to comment.