diff --git a/.vscode/launch.json b/.vscode/launch.json index 513e8a3..f9f1742 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,12 +6,14 @@ "type": "python", "request": "launch", "program": "${file}", + "cwd": "${fileDirname}", "purpose": [ "debug-test" ], "env": { // pytest-cov breaks debugging, so we have to disable it during debug sessions - "PYTEST_ADDOPTS": "--no-cov" + "PYTEST_ADDOPTS": "--no-cov", + "PYTHONPATH": "${workspaceFolder}" }, "console": "integratedTerminal", "justMyCode": false diff --git a/README.md b/README.md index a93ee29..5816bad 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,37 @@ def orchestrator(ctx: task.OrchestrationContext, _): You can find the full sample [here](./examples/fanout_fanin.py). +### Human interaction and durable timers + +An orchestration can wait for a user-defined event, such as a human approval event, before proceding to the next step. In addition, the orchestration can create a timer with an arbitrary duration that triggers some alternate action if the external event hasn't been received: + +```python +def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): + """Orchestrator function that represents a purchase order workflow""" + # Orders under $1000 are auto-approved + if order.Cost < 1000: + return "Auto-approved" + + # Orders of $1000 or more require manager approval + yield ctx.call_activity(send_approval_request, input=order) + + # Approvals must be received within 24 hours or they will be canceled. + approval_event = ctx.wait_for_external_event("approval_received") + timeout_event = ctx.create_timer(timedelta(hours=24)) + winner = yield task.when_any([approval_event, timeout_event]) + if winner == timeout_event: + return "Canceled" + + # The order was approved + ctx.call_activity(place_order, input=order) + approval_details = approval_event.get_result() + return f"Approved by '{approval_details.approver}'" +``` + +As an aside, you'll also notice that the example orchestration above works with custom business objects. Support for custom business objects includes support for custom classes, custom data classes, and named tuples. Serialization and deserialization of these objects is handled automatically by the SDK. + +You can find the full sample [here](./examples/human_interaction.py). + ## Getting Started ### Prerequisites diff --git a/durabletask/client.py b/durabletask/client.py index a8911df..8b6c0a6 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,10 +6,9 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import TypeVar +from typing import Any, TypeVar import grpc -import simplejson as json from google.protobuf import wrappers_pb2 import durabletask.internal.helpers as helpers @@ -46,7 +45,23 @@ class OrchestrationState: serialized_input: str | None serialized_output: str | None serialized_custom_status: str | None - failure_details: pb.TaskFailureDetails | None + failure_details: task.FailureDetails | None + + def raise_if_failed(self): + if self.failure_details is not None: + raise OrchestrationFailedError( + f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}", + self.failure_details) + + +class OrchestrationFailedError(Exception): + def __init__(self, message: str, failure_details: task.FailureDetails): + super().__init__(message) + self._failure_details = failure_details + + @property + def failure_details(self): + return self._failure_details def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> OrchestrationState | None: @@ -54,6 +69,14 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or return None state = res.orchestrationState + + failure_details = None + if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '': + failure_details = task.FailureDetails( + state.failureDetails.errorMessage, + state.failureDetails.errorType, + state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None) + return OrchestrationState( instance_id, state.name, @@ -63,7 +86,7 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Or state.input.value if not helpers.is_empty(state.input) else None, state.output.value if not helpers.is_empty(state.output) else None, state.customStatus.value if not helpers.is_empty(state.customStatus) else None, - state.failureDetails if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '' else None) + failure_details) class TaskHubGrpcClient: @@ -86,7 +109,7 @@ def schedule_new_orchestration(self, orchestrator: task.Orchestrator[TInput, TOu req = pb.CreateInstanceRequest( name=name, instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=wrappers_pb2.StringValue(value=json.dumps(input)) if input else None, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input else None, scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") @@ -128,6 +151,16 @@ def wait_for_orchestration_completion(self, instance_id: str, *, else: raise + def raise_orchestration_event(self, instance_id: str, event_name: str, *, + data: Any | None = None): + req = pb.RaiseEventRequest( + instanceId=instance_id, + name=event_name, + input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) + + self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") + self._stub.RaiseEvent(req) + def terminate_orchestration(self): pass @@ -136,6 +169,3 @@ def suspend_orchestration(self): def resume_orchestration(self): pass - - def raise_orchestration_event(self): - pass diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index bbf76f1..5bc654d 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -3,9 +3,7 @@ import traceback from datetime import datetime -from typing import Any -import simplejson as json from google.protobuf import timestamp_pb2, wrappers_pb2 import durabletask.internal.orchestrator_service_pb2 as pb @@ -117,6 +115,14 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails: ) +def new_event_raised_event(name: str, encoded_input: str | None = None) -> pb.HistoryEvent: + return pb.HistoryEvent( + eventId=-1, + timestamp=timestamp_pb2.Timestamp(), + eventRaised=pb.EventRaisedEvent(name=name, input=get_string_value(encoded_input)) + ) + + def get_string_value(val: str | None) -> wrappers_pb2.StringValue | None: if val is None: return None @@ -146,8 +152,7 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp)) -def new_schedule_task_action(id: int, name: str, input: Any) -> pb.OrchestratorAction: - encoded_input = json.dumps(input) if input is not None else None +def new_schedule_task_action(id: int, name: str, encoded_input: str | None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, scheduleTask=pb.ScheduleTaskAction( name=name, input=get_string_value(encoded_input) @@ -164,8 +169,7 @@ def new_create_sub_orchestration_action( id: int, name: str, instance_id: str | None, - input: Any) -> pb.OrchestratorAction: - encoded_input = json.dumps(input) if input is not None else None + encoded_input: str | None) -> pb.OrchestratorAction: return pb.OrchestratorAction(id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( name=name, instanceId=instance_id, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index b42d1f2..1e65415 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -1,10 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import dataclasses +import json import logging +from types import SimpleNamespace +from typing import Any, Dict import grpc +# Field name used to indicate that an object was automatically serialized +# and should be deserialized as a SimpleNamespace +AUTO_SERIALIZED = "__durabletask_autoobject__" + def get_default_host_address() -> str: return "localhost:4001" @@ -35,3 +43,49 @@ def get_logger( datefmt='%Y-%m-%d %H:%M:%S') log_handler.setFormatter(log_formatter) return logger + + +def to_json(obj): + return json.dumps(obj, cls=InternalJSONEncoder) + + +def from_json(json_str): + return json.loads(json_str, cls=InternalJSONDecoder) + + +class InternalJSONEncoder(json.JSONEncoder): + """JSON encoder that supports serializing specific Python types.""" + + def encode(self, obj: Any) -> str: + # if the object is a namedtuple, convert it to a dict with the AUTO_SERIALIZED key added + if isinstance(obj, tuple) and hasattr(obj, "_fields") and hasattr(obj, "_asdict"): + d = obj._asdict() # type: ignore + d[AUTO_SERIALIZED] = True + obj = d + return super().encode(obj) + + def default(self, obj): + if dataclasses.is_dataclass(obj): + # Dataclasses are not serializable by default, so we convert them to a dict and mark them for + # automatic deserialization by the receiver + d = dataclasses.asdict(obj) + d[AUTO_SERIALIZED] = True + return d + elif isinstance(obj, SimpleNamespace): + # Most commonly used for serializing custom objects that were previously serialized using our encoder + d = vars(obj) + d[AUTO_SERIALIZED] = True + return d + # This will typically raise a TypeError + return json.JSONEncoder.default(self, obj) + + +class InternalJSONDecoder(json.JSONDecoder): + def __init__(self, *args, **kwargs): + super().__init__(object_hook=self.dict_to_object, *args, **kwargs) + + def dict_to_object(self, d: Dict[str, Any]): + # If the object was serialized by the InternalJSONEncoder, deserialize it as a SimpleNamespace + if d.pop(AUTO_SERIALIZED, False): + return SimpleNamespace(**d) + return d diff --git a/durabletask/task.py b/durabletask/task.py index 1a23ae2..579302e 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -5,7 +5,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from datetime import datetime +from datetime import datetime, timedelta from typing import Any, Callable, Generator, Generic, List, TypeVar import durabletask.internal.helpers as pbh @@ -70,13 +70,13 @@ def is_replaying(self) -> bool: pass @abstractmethod - def create_timer(self, fire_at: datetime) -> Task: + def create_timer(self, fire_at: datetime | timedelta) -> Task: """Create a Timer Task to fire after at the specified deadline. Parameters ---------- - fire_at: datetime.datetime - The time for the timer to trigger + fire_at: datetime.datetime | datetime.timedelta + The time for the timer to trigger or a time delta from now. Returns ------- @@ -129,12 +129,27 @@ def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, """ pass + # TOOD: Add a timeout parameter, which allows the task to be canceled if the event is + # not received within the specified timeout. This requires support for task cancellation. + @abstractmethod + def wait_for_external_event(self, name: str) -> Task: + """Wait asynchronously for an event to be raised with the name `name`. -class TaskFailedError(Exception): - """Exception type for all orchestration task failures.""" + Parameters + ---------- + name : str + The event name of the event that the task is waiting for. + Returns + ------- + Task[TOutput] + A Durable Task that completes when the event is received. + """ + pass + + +class FailureDetails: def __init__(self, message: str, error_type: str, stack_trace: str | None): - super().__init__(message) self._message = message self._error_type = error_type self._stack_trace = stack_trace @@ -152,6 +167,21 @@ def stack_trace(self) -> str | None: return self._stack_trace +class TaskFailedError(Exception): + """Exception type for all orchestration task failures.""" + + def __init__(self, message: str, details: pb.TaskFailureDetails): + super().__init__(message) + self._details = FailureDetails( + details.errorMessage, + details.errorType, + details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None) + + @property + def details(self) -> FailureDetails: + return self._details + + class NonDeterminismError(Exception): pass @@ -208,6 +238,8 @@ def __init__(self, tasks: List[Task]): self._failed_tasks = 0 for task in tasks: task._parent = self + if task.is_complete: + self.on_child_completed(task) def get_tasks(self) -> List[Task]: return self._tasks @@ -230,13 +262,10 @@ def complete(self, result: T): if self._parent is not None: self._parent.on_child_completed(self) - def fail(self, details: pb.TaskFailureDetails): + def fail(self, message: str, details: pb.TaskFailureDetails): if self._is_complete: raise ValueError('The task has already completed.') - self._exception = TaskFailedError( - details.errorMessage, - details.errorType, - details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None) + self._exception = TaskFailedError(message, details) self._is_complete = True if self._parent is not None: self._parent.on_child_completed(self) @@ -278,6 +307,7 @@ def __init__(self, tasks: List[Task]): super().__init__(tasks) def on_child_completed(self, task: Task): + # The first task to complete is the result of the WhenAnyTask. if not self.is_complete: self._is_complete = True self._result = task diff --git a/durabletask/worker.py b/durabletask/worker.py index 090ec43..419e919 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -3,13 +3,13 @@ import concurrent.futures import logging -from datetime import datetime +from dataclasses import dataclass +from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType from typing import Any, Dict, Generator, List, Sequence, TypeVar import grpc -import simplejson as json from google.protobuf import empty_pb2 import durabletask.internal.helpers as ph @@ -207,6 +207,12 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}") +@dataclass +class _ExternalEvent: + name: str + data: Any + + class _RuntimeOrchestrationContext(task.OrchestrationContext): _generator: Generator[task.Task, Any, Any] | None _previous_task: task.Task | None @@ -222,6 +228,8 @@ def __init__(self, instance_id: str): self._current_utc_datetime = datetime(1000, 1, 1) self._instance_id = instance_id self._completion_status: pb.OrchestrationStatus | None = None + self._received_events: Dict[str, List[_ExternalEvent]] = {} + self._pending_events: Dict[str, List[task.CompletableTask]] = {} def run(self, generator: Generator[task.Task, Any, Any]): self._generator = generator @@ -245,17 +253,23 @@ def resume(self): # handle the exception or allow it to fail the orchestration. self._generator.throw(self._previous_task.get_exception()) elif self._previous_task.is_complete: - # Resume the generator. This will either return a Task or raise StopIteration if it's done. - next_task = self._generator.send(self._previous_task.get_result()) - # TODO: Validate the return value - self._previous_task = next_task + while True: + # Resume the generator. This will either return a Task or raise StopIteration if it's done. + # CONSIDER: Should we check for possible infinite loops here? + next_task = self._generator.send(self._previous_task.get_result()) + if not isinstance(next_task, task.Task): + raise TypeError("The orchestrator generator yielded a non-Task object") + self._previous_task = next_task + # If a completed task was returned, then we can keep running the generator function. + if not self._previous_task.is_complete: + break def set_complete(self, result: Any): self._is_complete = True self._result = result result_json: str | None = None if result is not None: - result_json = json.dumps(result) + result_json = shared.to_json(result) action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_COMPLETED, result_json) self._pending_actions[action.id] = action @@ -291,8 +305,10 @@ def is_replaying(self) -> bool: def current_utc_datetime(self, value: datetime): self._current_utc_datetime = value - def create_timer(self, fire_at: datetime) -> task.Task: + def create_timer(self, fire_at: datetime | timedelta) -> task.Task: id = self.next_sequence_number() + if isinstance(fire_at, timedelta): + fire_at = self.current_utc_datetime + fire_at action = ph.new_create_timer_action(id, fire_at) self._pending_actions[id] = action @@ -304,7 +320,8 @@ def call_activity(self, activity: task.Activity[TInput, TOutput], *, input: TInput | None = None) -> task.Task[TOutput]: id = self.next_sequence_number() name = task.get_name(activity) - action = ph.new_schedule_task_action(id, name, input) + encoded_input = shared.to_json(input) if input else None + action = ph.new_schedule_task_action(id, name, encoded_input) self._pending_actions[id] = action activity_task = task.CompletableTask[TOutput]() @@ -319,13 +336,36 @@ def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput] if instance_id is None: # Create a deteministic instance ID based on the parent instance ID instance_id = f"{self.instance_id}:{id:04x}" - action = ph.new_create_sub_orchestration_action(id, name, instance_id, input) + encoded_input = shared.to_json(input) if input else None + action = ph.new_create_sub_orchestration_action(id, name, instance_id, encoded_input) self._pending_actions[id] = action sub_orch_task = task.CompletableTask[TOutput]() self._pending_tasks[id] = sub_orch_task return sub_orch_task + def wait_for_external_event(self, name: str) -> task.Task: + # Check to see if this event has already been received, in which case we + # can return it immediately. Otherwise, record out intent to receive an + # event with the given name so that we can resume the generator when it + # arrives. If there are multiple events with the same name, we return + # them in the order they were received. + external_event_task = task.CompletableTask() + event_name = name.upper() + event_list = self._received_events.get(event_name, None) + if event_list: + event = event_list.pop(0) + if not event_list: + del self._received_events[event_name] + external_event_task.complete(event.data) + else: + task_list = self._pending_events.get(event_name, None) + if not task_list: + task_list = [] + self._pending_events[event_name] = task_list + task_list.append(external_event_task) + return external_event_task + class _OrchestrationExecutor: _generator: task.Orchestrator | None @@ -381,7 +421,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # deserialize the input, if any input = None if event.executionStarted.input is not None and event.executionStarted.input.value != "": - input = json.loads(event.executionStarted.input.value) + input = shared.from_json(event.executionStarted.input.value) result = fn(ctx, input) # this does not execute the generator, only creates it if isinstance(result, GeneratorType): @@ -437,7 +477,7 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven return result = None if not ph.is_empty(event.taskCompleted.result): - result = json.loads(event.taskCompleted.result.value) + result = shared.from_json(event.taskCompleted.result.value) activity_task.complete(result) ctx.resume() elif event.HasField("taskFailed"): @@ -448,7 +488,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven self._logger.warning( f"Ignoring unexpected taskFailed event for '{ctx.instance_id}' with ID = {task_id}.") return - activity_task.fail(event.taskFailed.failureDetails) + activity_task.fail( + f"Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", + event.taskFailed.failureDetails) ctx.resume() elif event.HasField("subOrchestrationInstanceCreated"): # This history event confirms that the sub-orchestration execution was successfully scheduled. @@ -476,19 +518,48 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven return result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): - result = json.loads(event.subOrchestrationInstanceCompleted.result.value) + result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value) sub_orch_task.complete(result) ctx.resume() elif event.HasField("subOrchestrationInstanceFailed"): - task_id = event.subOrchestrationInstanceFailed.taskScheduledId + failedEvent = event.subOrchestrationInstanceFailed + task_id = failedEvent.taskScheduledId sub_orch_task = ctx._pending_tasks.pop(task_id, None) if not sub_orch_task: # TODO: Should this be an error? When would it ever happen? self._logger.warning( f"Ignoring unexpected subOrchestrationInstanceFailed event for '{ctx.instance_id}' with ID = {task_id}.") return - sub_orch_task.fail(event.subOrchestrationInstanceFailed.failureDetails) + sub_orch_task.fail( + f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", + failedEvent.failureDetails) ctx.resume() + elif event.HasField("eventRaised"): + # event names are case-insensitive + event_name = event.eventRaised.name.upper() + if not ctx.is_replaying: + self._logger.info(f"Event raised: {event_name}") + task_list = ctx._pending_events.get(event_name, None) + decoded_result: Any | None = None + if task_list: + event_task = task_list.pop(0) + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_task.complete(decoded_result) + if not task_list: + del ctx._pending_events[event_name] + ctx.resume() + else: + # buffer the event + event_list = ctx._received_events.get(event_name, None) + if not event_list: + event_list = [] + ctx._received_events[event_name] = event_list + if not ph.is_empty(event.eventRaised.input): + decoded_result = shared.from_json(event.eventRaised.input.value) + event_list.append(_ExternalEvent(event.eventRaised.name, decoded_result)) + if not ctx.is_replaying: + self._logger.info(f"Event '{event_name}' has been buffered as there are no tasks waiting for it.") else: eventType = event.WhichOneof("eventType") raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'") @@ -509,13 +580,13 @@ def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: if not fn: raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!") - activity_input = json.loads(encoded_input) if encoded_input else None + activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext(orchestration_id, task_id) # Execute the activity function activity_output = fn(ctx, activity_input) - encoded_output = json.dumps(activity_output) if activity_output is not None else None + encoded_output = shared.to_json(activity_output) if activity_output is not None else None chars = len(encoded_output) if encoded_output else 0 self._logger.debug( f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.") diff --git a/examples/README.md b/examples/README.md index ca918a3..ec9088f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -18,7 +18,10 @@ With one of the sidecars running, you can simply execute any of the examples in python3 ./activity_sequence.py ``` +In some cases, the sample may require command-line parameters or user inputs. In these cases, the sample will print out instructions on how to proceed. + ## List of examples - [Activity sequence](./activity_sequence.py): Orchestration that schedules three activity calls in a sequence. -- [Fan-out/fan-in](./fanout_fanin.py): Orchestration that schedules a dynamic number of activity calls in parallel, waits for all of them to complete, and then performs an aggregation on the results. \ No newline at end of file +- [Fan-out/fan-in](./fanout_fanin.py): Orchestration that schedules a dynamic number of activity calls in parallel, waits for all of them to complete, and then performs an aggregation on the results. +- [Human interaction](./human_interaction.py): Orchestration that waits for a human to approve an order before continuing. \ No newline at end of file diff --git a/examples/human_interaction.py b/examples/human_interaction.py new file mode 100644 index 0000000..7aa4599 --- /dev/null +++ b/examples/human_interaction.py @@ -0,0 +1,99 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that waits for an "approval" event before proceding to the next step. If +the approval isn't received within a specified timeout, the order that is +represented by the orchestration is automatically cancelled.""" + +import threading +import time +from collections import namedtuple +from dataclasses import dataclass +from datetime import timedelta + +from durabletask import client, task, worker + + +@dataclass +class Order: + """Represents a purchase order""" + Cost: float + Product: str + Quantity: int + + def __str__(self): + return f'{self.Product} ({self.Quantity})' + + +def send_approval_request(_: task.ActivityContext, order: Order) -> None: + """Activity function that sends an approval request to the manager""" + time.sleep(5) + print(f'*** Sending approval request for order: {order}') + + +def place_order(_: task.ActivityContext, order: Order) -> None: + """Activity function that places an order""" + print(f'*** Placing order: {order}') + + +def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): + """Orchestrator function that represents a purchase order workflow""" + # Orders under $1000 are auto-approved + if order.Cost < 1000: + return "Auto-approved" + + # Orders of $1000 or more require manager approval + yield ctx.call_activity(send_approval_request, input=order) + + # Approvals must be received within 24 hours or they will be canceled. + approval_event = ctx.wait_for_external_event("approval_received") + timeout_event = ctx.create_timer(timedelta(hours=24)) + winner = yield task.when_any([approval_event, timeout_event]) + if winner == timeout_event: + return "Cancelled" + + # The order was approved + ctx.call_activity(place_order, input=order) + approval_details = approval_event.get_result() + return f"Approved by '{approval_details.approver}'" + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Order purchasing workflow demo.") + parser.add_argument("--cost", type=int, default=2000, help="Cost of the order") + parser.add_argument("--approver", type=str, default="Me", help="Approver name") + parser.add_argument("--timeout", type=int, default=60, help="Timeout in seconds") + args = parser.parse_args() + + # configure and start the worker + with worker.TaskHubGrpcWorker() as w: + w.add_orchestrator(purchase_order_workflow) + w.add_activity(send_approval_request) + w.add_activity(place_order) + w.start() + + c = client.TaskHubGrpcClient() + + # Start a purchase order workflow using the user input + order = Order(args.cost, "MyProduct", 1) + instance_id = c.schedule_new_orchestration(purchase_order_workflow, input=order) + + def prompt_for_approval(): + input("Press [ENTER] to approve the order...\n") + approval_event = namedtuple("Approval", ["approver"])(args.approver) + c.raise_orchestration_event(instance_id, "approval_received", data=approval_event) + + # Prompt the user for approval on a background thread + threading.Thread(target=prompt_for_approval, daemon=True).start() + + # Wait for the orchestration to complete + try: + state = c.wait_for_orchestration_completion(instance_id, timeout=args.timeout + 2) + if not state: + print("Workflow not found!") # not expected + elif state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + else: + state.raise_if_failed() # raises an exception + except TimeoutError: + print("*** Orchestration timed out!") diff --git a/pyproject.toml b/pyproject.toml index 39984f5..dece6c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "grpcio", - "simplejson", ] [project.urls] diff --git a/requirements.txt b/requirements.txt index 5891b90..641cee7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,4 @@ autopep8 grpcio grpcio-tools pytest -pytest-cov -simplejson \ No newline at end of file +pytest-cov \ No newline at end of file diff --git a/tests/test_activity_executor.py b/tests/test_activity_executor.py index 645aee1..fb34a1e 100644 --- a/tests/test_activity_executor.py +++ b/tests/test_activity_executor.py @@ -1,10 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json from typing import Any, Tuple -import simplejson as json - import durabletask.internal.shared as shared from durabletask import task, worker diff --git a/tests/test_orchestration_e2e.py b/tests/test_orchestration_e2e.py index a9d9f65..f32dd5c 100644 --- a/tests/test_orchestration_e2e.py +++ b/tests/test_orchestration_e2e.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json import threading +from datetime import timedelta import pytest -import simplejson as json from durabletask import client, task, worker @@ -112,3 +113,59 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): assert state.runtime_status == client.OrchestrationStatus.COMPLETED assert state.failure_details is None assert activity_counter == 30 + + +def test_wait_for_multiple_external_events(): + def orchestrator(ctx: task.OrchestrationContext, _): + a = yield ctx.wait_for_external_event('A') + b = yield ctx.wait_for_external_event('B') + c = yield ctx.wait_for_external_event('C') + return [a, b, c] + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.start() + w.add_orchestrator(orchestrator) + + # Start the orchestration and immediately raise events to it. + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator) + task_hub_client.raise_orchestration_event(id, 'A', data='a') + task_hub_client.raise_orchestration_event(id, 'B', data='b') + task_hub_client.raise_orchestration_event(id, 'C', data='c') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.serialized_output == json.dumps(['a', 'b', 'c']) + + +@pytest.mark.parametrize("raise_event", [True, False]) +def test_wait_for_external_event_timeout(raise_event: bool): + def orchestrator(ctx: task.OrchestrationContext, _): + approval: task.Task[bool] = ctx.wait_for_external_event('Approval') + timeout = ctx.create_timer(timedelta(seconds=3)) + winner = yield task.when_any([approval, timeout]) + if winner == approval: + return "approved" + else: + return "timed out" + + # Start a worker, which will connect to the sidecar in a background thread + with worker.TaskHubGrpcWorker() as w: + w.start() + w.add_orchestrator(orchestrator) + + # Start the orchestration and immediately raise events to it. + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator) + if raise_event: + task_hub_client.raise_orchestration_event(id, 'Approval') + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + if raise_event: + assert state.serialized_output == json.dumps("approved") + else: + assert state.serialized_output == json.dumps("timed out") diff --git a/tests/test_orchestration_executor.py b/tests/test_orchestration_executor.py index c5cec2d..2933860 100644 --- a/tests/test_orchestration_executor.py +++ b/tests/test_orchestration_executor.py @@ -1,12 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import json import logging from datetime import datetime, timedelta from typing import List -import simplejson as json - import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb from durabletask import task, worker @@ -220,7 +219,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? - assert complete_action.failureDetails.errorMessage == str(ex) + assert str(ex) in complete_action.failureDetails.errorMessage # Make sure the line of code where the exception was raised is included in the stack trace user_code_statement = "ctx.call_activity(dummy_activity, input=orchestrator_input)" @@ -400,7 +399,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? - assert complete_action.failureDetails.errorMessage == str(ex) + assert str(ex) in complete_action.failureDetails.errorMessage # Make sure the line of code where the exception was raised is included in the stack trace user_code_statement = "ctx.call_sub_orchestrator(suborchestrator)" @@ -463,6 +462,69 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert "call_sub_orchestrator" in complete_action.failureDetails.errorMessage # expected method name +def test_raise_event(): + """Tests that an orchestration can wait for and process an external event sent by a client""" + def orchestrator(ctx: task.OrchestrationContext, _): + result = yield ctx.wait_for_external_event("my_event") + return result + + registry = worker._Registry() + orchestrator_name = registry.add_orchestrator(orchestrator) + + old_events = [] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID)] + + # Execute the orchestration until it is waiting for an external event. The result + # should be an empty list of actions because the orchestration didn't schedule any work. + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 0 + + # Now send an external event to the orchestration and execute it again. This time + # the orchestration should complete. + old_events = new_events + new_events = [helpers.new_event_raised_event("my_event", encoded_input="42")] + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == "42" + + +def test_raise_event_buffered(): + """Tests that an orchestration can receive an event that arrives earlier than expected""" + def orchestrator(ctx: task.OrchestrationContext, _): + yield ctx.create_timer(ctx.current_utc_datetime + timedelta(days=1)) + result = yield ctx.wait_for_external_event("my_event") + return result + + registry = worker._Registry() + orchestrator_name = registry.add_orchestrator(orchestrator) + + old_events = [] + new_events = [ + helpers.new_orchestrator_started_event(), + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID), + helpers.new_event_raised_event("my_event", encoded_input="42")] + + # Execute the orchestration. It should be in a running state waiting for the timer to fire + executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + assert len(actions) == 1 + assert actions[0].HasField("createTimer") + + # Complete the timer task. The orchestration should move to the wait_for_external_event step, which + # should then complete immediately because the event was buffered in the old event history. + timer_due_time = datetime.utcnow() + timedelta(days=1) + old_events = new_events + [helpers.new_timer_created_event(1, timer_due_time)] + new_events = [helpers.new_timer_fired_event(1, timer_due_time)] + actions = executor.execute(TEST_INSTANCE_ID, old_events, new_events) + complete_action = get_and_validate_single_complete_orchestration_action(actions) + assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_COMPLETED + assert complete_action.result.value == "42" + + def test_fan_out(): """Tests that a fan-out pattern correctly schedules N tasks""" def hello(_, name: str): @@ -577,7 +639,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Is this the right error type? - assert complete_action.failureDetails.errorMessage == str(ex) + assert str(ex) in complete_action.failureDetails.errorMessage def test_when_any():