Skip to content

Commit 2836c75

Browse files
authored
Fix error handling in generator logic (#21)
1 parent 80de6c2 commit 2836c75

File tree

3 files changed

+109
-58
lines changed

3 files changed

+109
-58
lines changed

durabletask/worker.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -85,10 +85,10 @@ class TaskHubGrpcWorker:
8585
_response_stream: Optional[grpc.Future] = None
8686

8787
def __init__(self, *,
88-
host_address: Union[str, None] = None,
89-
metadata: Union[List[Tuple[str, str]], None] = None,
90-
log_handler = None,
91-
log_formatter: Union[logging.Formatter, None] = None,
88+
host_address: Optional[str] = None,
89+
metadata: Optional[List[Tuple[str, str]]] = None,
90+
log_handler=None,
91+
log_formatter: Optional[logging.Formatter] = None,
9292
secure_channel: bool = False):
9393
self._registry = _Registry()
9494
self._host_address = host_address if host_address else shared.get_default_host_address()
@@ -259,22 +259,20 @@ def resume(self):
259259
# has reached a completed state. The only time this won't be the
260260
# case is if the user yielded on a WhenAll task and there are still
261261
# outstanding child tasks that need to be completed.
262-
if self._previous_task is not None:
262+
while self._previous_task is not None and self._previous_task.is_complete:
263+
next_task = None
263264
if self._previous_task.is_failed:
264-
# Raise the failure as an exception to the generator. The orchestrator can then either
265-
# handle the exception or allow it to fail the orchestration.
266-
self._generator.throw(self._previous_task.get_exception())
267-
elif self._previous_task.is_complete:
268-
while True:
269-
# Resume the generator. This will either return a Task or raise StopIteration if it's done.
270-
# CONSIDER: Should we check for possible infinite loops here?
271-
next_task = self._generator.send(self._previous_task.get_result())
272-
if not isinstance(next_task, task.Task):
273-
raise TypeError("The orchestrator generator yielded a non-Task object")
274-
self._previous_task = next_task
275-
# If a completed task was returned, then we can keep running the generator function.
276-
if not self._previous_task.is_complete:
277-
break
265+
# Raise the failure as an exception to the generator.
266+
# The orchestrator can then either handle the exception or allow it to fail the orchestration.
267+
next_task = self._generator.throw(self._previous_task.get_exception())
268+
else:
269+
# Resume the generator with the previous result.
270+
# This will either return a Task or raise StopIteration if it's done.
271+
next_task = self._generator.send(self._previous_task.get_result())
272+
273+
if not isinstance(next_task, task.Task):
274+
raise TypeError("The orchestrator generator yielded a non-Task object")
275+
self._previous_task = next_task
278276

279277
def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False):
280278
if self._is_complete:
@@ -359,9 +357,9 @@ def current_utc_datetime(self, value: datetime):
359357

360358
def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
361359
return self.create_timer_internal(fire_at)
362-
360+
363361
def create_timer_internal(self, fire_at: Union[datetime, timedelta],
364-
retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
362+
retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
365363
id = self.next_sequence_number()
366364
if isinstance(fire_at, timedelta):
367365
fire_at = self.current_utc_datetime + fire_at
@@ -390,9 +388,9 @@ def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput]
390388
id = self.next_sequence_number()
391389
orchestrator_name = task.get_name(orchestrator)
392390
self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy,
393-
is_sub_orch=True, instance_id=instance_id)
391+
is_sub_orch=True, instance_id=instance_id)
394392
return self._pending_tasks.get(id, task.CompletableTask())
395-
393+
396394
def call_activity_function_helper(self, id: Optional[int],
397395
activity_function: Union[task.Activity[TInput, TOutput], str], *,
398396
input: Optional[TInput] = None,
@@ -402,14 +400,14 @@ def call_activity_function_helper(self, id: Optional[int],
402400
fn_task: Optional[task.CompletableTask[TOutput]] = None):
403401
if id is None:
404402
id = self.next_sequence_number()
405-
403+
406404
if fn_task is None:
407405
encoded_input = shared.to_json(input) if input is not None else None
408406
else:
409407
# Here, we don't need to convert the input to JSON because it is already converted.
410408
# We just need to take string representation of it.
411409
encoded_input = str(input)
412-
if is_sub_orch == False:
410+
if not is_sub_orch:
413411
name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function)
414412
action = ph.new_schedule_task_action(id, name, encoded_input)
415413
else:
@@ -495,7 +493,7 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e
495493
if not ctx._is_complete:
496494
task_count = len(ctx._pending_tasks)
497495
event_count = len(ctx._pending_events)
498-
self._logger.info(f"{instance_id}: Waiting for {task_count} task(s) and {event_count} event(s).")
496+
self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.")
499497
elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
500498
completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
501499
self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
@@ -556,8 +554,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
556554
timer_task.complete(None)
557555
if timer_task._retryable_parent is not None:
558556
activity_action = timer_task._retryable_parent._action
559-
560-
if timer_task._retryable_parent._is_sub_orch == False:
557+
558+
if not timer_task._retryable_parent._is_sub_orch:
561559
cur_task = activity_action.scheduleTask
562560
instance_id = None
563561
else:
@@ -612,11 +610,11 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
612610
self._logger.warning(
613611
f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.")
614612
return
615-
613+
616614
if isinstance(activity_task, task.RetryableTask):
617615
if activity_task._retry_policy is not None:
618616
next_delay = activity_task.compute_next_delay()
619-
if next_delay == None:
617+
if next_delay is None:
620618
activity_task.fail(
621619
f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
622620
event.taskFailed.failureDetails)
@@ -674,7 +672,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
674672
if isinstance(sub_orch_task, task.RetryableTask):
675673
if sub_orch_task._retry_policy is not None:
676674
next_delay = sub_orch_task.compute_next_delay()
677-
if next_delay == None:
675+
if next_delay is None:
678676
sub_orch_task.fail(
679677
f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
680678
failedEvent.failureDetails)

tests/test_orchestration_e2e.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,51 @@ def sequence(ctx: task.OrchestrationContext, start_val: int):
7777
assert state.serialized_custom_status is None
7878

7979

80+
def test_activity_error_handling():
81+
82+
def throw(_: task.ActivityContext, input: int) -> int:
83+
raise RuntimeError("Kah-BOOOOM!!!")
84+
85+
compensation_counter = 0
86+
87+
def increment_counter(ctx, _):
88+
nonlocal compensation_counter
89+
compensation_counter += 1
90+
91+
def orchestrator(ctx: task.OrchestrationContext, input: int):
92+
error_msg = ""
93+
try:
94+
yield ctx.call_activity(throw, input=input)
95+
except task.TaskFailedError as e:
96+
error_msg = e.details.message
97+
98+
# compensating actions
99+
yield ctx.call_activity(increment_counter)
100+
yield ctx.call_activity(increment_counter)
101+
102+
return error_msg
103+
104+
# Start a worker, which will connect to the sidecar in a background thread
105+
with worker.TaskHubGrpcWorker() as w:
106+
w.add_orchestrator(orchestrator)
107+
w.add_activity(throw)
108+
w.add_activity(increment_counter)
109+
w.start()
110+
111+
task_hub_client = client.TaskHubGrpcClient()
112+
id = task_hub_client.schedule_new_orchestration(orchestrator, input=1)
113+
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
114+
115+
assert state is not None
116+
assert state.name == task.get_name(orchestrator)
117+
assert state.instance_id == id
118+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
119+
assert state.serialized_output == json.dumps("Kah-BOOOOM!!!")
120+
assert state.failure_details is None
121+
assert state.serialized_custom_status is None
122+
assert compensation_counter == 2
123+
124+
80125
def test_sub_orchestration_fan_out():
81126
threadLock = threading.Lock()
82127
activity_counter = 0
@@ -269,10 +314,14 @@ def orchestrator(ctx: task.OrchestrationContext, input: int):
269314
assert state.serialized_input == json.dumps(4)
270315
assert all_results == [1, 2, 3, 4, 5]
271316

317+
318+
# NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet
319+
# support orchestration ID reuse. This gap is being tracked here:
320+
# https://github.com/microsoft/durabletask-go/issues/42
272321
def test_retry_policies():
273322
# This test verifies that the retry policies are working as expected.
274323
# It does this by creating an orchestration that calls a sub-orchestrator,
275-
# which in turn calls an activity that always fails.
324+
# which in turn calls an activity that always fails.
276325
# In this test, the retry policies are added, and the orchestration
277326
# should still fail. But, number of times the sub-orchestrator and activity
278327
# is called should increase as per the retry policies.
@@ -281,12 +330,12 @@ def test_retry_policies():
281330
throw_activity_counter = 0
282331

283332
# Second setup: With retry policies
284-
retry_policy=task.RetryPolicy(
285-
first_retry_interval=timedelta(seconds=1),
286-
max_number_of_attempts=3,
287-
backoff_coefficient=1,
288-
max_retry_interval=timedelta(seconds=10),
289-
retry_timeout=timedelta(seconds=30))
333+
retry_policy = task.RetryPolicy(
334+
first_retry_interval=timedelta(seconds=1),
335+
max_number_of_attempts=3,
336+
backoff_coefficient=1,
337+
max_retry_interval=timedelta(seconds=10),
338+
retry_timeout=timedelta(seconds=30))
290339

291340
def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _):
292341
yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy)
@@ -323,18 +372,19 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _):
323372
assert throw_activity_counter == 9
324373
assert child_orch_counter == 3
325374

375+
326376
def test_retry_timeout():
327377
# This test verifies that the retry timeout is working as expected.
328378
# Max number of attempts is 5 and retry timeout is 14 seconds.
329379
# Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds.
330380
# So, the 5th attempt should not be made and the orchestration should fail.
331381
throw_activity_counter = 0
332-
retry_policy=task.RetryPolicy(
333-
first_retry_interval=timedelta(seconds=1),
334-
max_number_of_attempts=5,
335-
backoff_coefficient=2,
336-
max_retry_interval=timedelta(seconds=10),
337-
retry_timeout=timedelta(seconds=14))
382+
retry_policy = task.RetryPolicy(
383+
first_retry_interval=timedelta(seconds=1),
384+
max_number_of_attempts=5,
385+
backoff_coefficient=2,
386+
max_retry_interval=timedelta(seconds=10),
387+
retry_timeout=timedelta(seconds=14))
338388

339389
def mock_orchestrator(ctx: task.OrchestrationContext, _):
340390
yield ctx.call_activity(throw_activity, retry_policy=retry_policy)

tests/test_orchestration_executor.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
256256
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None),
257257
helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))]
258258
expected_fire_at = current_timestamp + timedelta(seconds=1)
259-
259+
260260
new_events = [
261261
helpers.new_orchestrator_started_event(timestamp=current_timestamp),
262262
helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))]
@@ -391,6 +391,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
391391
assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!")
392392
assert actions[0].id == 7
393393

394+
394395
def test_nondeterminism_expected_timer():
395396
"""Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead"""
396397
def dummy_activity(ctx, _):
@@ -964,13 +965,13 @@ def dummy_activity(_, inp: str):
964965
return f"Hello {inp}!"
965966

966967
def orchestrator(ctx: task.OrchestrationContext, _):
967-
t1 = ctx.call_activity(dummy_activity,
968+
t1 = ctx.call_activity(dummy_activity,
968969
retry_policy=task.RetryPolicy(
969-
first_retry_interval=timedelta(seconds=1),
970-
max_number_of_attempts=6,
971-
backoff_coefficient=2,
972-
max_retry_interval=timedelta(seconds=10),
973-
retry_timeout=timedelta(seconds=50)),
970+
first_retry_interval=timedelta(seconds=1),
971+
max_number_of_attempts=6,
972+
backoff_coefficient=2,
973+
max_retry_interval=timedelta(seconds=10),
974+
retry_timeout=timedelta(seconds=50)),
974975
input="Tokyo")
975976
t2 = ctx.call_activity(dummy_activity, input="Seattle")
976977
winner = yield task.when_any([t1, t2])
@@ -1036,6 +1037,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
10361037
assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED
10371038
assert complete_action.result.value == encoded_output
10381039

1040+
10391041
def test_when_all_with_retry():
10401042
"""Tests that a when_all pattern works correctly with retries"""
10411043
def dummy_activity(ctx, inp: str):
@@ -1044,13 +1046,13 @@ def dummy_activity(ctx, inp: str):
10441046
return f"Hello {inp}!"
10451047

10461048
def orchestrator(ctx: task.OrchestrationContext, _):
1047-
t1 = ctx.call_activity(dummy_activity,
1049+
t1 = ctx.call_activity(dummy_activity,
10481050
retry_policy=task.RetryPolicy(
1049-
first_retry_interval=timedelta(seconds=2),
1050-
max_number_of_attempts=3,
1051-
backoff_coefficient=4,
1052-
max_retry_interval=timedelta(seconds=5),
1053-
retry_timeout=timedelta(seconds=50)),
1051+
first_retry_interval=timedelta(seconds=2),
1052+
max_number_of_attempts=3,
1053+
backoff_coefficient=4,
1054+
max_retry_interval=timedelta(seconds=5),
1055+
retry_timeout=timedelta(seconds=50)),
10541056
input="Tokyo")
10551057
t2 = ctx.call_activity(dummy_activity, input="Seattle")
10561058
results = yield task.when_all([t1, t2])
@@ -1117,7 +1119,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
11171119
assert actions[1].id == 1
11181120

11191121
ex = ValueError("Kah-BOOOOM!!!")
1120-
1122+
11211123
# Simulate the task failing for the third time. Overall workflow should fail at this point.
11221124
old_events = old_events + new_events
11231125
new_events = [
@@ -1130,6 +1132,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
11301132
assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error?
11311133
assert str(ex) in complete_action.failureDetails.errorMessage
11321134

1135+
11331136
def get_and_validate_single_complete_orchestration_action(actions: List[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction:
11341137
assert len(actions) == 1
11351138
assert type(actions[0]) is pb.OrchestratorAction

0 commit comments

Comments
 (0)