Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved error messages for P2P shuffling #7979

Merged
merged 8 commits into from Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 16 additions & 6 deletions distributed/shuffle/_scheduler_extension.py
Expand Up @@ -45,6 +45,9 @@ class ShuffleState(abc.ABC):
def to_msg(self) -> dict[str, Any]:
"""Transform the shuffle state into a JSON-serializable message"""

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]>"


@dataclass
class DataFrameShuffleState(ShuffleState):
Expand Down Expand Up @@ -119,15 +122,24 @@ def shuffle_ids(self) -> set[ShuffleId]:

async def barrier(self, id: ShuffleId, run_id: int) -> None:
shuffle = self.states[id]
assert shuffle.run_id == run_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert shuffle.run_id == run_id
assert shuffle.run_id == run_id, "Shuffle barrier ID does not match requested run_id"

? Or something like that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

msg = {"op": "shuffle_inputs_done", "shuffle_id": id, "run_id": run_id}
await self.scheduler.broadcast(
msg=msg, workers=list(shuffle.participating_workers)
)

def restrict_task(self, id: ShuffleId, run_id: int, key: str, worker: str) -> dict:
shuffle = self.states[id]
if shuffle.run_id != run_id:
return {"status": "error", "message": "Stale shuffle"}
if shuffle.run_id > run_id:
return {
"status": "error",
"message": f"Request stale, expected run_id=={run_id} for {shuffle}",
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved
}
elif shuffle.run_id < run_id:
return {
"status": "error",
"message": f"Request invalid, expected run_id=={run_id} for {shuffle}",
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved
}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
return {"status": "OK"}
Expand Down Expand Up @@ -298,9 +310,7 @@ def remove_worker(self, scheduler: Scheduler, worker: str) -> None:
for shuffle_id, shuffle in self.states.items():
if worker not in shuffle.participating_workers:
continue
exception = RuntimeError(
f"Worker {worker} left during active shuffle {shuffle_id}"
)
exception = RuntimeError(f"Worker {worker} left during active {shuffle_id}")
self.erred_shuffles[shuffle_id] = exception
self._fail_on_workers(shuffle, str(exception))

Expand Down Expand Up @@ -335,7 +345,7 @@ def transition(
shuffle = self.states[shuffle_id]
except KeyError:
return
self._fail_on_workers(shuffle, message=f"Shuffle {shuffle_id} forgotten")
self._fail_on_workers(shuffle, message=f"{shuffle} forgotten")
self._clean_on_scheduler(shuffle_id)

def _fail_on_workers(self, shuffle: ShuffleState, message: str) -> None:
Expand Down
36 changes: 19 additions & 17 deletions distributed/shuffle/_worker_extension.py
Expand Up @@ -99,7 +99,10 @@ def __init__(
self._closed_event = asyncio.Event()

def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self.id}[{self.run_id}] on {self.local_address}>"
return f"<{self.__class__.__name__}: id={self.id!r}, run_id={self.run_id!r}, local_address={self.local_address!r}, closed={self.closed!r}, transferred={self.transferred!r}>"

def __str__(self) -> str:
return f"{self.__class__.__name__}<{self.id}[{self.run_id}]> on {self.local_address}"

def __hash__(self) -> int:
return self.run_id
Expand Down Expand Up @@ -162,9 +165,7 @@ def raise_if_closed(self) -> None:
if self.closed:
if self._exception:
raise self._exception
raise ShuffleClosedError(
f"Shuffle {self.id} has been closed on {self.local_address}"
)
raise ShuffleClosedError(f"{self} has already been closed")

async def inputs_done(self) -> None:
self.raise_if_closed()
Expand Down Expand Up @@ -346,7 +347,7 @@ async def _receive(self, data: list[tuple[ArrayRechunkShardID, bytes]]) -> None:
async def add_partition(self, data: np.ndarray, partition_id: NDIndex) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to shuffle {self}")
raise RuntimeError(f"Cannot add more partitions to {self}")

def _() -> dict[str, list[tuple[ArrayRechunkShardID, bytes]]]:
"""Return a mapping of worker addresses to a list of tuples of shard IDs
Expand Down Expand Up @@ -511,7 +512,7 @@ def _repartition_buffers(self, data: list[bytes]) -> dict[NDIndex, list[bytes]]:
async def add_partition(self, data: pd.DataFrame, partition_id: int) -> int:
self.raise_if_closed()
if self.transferred:
raise RuntimeError(f"Cannot add more partitions to shuffle {self}")
raise RuntimeError(f"Cannot add more partitions to {self}")

def _() -> dict[str, list[tuple[int, bytes]]]:
out = split_by_worker(
Expand Down Expand Up @@ -586,6 +587,12 @@ def __init__(self, worker: Worker) -> None:
self.closed = False
self._executor = ThreadPoolExecutor(self.worker.state.nthreads)

def __str__(self) -> str:
return f"ShuffleWorkerExtension on {self.worker.address}"

def __repr__(self) -> str:
return f"<ShuffleWorkerExtension, worker={self.worker.address_safe!r}, closed={self.closed}>"

# Handlers
##########
# NOTE: handlers are not threadsafe, but they're called from async comms, so that's okay
Expand Down Expand Up @@ -695,11 +702,10 @@ async def _get_shuffle_run(
shuffle = await self._refresh_shuffle(
shuffle_id=shuffle_id,
)
if run_id < shuffle.run_id:
raise RuntimeError("Stale shuffle")
elif run_id > shuffle.run_id:
# This should never happen
raise RuntimeError("Invalid shuffle state")
if run_id > shuffle.run_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: flip order of operands to align with restrict_task in the scheduler extension?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, done.

raise RuntimeError(f"run_id invalid, got {shuffle}")
hendrikmakait marked this conversation as resolved.
Show resolved Hide resolved
elif run_id < shuffle.run_id:
raise RuntimeError(f"{run_id=} stale, got {shuffle}")

if shuffle._exception:
raise shuffle._exception
Expand Down Expand Up @@ -729,9 +735,7 @@ async def _get_or_create_shuffle(
)

if self.closed:
raise ShuffleClosedError(
f"{self.__class__.__name__} already closed on {self.worker.address}"
)
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle._exception:
raise shuffle._exception
return shuffle
Expand Down Expand Up @@ -790,9 +794,7 @@ async def _refresh_shuffle(
assert result["status"] == "OK"

if self.closed:
raise ShuffleClosedError(
f"{self.__class__.__name__} already closed on {self.worker.address}"
)
raise ShuffleClosedError(f"{self} has already been closed")
if shuffle_id in self.shuffles:
existing = self.shuffles[shuffle_id]
if existing.run_id >= result["run_id"]:
Expand Down
4 changes: 2 additions & 2 deletions distributed/shuffle/tests/test_shuffle.py
Expand Up @@ -1601,7 +1601,7 @@ async def test_shuffle_run_consistency(c, s, a):

# This should never occur, but fetching an ID larger than the ID available on
# the scheduler should result in an error.
with pytest.raises(RuntimeError, match="Invalid shuffle state"):
with pytest.raises(RuntimeError, match="invalid"):
await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"] + 1)

# Finish first execution
Expand All @@ -1628,7 +1628,7 @@ async def test_shuffle_run_consistency(c, s, a):
assert await worker_ext._get_shuffle_run(shuffle_id, new_shuffle_dict["run_id"])

# Fetching a stale run from a worker aware of the new run raises an error
with pytest.raises(RuntimeError, match="Stale shuffle"):
with pytest.raises(RuntimeError, match="stale"):
await worker_ext._get_shuffle_run(shuffle_id, shuffle_dict["run_id"])

worker_ext.block_barrier.set()
Expand Down