Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 29 additions & 31 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ class TaskHubGrpcWorker:
_response_stream: Optional[grpc.Future] = None

def __init__(self, *,
host_address: Union[str, None] = None,
metadata: Union[List[Tuple[str, str]], None] = None,
log_handler = None,
log_formatter: Union[logging.Formatter, None] = None,
host_address: Optional[str] = None,
metadata: Optional[List[Tuple[str, str]]] = None,
log_handler=None,
log_formatter: Optional[logging.Formatter] = None,
secure_channel: bool = False):
self._registry = _Registry()
self._host_address = host_address if host_address else shared.get_default_host_address()
Expand Down Expand Up @@ -259,22 +259,20 @@ def resume(self):
# has reached a completed state. The only time this won't be the
# case is if the user yielded on a WhenAll task and there are still
# outstanding child tasks that need to be completed.
if self._previous_task is not None:
while self._previous_task is not None and self._previous_task.is_complete:
next_task = None
if self._previous_task.is_failed:
# Raise the failure as an exception to the generator. The orchestrator can then either
# handle the exception or allow it to fail the orchestration.
self._generator.throw(self._previous_task.get_exception())
elif self._previous_task.is_complete:
while True:
# Resume the generator. This will either return a Task or raise StopIteration if it's done.
# CONSIDER: Should we check for possible infinite loops here?
next_task = self._generator.send(self._previous_task.get_result())
if not isinstance(next_task, task.Task):
raise TypeError("The orchestrator generator yielded a non-Task object")
self._previous_task = next_task
# If a completed task was returned, then we can keep running the generator function.
if not self._previous_task.is_complete:
break
# Raise the failure as an exception to the generator.
# The orchestrator can then either handle the exception or allow it to fail the orchestration.
next_task = self._generator.throw(self._previous_task.get_exception())
else:
# Resume the generator with the previous result.
# This will either return a Task or raise StopIteration if it's done.
next_task = self._generator.send(self._previous_task.get_result())

if not isinstance(next_task, task.Task):
raise TypeError("The orchestrator generator yielded a non-Task object")
self._previous_task = next_task

def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False):
if self._is_complete:
Expand Down Expand Up @@ -359,9 +357,9 @@ def current_utc_datetime(self, value: datetime):

def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task:
return self.create_timer_internal(fire_at)

def create_timer_internal(self, fire_at: Union[datetime, timedelta],
retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
retryable_task: Optional[task.RetryableTask] = None) -> task.Task:
id = self.next_sequence_number()
if isinstance(fire_at, timedelta):
fire_at = self.current_utc_datetime + fire_at
Expand Down Expand Up @@ -390,9 +388,9 @@ def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput]
id = self.next_sequence_number()
orchestrator_name = task.get_name(orchestrator)
self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy,
is_sub_orch=True, instance_id=instance_id)
is_sub_orch=True, instance_id=instance_id)
return self._pending_tasks.get(id, task.CompletableTask())

def call_activity_function_helper(self, id: Optional[int],
activity_function: Union[task.Activity[TInput, TOutput], str], *,
input: Optional[TInput] = None,
Expand All @@ -402,14 +400,14 @@ def call_activity_function_helper(self, id: Optional[int],
fn_task: Optional[task.CompletableTask[TOutput]] = None):
if id is None:
id = self.next_sequence_number()

if fn_task is None:
encoded_input = shared.to_json(input) if input is not None else None
else:
# Here, we don't need to convert the input to JSON because it is already converted.
# We just need to take string representation of it.
encoded_input = str(input)
if is_sub_orch == False:
if not is_sub_orch:
name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function)
action = ph.new_schedule_task_action(id, name, encoded_input)
else:
Expand Down Expand Up @@ -495,7 +493,7 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e
if not ctx._is_complete:
task_count = len(ctx._pending_tasks)
event_count = len(ctx._pending_events)
self._logger.info(f"{instance_id}: Waiting for {task_count} task(s) and {event_count} event(s).")
self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.")
elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW:
completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status)
self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}")
Expand Down Expand Up @@ -556,8 +554,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
timer_task.complete(None)
if timer_task._retryable_parent is not None:
activity_action = timer_task._retryable_parent._action
if timer_task._retryable_parent._is_sub_orch == False:

if not timer_task._retryable_parent._is_sub_orch:
cur_task = activity_action.scheduleTask
instance_id = None
else:
Expand Down Expand Up @@ -612,11 +610,11 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
self._logger.warning(
f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.")
return

if isinstance(activity_task, task.RetryableTask):
if activity_task._retry_policy is not None:
next_delay = activity_task.compute_next_delay()
if next_delay == None:
if next_delay is None:
activity_task.fail(
f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}",
event.taskFailed.failureDetails)
Expand Down Expand Up @@ -674,7 +672,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
if isinstance(sub_orch_task, task.RetryableTask):
if sub_orch_task._retry_policy is not None:
next_delay = sub_orch_task.compute_next_delay()
if next_delay == None:
if next_delay is None:
sub_orch_task.fail(
f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}",
failedEvent.failureDetails)
Expand Down
76 changes: 63 additions & 13 deletions tests/test_orchestration_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,51 @@ def sequence(ctx: task.OrchestrationContext, start_val: int):
assert state.serialized_custom_status is None


def test_activity_error_handling():

def throw(_: task.ActivityContext, input: int) -> int:
raise RuntimeError("Kah-BOOOOM!!!")

compensation_counter = 0

def increment_counter(ctx, _):
nonlocal compensation_counter
compensation_counter += 1

def orchestrator(ctx: task.OrchestrationContext, input: int):
error_msg = ""
try:
yield ctx.call_activity(throw, input=input)
except task.TaskFailedError as e:
error_msg = e.details.message

# compensating actions
yield ctx.call_activity(increment_counter)
yield ctx.call_activity(increment_counter)

return error_msg

# Start a worker, which will connect to the sidecar in a background thread
with worker.TaskHubGrpcWorker() as w:
w.add_orchestrator(orchestrator)
w.add_activity(throw)
w.add_activity(increment_counter)
w.start()

task_hub_client = client.TaskHubGrpcClient()
id = task_hub_client.schedule_new_orchestration(orchestrator, input=1)
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)

assert state is not None
assert state.name == task.get_name(orchestrator)
assert state.instance_id == id
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
assert state.serialized_output == json.dumps("Kah-BOOOOM!!!")
assert state.failure_details is None
assert state.serialized_custom_status is None
assert compensation_counter == 2


def test_sub_orchestration_fan_out():
threadLock = threading.Lock()
activity_counter = 0
Expand Down Expand Up @@ -269,10 +314,14 @@ def orchestrator(ctx: task.OrchestrationContext, input: int):
assert state.serialized_input == json.dumps(4)
assert all_results == [1, 2, 3, 4, 5]


# NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet
# support orchestration ID reuse. This gap is being tracked here:
# https://github.com/microsoft/durabletask-go/issues/42
def test_retry_policies():
# This test verifies that the retry policies are working as expected.
# It does this by creating an orchestration that calls a sub-orchestrator,
# which in turn calls an activity that always fails.
# which in turn calls an activity that always fails.
# In this test, the retry policies are added, and the orchestration
# should still fail. But, number of times the sub-orchestrator and activity
# is called should increase as per the retry policies.
Expand All @@ -281,12 +330,12 @@ def test_retry_policies():
throw_activity_counter = 0

# Second setup: With retry policies
retry_policy=task.RetryPolicy(
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=3,
backoff_coefficient=1,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=30))
retry_policy = task.RetryPolicy(
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=3,
backoff_coefficient=1,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=30))

def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _):
yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy)
Expand Down Expand Up @@ -323,18 +372,19 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _):
assert throw_activity_counter == 9
assert child_orch_counter == 3


def test_retry_timeout():
# This test verifies that the retry timeout is working as expected.
# Max number of attempts is 5 and retry timeout is 14 seconds.
# Total seconds consumed till 4th attempt is 1 + 2 + 4 + 8 = 15 seconds.
# So, the 5th attempt should not be made and the orchestration should fail.
throw_activity_counter = 0
retry_policy=task.RetryPolicy(
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=5,
backoff_coefficient=2,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=14))
retry_policy = task.RetryPolicy(
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=5,
backoff_coefficient=2,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=14))

def mock_orchestrator(ctx: task.OrchestrationContext, _):
yield ctx.call_activity(throw_activity, retry_policy=retry_policy)
Expand Down
31 changes: 17 additions & 14 deletions tests/test_orchestration_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None),
helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))]
expected_fire_at = current_timestamp + timedelta(seconds=1)

new_events = [
helpers.new_orchestrator_started_event(timestamp=current_timestamp),
helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))]
Expand Down Expand Up @@ -391,6 +391,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input):
assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!")
assert actions[0].id == 7


def test_nondeterminism_expected_timer():
"""Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead"""
def dummy_activity(ctx, _):
Expand Down Expand Up @@ -964,13 +965,13 @@ def dummy_activity(_, inp: str):
return f"Hello {inp}!"

def orchestrator(ctx: task.OrchestrationContext, _):
t1 = ctx.call_activity(dummy_activity,
t1 = ctx.call_activity(dummy_activity,
retry_policy=task.RetryPolicy(
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=6,
backoff_coefficient=2,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=50)),
first_retry_interval=timedelta(seconds=1),
max_number_of_attempts=6,
backoff_coefficient=2,
max_retry_interval=timedelta(seconds=10),
retry_timeout=timedelta(seconds=50)),
input="Tokyo")
t2 = ctx.call_activity(dummy_activity, input="Seattle")
winner = yield task.when_any([t1, t2])
Expand Down Expand Up @@ -1036,6 +1037,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED
assert complete_action.result.value == encoded_output


def test_when_all_with_retry():
"""Tests that a when_all pattern works correctly with retries"""
def dummy_activity(ctx, inp: str):
Expand All @@ -1044,13 +1046,13 @@ def dummy_activity(ctx, inp: str):
return f"Hello {inp}!"

def orchestrator(ctx: task.OrchestrationContext, _):
t1 = ctx.call_activity(dummy_activity,
t1 = ctx.call_activity(dummy_activity,
retry_policy=task.RetryPolicy(
first_retry_interval=timedelta(seconds=2),
max_number_of_attempts=3,
backoff_coefficient=4,
max_retry_interval=timedelta(seconds=5),
retry_timeout=timedelta(seconds=50)),
first_retry_interval=timedelta(seconds=2),
max_number_of_attempts=3,
backoff_coefficient=4,
max_retry_interval=timedelta(seconds=5),
retry_timeout=timedelta(seconds=50)),
input="Tokyo")
t2 = ctx.call_activity(dummy_activity, input="Seattle")
results = yield task.when_all([t1, t2])
Expand Down Expand Up @@ -1117,7 +1119,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
assert actions[1].id == 1

ex = ValueError("Kah-BOOOOM!!!")

# Simulate the task failing for the third time. Overall workflow should fail at this point.
old_events = old_events + new_events
new_events = [
Expand All @@ -1130,6 +1132,7 @@ def orchestrator(ctx: task.OrchestrationContext, _):
assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error?
assert str(ex) in complete_action.failureDetails.errorMessage


def get_and_validate_single_complete_orchestration_action(actions: List[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction:
assert len(actions) == 1
assert type(actions[0]) is pb.OrchestratorAction
Expand Down