From a0871eccdf7116a451e31a1d163b8b6d1dbf6246 Mon Sep 17 00:00:00 2001 From: Chris Gillum Date: Sat, 28 Oct 2023 23:19:52 +0000 Subject: [PATCH] Fix error handling in generator logic --- durabletask/worker.py | 60 +++++++++++----------- tests/test_orchestration_e2e.py | 76 +++++++++++++++++++++++----- tests/test_orchestration_executor.py | 31 +++++++----- 3 files changed, 109 insertions(+), 58 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index bfc2603..8775efc 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -85,10 +85,10 @@ class TaskHubGrpcWorker: _response_stream: Optional[grpc.Future] = None def __init__(self, *, - host_address: Union[str, None] = None, - metadata: Union[List[Tuple[str, str]], None] = None, - log_handler = None, - log_formatter: Union[logging.Formatter, None] = None, + host_address: Optional[str] = None, + metadata: Optional[List[Tuple[str, str]]] = None, + log_handler=None, + log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() @@ -259,22 +259,20 @@ def resume(self): # has reached a completed state. The only time this won't be the # case is if the user yielded on a WhenAll task and there are still # outstanding child tasks that need to be completed. - if self._previous_task is not None: + while self._previous_task is not None and self._previous_task.is_complete: + next_task = None if self._previous_task.is_failed: - # Raise the failure as an exception to the generator. The orchestrator can then either - # handle the exception or allow it to fail the orchestration. - self._generator.throw(self._previous_task.get_exception()) - elif self._previous_task.is_complete: - while True: - # Resume the generator. This will either return a Task or raise StopIteration if it's done. - # CONSIDER: Should we check for possible infinite loops here? - next_task = self._generator.send(self._previous_task.get_result()) - if not isinstance(next_task, task.Task): - raise TypeError("The orchestrator generator yielded a non-Task object") - self._previous_task = next_task - # If a completed task was returned, then we can keep running the generator function. - if not self._previous_task.is_complete: - break + # Raise the failure as an exception to the generator. + # The orchestrator can then either handle the exception or allow it to fail the orchestration. + next_task = self._generator.throw(self._previous_task.get_exception()) + else: + # Resume the generator with the previous result. + # This will either return a Task or raise StopIteration if it's done. + next_task = self._generator.send(self._previous_task.get_result()) + + if not isinstance(next_task, task.Task): + raise TypeError("The orchestrator generator yielded a non-Task object") + self._previous_task = next_task def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False): if self._is_complete: @@ -359,9 +357,9 @@ def current_utc_datetime(self, value: datetime): def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) - + def create_timer_internal(self, fire_at: Union[datetime, timedelta], - retryable_task: Optional[task.RetryableTask] = None) -> task.Task: + retryable_task: Optional[task.RetryableTask] = None) -> task.Task: id = self.next_sequence_number() if isinstance(fire_at, timedelta): fire_at = self.current_utc_datetime + fire_at @@ -390,9 +388,9 @@ def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput] id = self.next_sequence_number() orchestrator_name = task.get_name(orchestrator) self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy, - is_sub_orch=True, instance_id=instance_id) + is_sub_orch=True, instance_id=instance_id) return self._pending_tasks.get(id, task.CompletableTask()) - + def call_activity_function_helper(self, id: Optional[int], activity_function: Union[task.Activity[TInput, TOutput], str], *, input: Optional[TInput] = None, @@ -402,14 +400,14 @@ def call_activity_function_helper(self, id: Optional[int], fn_task: Optional[task.CompletableTask[TOutput]] = None): if id is None: id = self.next_sequence_number() - + if fn_task is None: encoded_input = shared.to_json(input) if input is not None else None else: # Here, we don't need to convert the input to JSON because it is already converted. # We just need to take string representation of it. encoded_input = str(input) - if is_sub_orch == False: + if not is_sub_orch: name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function) action = ph.new_schedule_task_action(id, name, encoded_input) else: @@ -495,7 +493,7 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e if not ctx._is_complete: task_count = len(ctx._pending_tasks) event_count = len(ctx._pending_events) - self._logger.info(f"{instance_id}: Waiting for {task_count} task(s) and {event_count} event(s).") + self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.") elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status) self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}") @@ -556,8 +554,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven timer_task.complete(None) if timer_task._retryable_parent is not None: activity_action = timer_task._retryable_parent._action - - if timer_task._retryable_parent._is_sub_orch == False: + + if not timer_task._retryable_parent._is_sub_orch: cur_task = activity_action.scheduleTask instance_id = None else: @@ -612,11 +610,11 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven self._logger.warning( f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.") return - + if isinstance(activity_task, task.RetryableTask): if activity_task._retry_policy is not None: next_delay = activity_task.compute_next_delay() - if next_delay == None: + if next_delay is None: activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", event.taskFailed.failureDetails) @@ -674,7 +672,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: next_delay = sub_orch_task.compute_next_delay() - if next_delay == None: + if next_delay is None: sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", failedEvent.failureDetails) diff --git a/tests/test_orchestration_e2e.py b/tests/test_orchestration_e2e.py index 9f8a901..ad329c5 100644 --- a/tests/test_orchestration_e2e.py +++ b/tests/test_orchestration_e2e.py @@ -77,6 +77,51 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): assert state.serialized_custom_status is None +def test_activity_error_handling(): + + def throw(_: task.ActivityContext, input: int) -> int: + raise RuntimeError("Kah-BOOOOM!!!") + + compensation_counter = 0 + + def increment_counter(ctx, _): + nonlocal compensation_counter + compensation_counter += 1 + + def orchestrator(ctx: task.OrchestrationContext, input: int): + error_msg = "" + try: + yield ctx.call_activity(throw, input=input) + except task.TaskFailedError as e: + error_msg = e.details.message + + # compensating actions + yield ctx.call_activity(increment_counter) + yield ctx.call_activity(increment_counter) + + return error_msg + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(orchestrator) + w.add_activity(throw) + w.add_activity(increment_counter) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.name == task.get_name(orchestrator) + assert state.instance_id == id + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps("Kah-BOOOOM!!!") + assert state.failure_details is None + assert state.serialized_custom_status is None + assert compensation_counter == 2 + + def test_sub_orchestration_fan_out(): threadLock = threading.Lock() activity_counter = 0 @@ -269,10 +314,14 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert state.serialized_input == json.dumps(4) assert all_results == [1, 2, 3, 4, 5] + +# NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet +# support orchestration ID reuse. This gap is being tracked here: +# https://github.com/microsoft/durabletask-go/issues/42 def test_retry_policies(): # This test verifies that the retry policies are working as expected. # It does this by creating an orchestration that calls a sub-orchestrator, - # which in turn calls an activity that always fails. + # which in turn calls an activity that always fails. # In this test, the retry policies are added, and the orchestration # should still fail. But, number of times the sub-orchestrator and activity # is called should increase as per the retry policies. @@ -281,12 +330,12 @@ def test_retry_policies(): throw_activity_counter = 0 # Second setup: With retry policies - retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), - max_number_of_attempts=3, - backoff_coefficient=1, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30)) + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=3, + backoff_coefficient=1, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=30)) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy) @@ -323,18 +372,19 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): assert throw_activity_counter == 9 assert child_orch_counter == 3 + def test_retry_timeout(): # This test verifies that the retry timeout is working as expected. # Max number of attempts is 5 and retry timeout is 14 seconds. # Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds. # So, the 5th attempt should not be made and the orchestration should fail. throw_activity_counter = 0 - retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), - max_number_of_attempts=5, - backoff_coefficient=2, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14)) + retry_policy = task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=5, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=14)) def mock_orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_activity(throw_activity, retry_policy=retry_policy) diff --git a/tests/test_orchestration_executor.py b/tests/test_orchestration_executor.py index 9c34f39..6d5fdac 100644 --- a/tests/test_orchestration_executor.py +++ b/tests/test_orchestration_executor.py @@ -256,7 +256,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] expected_fire_at = current_timestamp + timedelta(seconds=1) - + new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] @@ -391,6 +391,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!") assert actions[0].id == 7 + def test_nondeterminism_expected_timer(): """Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead""" def dummy_activity(ctx, _): @@ -964,13 +965,13 @@ def dummy_activity(_, inp: str): return f"Hello {inp}!" def orchestrator(ctx: task.OrchestrationContext, _): - t1 = ctx.call_activity(dummy_activity, + t1 = ctx.call_activity(dummy_activity, retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), - max_number_of_attempts=6, - backoff_coefficient=2, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=50)), + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=6, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=50)), input="Tokyo") t2 = ctx.call_activity(dummy_activity, input="Seattle") winner = yield task.when_any([t1, t2]) @@ -1036,6 +1037,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED assert complete_action.result.value == encoded_output + def test_when_all_with_retry(): """Tests that a when_all pattern works correctly with retries""" def dummy_activity(ctx, inp: str): @@ -1044,13 +1046,13 @@ def dummy_activity(ctx, inp: str): return f"Hello {inp}!" def orchestrator(ctx: task.OrchestrationContext, _): - t1 = ctx.call_activity(dummy_activity, + t1 = ctx.call_activity(dummy_activity, retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=2), - max_number_of_attempts=3, - backoff_coefficient=4, - max_retry_interval=timedelta(seconds=5), - retry_timeout=timedelta(seconds=50)), + first_retry_interval=timedelta(seconds=2), + max_number_of_attempts=3, + backoff_coefficient=4, + max_retry_interval=timedelta(seconds=5), + retry_timeout=timedelta(seconds=50)), input="Tokyo") t2 = ctx.call_activity(dummy_activity, input="Seattle") results = yield task.when_all([t1, t2]) @@ -1117,7 +1119,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert actions[1].id == 1 ex = ValueError("Kah-BOOOOM!!!") - + # Simulate the task failing for the third time. Overall workflow should fail at this point. old_events = old_events + new_events new_events = [ @@ -1130,6 +1132,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage + def get_and_validate_single_complete_orchestration_action(actions: List[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: assert len(actions) == 1 assert type(actions[0]) is pb.OrchestratorAction