Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions durabletask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def on_child_completed(self, task: Task[T]):
# The order of the result MUST match the order of the tasks provided to the constructor.
self._result = [task.get_result() for task in self._tasks]
self._is_complete = True
if self._parent is not None:
self._parent.on_child_completed(self)

def get_completed_tasks(self) -> int:
return self._completed_tasks
Expand Down Expand Up @@ -423,6 +425,8 @@ def on_child_completed(self, task: Task):
if not self.is_complete:
self._is_complete = True
self._result = task
if self._parent is not None:
self._parent.on_child_completed(self)


def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]:
Expand Down
45 changes: 45 additions & 0 deletions tests/durabletask/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,51 @@ def test_when_all_happy_path_returns_ordered_results_and_completes_last():
assert all_task.get_result() == ["one", "two", "three"]


def test_when_all_is_composable_with_when_any():
c1 = task.CompletableTask()
c2 = task.CompletableTask()

any_task = task.when_any([c1, c2])
all_task = task.when_all([any_task])

assert not any_task.is_complete
assert not all_task.is_complete

c2.complete("two")

assert any_task.is_complete
assert all_task.is_complete
assert all_task.get_result() == [c2]


def test_when_any_is_composable_with_when_all():
c1 = task.CompletableTask()
c2 = task.CompletableTask()
c3 = task.CompletableTask()

all_task1 = task.when_all([c1, c2])
all_task2 = task.when_all([c3])
any_task = task.when_any([all_task1, all_task2])

assert not any_task.is_complete
assert not all_task1.is_complete
assert not all_task2.is_complete

c1.complete("one")

assert not any_task.is_complete
assert not all_task1.is_complete
assert not all_task2.is_complete

c2.complete("two")

assert any_task.is_complete
assert all_task1.is_complete
assert not all_task2.is_complete

assert any_task.get_result() == all_task1


def test_when_any_happy_path_returns_winner_task_and_completes_on_first():
a = task.CompletableTask()
b = task.CompletableTask()
Expand Down
Loading