diff --git a/durabletask/task.py b/durabletask/task.py index 2650bfd..66abc28 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -322,6 +322,8 @@ def on_child_completed(self, task: Task[T]): # The order of the result MUST match the order of the tasks provided to the constructor. self._result = [task.get_result() for task in self._tasks] self._is_complete = True + if self._parent is not None: + self._parent.on_child_completed(self) def get_completed_tasks(self) -> int: return self._completed_tasks @@ -423,6 +425,8 @@ def on_child_completed(self, task: Task): if not self.is_complete: self._is_complete = True self._result = task + if self._parent is not None: + self._parent.on_child_completed(self) def when_all(tasks: list[Task[T]]) -> WhenAllTask[T]: diff --git a/tests/durabletask/test_task.py b/tests/durabletask/test_task.py index 81cc8a2..d8ec88e 100644 --- a/tests/durabletask/test_task.py +++ b/tests/durabletask/test_task.py @@ -46,6 +46,51 @@ def test_when_all_happy_path_returns_ordered_results_and_completes_last(): assert all_task.get_result() == ["one", "two", "three"] +def test_when_all_is_composable_with_when_any(): + c1 = task.CompletableTask() + c2 = task.CompletableTask() + + any_task = task.when_any([c1, c2]) + all_task = task.when_all([any_task]) + + assert not any_task.is_complete + assert not all_task.is_complete + + c2.complete("two") + + assert any_task.is_complete + assert all_task.is_complete + assert all_task.get_result() == [c2] + + +def test_when_any_is_composable_with_when_all(): + c1 = task.CompletableTask() + c2 = task.CompletableTask() + c3 = task.CompletableTask() + + all_task1 = task.when_all([c1, c2]) + all_task2 = task.when_all([c3]) + any_task = task.when_any([all_task1, all_task2]) + + assert not any_task.is_complete + assert not all_task1.is_complete + assert not all_task2.is_complete + + c1.complete("one") + + assert not any_task.is_complete + assert not all_task1.is_complete + assert not all_task2.is_complete + + c2.complete("two") + + assert any_task.is_complete + assert all_task1.is_complete + assert not all_task2.is_complete + + assert any_task.get_result() == all_task1 + + def test_when_any_happy_path_returns_winner_task_and_completes_on_first(): a = task.CompletableTask() b = task.CompletableTask()