Skip to content

Commit

Permalink
Minor quality-of-life tweaks to cancelled state (#6701)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jul 11, 2022
1 parent d2912c6 commit a05cc38
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 46 deletions.
15 changes: 13 additions & 2 deletions distributed/tests/test_worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ def test_TaskState__to_dict():
]


def test_TaskState_repr():
ts = TaskState("x")
assert str(ts) == "<TaskState 'x' released>"
ts.state = "cancelled"
ts.previous = "flight"
assert str(ts) == "<TaskState 'x' cancelled(flight)>"
ts.state = "resumed"
ts.next = "waiting"
assert str(ts) == "<TaskState 'x' resumed(flight->waiting)>"


def test_WorkerState__to_dict(ws):
ws.handle_stimulus(
AcquireReplicasEvent(
Expand Down Expand Up @@ -1162,7 +1173,7 @@ def test_task_with_dependencies_acquires_resources(ws):
(ExecuteSuccessEvent, "memory"),
pytest.param(
ExecuteFailureEvent,
"error",
"flight",
marks=pytest.mark.xfail(
reason="distributed#6682,distributed#6689,distributed#6693"
),
Expand Down Expand Up @@ -1238,7 +1249,7 @@ def test_done_task_not_in_all_running_tasks(
(ExecuteSuccessEvent, "memory"),
pytest.param(
ExecuteFailureEvent,
"error",
"flight",
marks=pytest.mark.xfail(reason="distributed#6689"),
),
],
Expand Down
96 changes: 52 additions & 44 deletions distributed/worker_state_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ class TaskState:

#: The current state of the task
state: TaskStateState = "released"
#: The previous state of the task. This is a state machine implementation detail.
_previous: TaskStateState | None = None
#: The next state of the task. This is a state machine implementation detail.
_next: TaskStateState | None = None
#: The previous state of the task. It is not None iff state in (cancelled, resumed).
previous: TaskStateState | None = None
#: The next state of the task. It is not None iff state == resumed.
next: TaskStateState | None = None

#: Expected duration of the task
duration: float | None = None
Expand Down Expand Up @@ -289,7 +289,13 @@ def __post_init__(self) -> None:
TaskState._instances.add(self)

def __repr__(self) -> str:
return f"<TaskState {self.key!r} {self.state}>"
if self.state == "cancelled":
state = f"cancelled({self.previous})"
elif self.state == "resumed":
state = f"resumed({self.previous}->{self.next})"
else:
state = self.state
return f"<TaskState {self.key!r} {state}>"

def __hash__(self) -> int:
"""Override dataclass __hash__, reverting to the default behaviour
Expand Down Expand Up @@ -1152,7 +1158,7 @@ class WorkerState:
available_resources: dict[str, float]

#: Set of tasks that are currently running.
#: See also :meth:`executing_count` and :attr:`long_runing`.
#: See also :meth:`executing_count` and :attr:`long_running`.
executing: set[TaskState]

#: Set of tasks that are currently running and have called
Expand Down Expand Up @@ -1294,7 +1300,7 @@ def all_running_tasks(self) -> set[TaskState]:
These are:
- ``ts.status in ("executing", "long-running", "cancelled")``
- ``ts.status == "resumed" and ts._previous in ("executing", "long-running")``
- ``ts.status == "resumed" and ts.previous in ("executing", "long-running")``
"""
# Note: cancelled and resumed tasks are still in either of these sets
return self.executing | self.long_running
Expand Down Expand Up @@ -1393,8 +1399,8 @@ def _purge_state(self, ts: TaskState) -> None:

ts.waiting_for_data.clear()
ts.nbytes = None
ts._previous = None
ts._next = None
ts.previous = None
ts.next = None
ts.done = False

self.executing.discard(ts)
Expand Down Expand Up @@ -1657,8 +1663,8 @@ def _put_key_in_memory(
Raises
------
Exception:
In case the data is put into the in memory buffer and a serialization error
occurs during spilling, this raises that error. This has to be handled by
In case the data is put into the in-memory buffer and a serialization error
occurs during spilling, this re-raises that error. This has to be handled by
the caller since most callers generate scheduler messages on success (see
comment above) but we need to signal that this was not successful.
Expand Down Expand Up @@ -1904,7 +1910,7 @@ def _transition_cancelled_error(
*,
stimulus_id: str,
) -> RecsInstrs:
assert ts._previous in ("executing", "long-running")
assert ts.previous in ("executing", "long-running")
recs, instructions = self._transition_executing_error(
ts,
exception,
Expand Down Expand Up @@ -2000,23 +2006,23 @@ def _transition_from_resumed(
recs: Recs = {}
instructions: Instructions = []

if ts._previous == finish:
if ts.previous == finish:
# We're back where we started. We should forget about the entire
# cancellation attempt
ts.state = finish
ts._next = None
ts._previous = None
ts.next = None
ts.previous = None
elif not ts.done:
# If we're not done, yet, just remember where we want to be next
ts._next = finish
ts.next = finish
else:
# Flight/executing finished unsuccessfully, i.e. not in memory
assert finish != "memory"
next_state = ts._next
next_state = ts.next
assert next_state in {"waiting", "fetch"}, next_state
assert ts._previous in {"executing", "long-running", "flight"}, ts._previous
assert ts.previous in {"executing", "long-running", "flight"}, ts.previous

if ts._previous in ("executing", "long-running"):
if ts.previous in ("executing", "long-running"):
self._release_resources(ts)
self.executing.discard(ts)
self.long_running.discard(ts)
Expand Down Expand Up @@ -2055,7 +2061,7 @@ def _transition_resumed_released(
) -> RecsInstrs:
if not ts.done:
ts.state = "cancelled"
ts._next = None
ts.next = None
return {}, []
else:
return self._transition_generic_released(ts, stimulus_id=stimulus_id)
Expand All @@ -2071,33 +2077,33 @@ def _transition_cancelled_fetch(
) -> RecsInstrs:
if ts.done:
return {ts: "released"}, []
elif ts._previous == "flight":
ts.state = ts._previous
elif ts.previous == "flight":
ts.state = ts.previous
return {}, []
else:
assert ts._previous in ("executing", "long-running")
assert ts.previous in ("executing", "long-running")
ts.state = "resumed"
ts._next = "fetch"
ts.next = "fetch"
return {}, []

def _transition_cancelled_waiting(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
if ts.done:
return {ts: "released"}, []
elif ts._previous in ("executing", "long-running"):
ts.state = ts._previous
elif ts.previous in ("executing", "long-running"):
ts.state = ts.previous
return {}, []
else:
assert ts._previous == "flight"
assert ts.previous == "flight"
ts.state = "resumed"
ts._next = "waiting"
ts.next = "waiting"
return {}, []

def _transition_cancelled_forgotten(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
ts._next = "forgotten"
ts.next = "forgotten"
if not ts.done:
return {}, []
return {ts: "released"}, []
Expand All @@ -2117,8 +2123,8 @@ def _transition_cancelled_released(
def _transition_executing_released(
self, ts: TaskState, *, stimulus_id: str
) -> RecsInstrs:
ts._previous = ts.state
ts._next = None
ts.previous = ts.state
ts.next = None
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
ts.done = False
Expand Down Expand Up @@ -2242,18 +2248,20 @@ def _transition_flight_released(
# sensible?
return self._transition_generic_released(ts, stimulus_id=stimulus_id)
else:
ts._previous = "flight"
ts._next = None
ts.previous = "flight"
ts.next = None
# See https://github.com/dask/distributed/pull/5046#discussion_r685093940
ts.state = "cancelled"
return {}, []

def _transition_cancelled_memory(
self, ts: TaskState, value: object, *, stimulus_id: str
) -> RecsInstrs:
# We only need this because the to-memory signatures require a value but
# we do not want to store a cancelled result and want to release
# immediately
"""We only need this because the to-memory signatures require a value but
we do not want to store a cancelled result and want to release immediately.
See also ``_transition_cancelled_error``
"""
assert ts.done
return self._transition_cancelled_released(ts, stimulus_id=stimulus_id)

Expand Down Expand Up @@ -3136,19 +3144,19 @@ def _validate_task_missing(self, ts: TaskState) -> None:

def _validate_task_cancelled(self, ts: TaskState) -> None:
assert ts.key not in self.data
assert ts._previous in {"long-running", "executing", "flight"}
assert ts.previous in {"long-running", "executing", "flight"}
# We'll always transition to released after it is done
assert ts._next is None, (ts.key, ts._next, self.story(ts))
assert ts.next is None

def _validate_task_resumed(self, ts: TaskState) -> None:
assert ts.key not in self.data
assert ts._next in {"fetch", "waiting"}
assert ts._previous in {"long-running", "executing", "flight"}
assert ts.next in {"fetch", "waiting"}
assert ts.previous in {"long-running", "executing", "flight"}

def _validate_task_released(self, ts: TaskState) -> None:
assert ts.key not in self.data
assert not ts._next
assert not ts._previous
assert not ts.next
assert not ts.previous
for tss in self.data_needed.values():
assert ts not in tss
assert ts not in self.executing
Expand Down Expand Up @@ -3237,11 +3245,11 @@ def validate_state(self) -> None:
# FIXME https://github.com/dask/distributed/issues/6689
# for ts in self.executing:
# assert ts.state == "executing" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "executing"
# ts.state in ("cancelled", "resumed") and ts.previous == "executing"
# ), self.story(ts)
# for ts in self.long_running:
# assert ts.state == "long-running" or (
# ts.state in ("cancelled", "resumed") and ts._previous == "long-running"
# ts.state in ("cancelled", "resumed") and ts.previous == "long-running"
# ), self.story(ts)

# Test that there aren't multiple TaskState objects with the same key in any
Expand Down

0 comments on commit a05cc38

Please sign in to comment.