From b4b83d30635908b1fe5866ca80ca386f5d2c231a Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Thu, 9 Oct 2025 19:43:55 -0500 Subject: [PATCH 1/3] - Fix Critical bug in continue_as_new bug where it was missing router on the complete action. Linting - Normalize strings to use double-quotes consistently across tests and source files as used in msft durabletask. - Run ruff format - Add `test_continue_as_new_with_activity_e2e` for better coverage of `continue_as_new` functionality. - Modify `.flake8` for extended exclusions and per-file ignores. - Introduce `tox.ini` for test environment configurations and streamline linting, typing, and example validation. Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- .flake8 | 11 +- .github/workflows/pr-validation.yml | 3 +- dev-requirements.txt | 9 +- durabletask/__init__.py | 1 - durabletask/client.py | 116 ++++-- durabletask/internal/grpc_interceptor.py | 34 +- durabletask/internal/helpers.py | 115 +++--- durabletask/internal/shared.py | 45 +- durabletask/task.py | 105 +++-- durabletask/worker.py | 258 +++++------- examples/activity_sequence.py | 13 +- examples/fanout_fanin.py | 17 +- examples/human_interaction.py | 9 +- pyproject.toml | 30 +- tests/durabletask/test_activity_executor.py | 11 +- tests/durabletask/test_client.py | 31 +- tests/durabletask/test_concurrency_options.py | 8 +- tests/durabletask/test_orchestration_e2e.py | 84 ++-- .../test_orchestration_executor.py | 384 ++++++++++++------ tests/durabletask/test_orchestration_wait.py | 32 +- .../test_worker_concurrency_loop.py | 43 +- .../test_worker_concurrency_loop_async.py | 42 +- tox.ini | 103 +++++ 23 files changed, 955 insertions(+), 549 deletions(-) create mode 100644 tox.ini diff --git a/.flake8 b/.flake8 index ecc399c..1b01528 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,7 @@ [flake8] -ignore = E501,C901 -exclude = - .git - *_pb2* - __pycache__ \ No newline at end of file +ignore = E203,E501,W503,E701,E704,F821,C901 +extend-exclude = .tox,venv,.venv,build,**/.venv,**/venv,*pb2_grpc.py,*pb2.py +per-file-ignores= + examples/**:F541 setup.py:E121 + tests/**:F541,E712 +max-line-length = 100 \ No newline at end of file diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 63540ac..3c3313a 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -28,8 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install flake8 pytest - pip install -r requirements.txt + pip install -r dev-requirements.txt - name: Lint with flake8 run: | flake8 . --count --show-source --statistics --exit-zero diff --git a/dev-requirements.txt b/dev-requirements.txt index 119f072..bccde37 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1 +1,8 @@ -grpcio-tools==1.62.3 # 1.62.X is the latest version before protobuf 1.26.X is used which has breaking changes for Python +# TODO: move to pyproject optional-dependencies +pytest-asyncio>=0.23 +flake8 +tox>=4.0.0 +pytest +pytest-cov +grpcio-tools==1.75.1 +protobuf>=6.31.1 \ No newline at end of file diff --git a/durabletask/__init__.py b/durabletask/__init__.py index a37823c..78ea7ca 100644 --- a/durabletask/__init__.py +++ b/durabletask/__init__.py @@ -3,5 +3,4 @@ """Durable Task SDK for Python""" - PACKAGE_NAME = "durabletask" diff --git a/durabletask/client.py b/durabletask/client.py index 7a72e1a..1fe8688 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -18,12 +18,13 @@ from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class OrchestrationStatus(Enum): """The status of an orchestration instance.""" + RUNNING = pb.ORCHESTRATION_STATUS_RUNNING COMPLETED = pb.ORCHESTRATION_STATUS_COMPLETED FAILED = pb.ORCHESTRATION_STATUS_FAILED @@ -52,7 +53,8 @@ def raise_if_failed(self): if self.failure_details is not None: raise OrchestrationFailedError( f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}", - self.failure_details) + self.failure_details, + ) class OrchestrationFailedError(Exception): @@ -65,18 +67,23 @@ def failure_details(self): return self._failure_details -def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Optional[OrchestrationState]: +def new_orchestration_state( + instance_id: str, res: pb.GetInstanceResponse +) -> Optional[OrchestrationState]: if not res.exists: return None state = res.orchestrationState failure_details = None - if state.failureDetails.errorMessage != '' or state.failureDetails.errorType != '': + if state.failureDetails.errorMessage != "" or state.failureDetails.errorType != "": failure_details = task.FailureDetails( state.failureDetails.errorMessage, state.failureDetails.errorType, - state.failureDetails.stackTrace.value if not helpers.is_empty(state.failureDetails.stackTrace) else None) + state.failureDetails.stackTrace.value + if not helpers.is_empty(state.failureDetails.stackTrace) + else None, + ) return OrchestrationState( instance_id, @@ -87,19 +94,21 @@ def new_orchestration_state(instance_id: str, res: pb.GetInstanceResponse) -> Op state.input.value if not helpers.is_empty(state.input) else None, state.output.value if not helpers.is_empty(state.output) else None, state.customStatus.value if not helpers.is_empty(state.customStatus) else None, - failure_details) + failure_details, + ) class TaskHubGrpcClient: - - def __init__(self, *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler: Optional[logging.Handler] = None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): - + def __init__( + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + ): # If the caller provided metadata, we need to create a new interceptor for it and # add it to the list of interceptors. if interceptors is not None: @@ -112,25 +121,28 @@ def __init__(self, *, interceptors = None channel = shared.get_grpc_channel( - host_address=host_address, - secure_channel=secure_channel, - interceptors=interceptors + host_address=host_address, secure_channel=secure_channel, interceptors=interceptors ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) - def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - start_at: Optional[datetime] = None, - reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None) -> str: - + def schedule_new_orchestration( + self, + orchestrator: Union[task.Orchestrator[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + start_at: Optional[datetime] = None, + reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, + ) -> str: name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) req = pb.CreateInstanceRequest( name=name, instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None, + input=wrappers_pb2.StringValue(value=shared.to_json(input)) + if input is not None + else None, scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, @@ -140,19 +152,22 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu res: pb.CreateInstanceResponse = self._stub.StartInstance(req) return res.instanceId - def get_orchestration_state(self, instance_id: str, *, fetch_payloads: bool = True) -> Optional[OrchestrationState]: + def get_orchestration_state( + self, instance_id: str, *, fetch_payloads: bool = True + ) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) res: pb.GetInstanceResponse = self._stub.GetInstance(req) return new_orchestration_state(req.instanceId, res) - def wait_for_orchestration_start(self, instance_id: str, *, - fetch_payloads: bool = False, - timeout: int = 0) -> Optional[OrchestrationState]: + def wait_for_orchestration_start( + self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0 + ) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: grpc_timeout = None if timeout == 0 else timeout self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start.") + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start." + ) res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) return new_orchestration_state(req.instanceId, res) except grpc.RpcError as rpc_error: @@ -162,22 +177,30 @@ def wait_for_orchestration_start(self, instance_id: str, *, else: raise - def wait_for_orchestration_completion(self, instance_id: str, *, - fetch_payloads: bool = True, - timeout: int = 0) -> Optional[OrchestrationState]: + def wait_for_orchestration_completion( + self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0 + ) -> Optional[OrchestrationState]: req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) try: grpc_timeout = None if timeout == 0 else timeout self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete.") - res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion(req, timeout=grpc_timeout) + f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete." + ) + res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion( + req, timeout=grpc_timeout + ) state = new_orchestration_state(req.instanceId, res) if not state: return None - if state.runtime_status == OrchestrationStatus.FAILED and state.failure_details is not None: + if ( + state.runtime_status == OrchestrationStatus.FAILED + and state.failure_details is not None + ): details = state.failure_details - self._logger.info(f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}") + self._logger.info( + f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}" + ) elif state.runtime_status == OrchestrationStatus.TERMINATED: self._logger.info(f"Instance '{instance_id}' was terminated.") elif state.runtime_status == OrchestrationStatus.COMPLETED: @@ -191,23 +214,26 @@ def wait_for_orchestration_completion(self, instance_id: str, *, else: raise - def raise_orchestration_event(self, instance_id: str, event_name: str, *, - data: Optional[Any] = None): + def raise_orchestration_event( + self, instance_id: str, event_name: str, *, data: Optional[Any] = None + ): req = pb.RaiseEventRequest( instanceId=instance_id, name=event_name, - input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None) + input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None, + ) self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") self._stub.RaiseEvent(req) - def terminate_orchestration(self, instance_id: str, *, - output: Optional[Any] = None, - recursive: bool = True): + def terminate_orchestration( + self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True + ): req = pb.TerminateRequest( instanceId=instance_id, output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, - recursive=recursive) + recursive=recursive, + ) self._logger.info(f"Terminating instance '{instance_id}'.") self._stub.TerminateInstance(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 69db3c5..f9e8fb5 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -7,20 +7,26 @@ class _ClientCallDetails( - namedtuple( - '_ClientCallDetails', - ['method', 'timeout', 'metadata', 'credentials', 'wait_for_ready', 'compression']), - grpc.ClientCallDetails): + namedtuple( + "_ClientCallDetails", + ["method", "timeout", "metadata", "credentials", "wait_for_ready", "compression"], + ), + grpc.ClientCallDetails, +): """This is an implementation of the ClientCallDetails interface needed for interceptors. This class takes six named values and inherits the ClientCallDetails from grpc package. This class encloses the values that describe a RPC to be invoked. """ + pass -class DefaultClientInterceptorImpl ( - grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): +class DefaultClientInterceptorImpl( + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor, +): """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" @@ -29,10 +35,9 @@ def __init__(self, metadata: list[tuple[str, str]]): super().__init__() self._metadata = metadata - def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + def _intercept_call(self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC - call details.""" + call details.""" if self._metadata is None: return client_call_details @@ -43,8 +48,13 @@ def _intercept_call( metadata.extend(self._metadata) client_call_details = _ClientCallDetails( - client_call_details.method, client_call_details.timeout, metadata, - client_call_details.credentials, client_call_details.wait_for_ready, client_call_details.compression) + client_call_details.method, + client_call_details.timeout, + metadata, + client_call_details.credentials, + client_call_details.wait_for_ready, + client_call_details.compression, + ) return client_call_details diff --git a/durabletask/internal/helpers.py b/durabletask/internal/helpers.py index 48ab14b..2a632bd 100644 --- a/durabletask/internal/helpers.py +++ b/durabletask/internal/helpers.py @@ -12,21 +12,29 @@ # TODO: The new_xxx_event methods are only used by test code and should be moved elsewhere -def new_orchestrator_started_event(timestamp: Optional[datetime] = None) -> pb.HistoryEvent: +def new_orchestrator_started_event( + timestamp: Optional[datetime] = None, +) -> pb.HistoryEvent: ts = timestamp_pb2.Timestamp() if timestamp is not None: ts.FromDatetime(timestamp) - return pb.HistoryEvent(eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent()) + return pb.HistoryEvent( + eventId=-1, timestamp=ts, orchestratorStarted=pb.OrchestratorStartedEvent() + ) -def new_execution_started_event(name: str, instance_id: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: +def new_execution_started_event( + name: str, instance_id: str, encoded_input: Optional[str] = None +) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), executionStarted=pb.ExecutionStartedEvent( name=name, input=get_string_value(encoded_input), - orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id))) + orchestrationInstance=pb.OrchestrationInstance(instanceId=instance_id), + ), + ) def new_timer_created_event(timer_id: int, fire_at: datetime) -> pb.HistoryEvent: @@ -35,7 +43,7 @@ def new_timer_created_event(timer_id: int, fire_at: datetime) -> pb.HistoryEvent return pb.HistoryEvent( eventId=timer_id, timestamp=timestamp_pb2.Timestamp(), - timerCreated=pb.TimerCreatedEvent(fireAt=ts) + timerCreated=pb.TimerCreatedEvent(fireAt=ts), ) @@ -45,23 +53,29 @@ def new_timer_fired_event(timer_id: int, fire_at: datetime) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - timerFired=pb.TimerFiredEvent(fireAt=ts, timerId=timer_id) + timerFired=pb.TimerFiredEvent(fireAt=ts, timerId=timer_id), ) -def new_task_scheduled_event(event_id: int, name: str, encoded_input: Optional[str] = None) -> pb.HistoryEvent: +def new_task_scheduled_event( + event_id: int, name: str, encoded_input: Optional[str] = None +) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=event_id, timestamp=timestamp_pb2.Timestamp(), - taskScheduled=pb.TaskScheduledEvent(name=name, input=get_string_value(encoded_input)) + taskScheduled=pb.TaskScheduledEvent(name=name, input=get_string_value(encoded_input)), ) -def new_task_completed_event(event_id: int, encoded_output: Optional[str] = None) -> pb.HistoryEvent: +def new_task_completed_event( + event_id: int, encoded_output: Optional[str] = None +) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - taskCompleted=pb.TaskCompletedEvent(taskScheduledId=event_id, result=get_string_value(encoded_output)) + taskCompleted=pb.TaskCompletedEvent( + taskScheduledId=event_id, result=get_string_value(encoded_output) + ), ) @@ -69,32 +83,33 @@ def new_task_failed_event(event_id: int, ex: Exception) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - taskFailed=pb.TaskFailedEvent(taskScheduledId=event_id, failureDetails=new_failure_details(ex)) + taskFailed=pb.TaskFailedEvent( + taskScheduledId=event_id, failureDetails=new_failure_details(ex) + ), ) def new_sub_orchestration_created_event( - event_id: int, - name: str, - instance_id: str, - encoded_input: Optional[str] = None) -> pb.HistoryEvent: + event_id: int, name: str, instance_id: str, encoded_input: Optional[str] = None +) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=event_id, timestamp=timestamp_pb2.Timestamp(), subOrchestrationInstanceCreated=pb.SubOrchestrationInstanceCreatedEvent( - name=name, - input=get_string_value(encoded_input), - instanceId=instance_id) + name=name, input=get_string_value(encoded_input), instanceId=instance_id + ), ) -def new_sub_orchestration_completed_event(event_id: int, encoded_output: Optional[str] = None) -> pb.HistoryEvent: +def new_sub_orchestration_completed_event( + event_id: int, encoded_output: Optional[str] = None +) -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), subOrchestrationInstanceCompleted=pb.SubOrchestrationInstanceCompletedEvent( - result=get_string_value(encoded_output), - taskScheduledId=event_id) + result=get_string_value(encoded_output), taskScheduledId=event_id + ), ) @@ -103,8 +118,8 @@ def new_sub_orchestration_failed_event(event_id: int, ex: Exception) -> pb.Histo eventId=-1, timestamp=timestamp_pb2.Timestamp(), subOrchestrationInstanceFailed=pb.SubOrchestrationInstanceFailedEvent( - failureDetails=new_failure_details(ex), - taskScheduledId=event_id) + failureDetails=new_failure_details(ex), taskScheduledId=event_id + ), ) @@ -112,7 +127,7 @@ def new_failure_details(ex: Exception) -> pb.TaskFailureDetails: return pb.TaskFailureDetails( errorType=type(ex).__name__, errorMessage=str(ex), - stackTrace=wrappers_pb2.StringValue(value=''.join(traceback.format_tb(ex.__traceback__))) + stackTrace=wrappers_pb2.StringValue(value="".join(traceback.format_tb(ex.__traceback__))), ) @@ -120,7 +135,7 @@ def new_event_raised_event(name: str, encoded_input: Optional[str] = None) -> pb return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - eventRaised=pb.EventRaisedEvent(name=name, input=get_string_value(encoded_input)) + eventRaised=pb.EventRaisedEvent(name=name, input=get_string_value(encoded_input)), ) @@ -128,7 +143,7 @@ def new_suspend_event() -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - executionSuspended=pb.ExecutionSuspendedEvent() + executionSuspended=pb.ExecutionSuspendedEvent(), ) @@ -136,7 +151,7 @@ def new_resume_event() -> pb.HistoryEvent: return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - executionResumed=pb.ExecutionResumedEvent() + executionResumed=pb.ExecutionResumedEvent(), ) @@ -144,9 +159,7 @@ def new_terminated_event(*, encoded_output: Optional[str] = None) -> pb.HistoryE return pb.HistoryEvent( eventId=-1, timestamp=timestamp_pb2.Timestamp(), - executionTerminated=pb.ExecutionTerminatedEvent( - input=get_string_value(encoded_output) - ) + executionTerminated=pb.ExecutionTerminatedEvent(input=get_string_value(encoded_output)), ) @@ -158,18 +171,25 @@ def get_string_value(val: Optional[str]) -> Optional[wrappers_pb2.StringValue]: def new_complete_orchestration_action( - id: int, - status: pb.OrchestrationStatus, - result: Optional[str] = None, - failure_details: Optional[pb.TaskFailureDetails] = None, - carryover_events: Optional[list[pb.HistoryEvent]] = None) -> pb.OrchestratorAction: + id: int, + status: pb.OrchestrationStatus, + result: Optional[str] = None, + failure_details: Optional[pb.TaskFailureDetails] = None, + carryover_events: Optional[list[pb.HistoryEvent]] = None, + router: Optional[pb.TaskRouter] = None, +) -> pb.OrchestratorAction: completeOrchestrationAction = pb.CompleteOrchestrationAction( orchestrationStatus=status, result=get_string_value(result), failureDetails=failure_details, - carryoverEvents=carryover_events) + carryoverEvents=carryover_events, + ) - return pb.OrchestratorAction(id=id, completeOrchestration=completeOrchestrationAction) + return pb.OrchestratorAction( + id=id, + completeOrchestration=completeOrchestrationAction, + router=router, + ) def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction: @@ -178,7 +198,9 @@ def new_create_timer_action(id: int, fire_at: datetime) -> pb.OrchestratorAction return pb.OrchestratorAction(id=id, createTimer=pb.CreateTimerAction(fireAt=timestamp)) -def new_schedule_task_action(id: int, name: str, encoded_input: Optional[str], router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction: +def new_schedule_task_action( + id: int, name: str, encoded_input: Optional[str], router: Optional[pb.TaskRouter] = None +) -> pb.OrchestratorAction: return pb.OrchestratorAction( id=id, scheduleTask=pb.ScheduleTaskAction( @@ -197,11 +219,12 @@ def new_timestamp(dt: datetime) -> timestamp_pb2.Timestamp: def new_create_sub_orchestration_action( - id: int, - name: str, - instance_id: Optional[str], - encoded_input: Optional[str], - router: Optional[pb.TaskRouter] = None) -> pb.OrchestratorAction: + id: int, + name: str, + instance_id: Optional[str], + encoded_input: Optional[str], + router: Optional[pb.TaskRouter] = None, +) -> pb.OrchestratorAction: return pb.OrchestratorAction( id=id, createSubOrchestration=pb.CreateSubOrchestrationAction( @@ -215,13 +238,13 @@ def new_create_sub_orchestration_action( def is_empty(v: wrappers_pb2.StringValue): - return v is None or v.value == '' + return v is None or v.value == "" def get_orchestration_status_str(status: pb.OrchestrationStatus): try: const_name = pb.OrchestrationStatus.Name(status) - if const_name.startswith('ORCHESTRATION_STATUS_'): - return const_name[len('ORCHESTRATION_STATUS_'):] + if const_name.startswith("ORCHESTRATION_STATUS_"): + return const_name[len("ORCHESTRATION_STATUS_") :] except Exception: return "UNKNOWN" diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index c0fbe74..2f63f10 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -13,7 +13,7 @@ grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, grpc.StreamUnaryClientInterceptor, - grpc.StreamStreamClientInterceptor + grpc.StreamStreamClientInterceptor, ] # Field name used to indicate that an object was automatically serialized @@ -25,13 +25,36 @@ def get_default_host_address() -> str: + """Resolve the default Durable Task sidecar address. + + Honors environment variables if present; otherwise defaults to localhost:4001. + + Supported environment variables (checked in order): + - DURABLETASK_GRPC_ENDPOINT (e.g., "localhost:4001", "grpcs://host:443") + - DURABLETASK_GRPC_HOST and DURABLETASK_GRPC_PORT + """ + import os + + # Full endpoint overrides + endpoint = os.environ.get("DAPR_GRPC_ENDPOINT") + if endpoint: + return endpoint + + # Host/port split overrides + host = os.environ.get("DAPR_GRPC_HOST") or os.environ.get("DAPR_RUNTIME_HOST") + port = os.environ.get("DAPR_GRPC_PORT") + if host and port: + return f"{host}:{port}" + + # Default to durabletask-go default port return "localhost:4001" def get_grpc_channel( - host_address: Optional[str], - secure_channel: bool = False, - interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel: + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None, +) -> grpc.Channel: if host_address is None: host_address = get_default_host_address() @@ -39,14 +62,14 @@ def get_grpc_channel( if host_address.lower().startswith(protocol): secure_channel = True # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break for protocol in INSECURE_PROTOCOLS: if host_address.lower().startswith(protocol): secure_channel = False # remove the protocol from the host name - host_address = host_address[len(protocol):] + host_address = host_address[len(protocol) :] break # Create the base channel @@ -62,9 +85,10 @@ def get_grpc_channel( def get_logger( - name_suffix: str, - log_handler: Optional[logging.Handler] = None, - log_formatter: Optional[logging.Formatter] = None) -> logging.Logger: + name_suffix: str, + log_handler: Optional[logging.Handler] = None, + log_formatter: Optional[logging.Formatter] = None, +) -> logging.Logger: logger = logging.Logger(f"durabletask-{name_suffix}") # Add a default log handler if none is provided @@ -77,7 +101,8 @@ def get_logger( if log_formatter is None: log_formatter = logging.Formatter( fmt="%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s", - datefmt='%Y-%m-%d %H:%M:%S') + datefmt="%Y-%m-%d %H:%M:%S", + ) log_handler.setFormatter(log_formatter) return logger diff --git a/durabletask/task.py b/durabletask/task.py index 29af2c5..4bcd060 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -12,13 +12,12 @@ import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb -T = TypeVar('T') -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') +T = TypeVar("T") +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") class OrchestrationContext(ABC): - @property @abstractmethod def instance_id(self) -> str: @@ -98,10 +97,14 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> Task: pass @abstractmethod - def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, - input: Optional[TInput] = None, - retry_policy: Optional[RetryPolicy] = None, - app_id: Optional[str] = None) -> Task[TOutput]: + def call_activity( + self, + activity: Union[Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[RetryPolicy] = None, + app_id: Optional[str] = None, + ) -> Task[TOutput]: """Schedule an activity for execution. Parameters @@ -123,11 +126,15 @@ def call_activity(self, activity: Union[Activity[TInput, TOutput], str], *, pass @abstractmethod - def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[RetryPolicy] = None, - app_id: Optional[str] = None) -> Task[TOutput]: + def call_sub_orchestrator( + self, + orchestrator: Orchestrator[TInput, TOutput], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[RetryPolicy] = None, + app_id: Optional[str] = None, + ) -> Task[TOutput]: """Schedule sub-orchestrator function for execution. Parameters @@ -210,7 +217,8 @@ def __init__(self, message: str, details: pb.TaskFailureDetails): self._details = FailureDetails( details.errorMessage, details.errorType, - details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None) + details.stackTrace.value if not pbh.is_empty(details.stackTrace) else None, + ) @property def details(self) -> FailureDetails: @@ -227,6 +235,7 @@ class OrchestrationStateError(Exception): class Task(ABC, Generic[T]): """Abstract base class for asynchronous tasks in a durable orchestration.""" + _result: T _exception: Optional[TaskFailedError] _parent: Optional[CompositeTask[T]] @@ -250,7 +259,7 @@ def is_failed(self) -> bool: def get_result(self) -> T: """Returns the result of the task.""" if not self._is_complete: - raise ValueError('The task has not completed.') + raise ValueError("The task has not completed.") elif self._exception is not None: raise self._exception return self._result @@ -258,12 +267,13 @@ def get_result(self) -> T: def get_exception(self) -> TaskFailedError: """Returns the exception that caused the task to fail.""" if self._exception is None: - raise ValueError('The task has not failed.') + raise ValueError("The task has not failed.") return self._exception class CompositeTask(Task[T]): """A task that is composed of other tasks.""" + _tasks: list[Task] def __init__(self, tasks: list[Task]): @@ -283,6 +293,7 @@ def get_tasks(self) -> list[Task]: def on_child_completed(self, task: Task[T]): pass + class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" @@ -298,7 +309,7 @@ def pending_tasks(self) -> int: def on_child_completed(self, task: Task[T]): if self.is_complete: - raise ValueError('The task has already completed.') + raise ValueError("The task has already completed.") self._completed_tasks += 1 if task.is_failed and self._exception is None: self._exception = task.get_exception() @@ -313,14 +324,13 @@ def get_completed_tasks(self) -> int: class CompletableTask(Task[T]): - def __init__(self): super().__init__() self._retryable_parent = None def complete(self, result: T): if self._is_complete: - raise ValueError('The task has already completed.') + raise ValueError("The task has already completed.") self._result = result self._is_complete = True if self._parent is not None: @@ -328,7 +338,7 @@ def complete(self, result: T): def fail(self, message: str, details: pb.TaskFailureDetails): if self._is_complete: - raise ValueError('The task has already completed.') + raise ValueError("The task has already completed.") self._exception = TaskFailedError(message, details) self._is_complete = True if self._parent is not None: @@ -338,8 +348,13 @@ def fail(self, message: str, details: pb.TaskFailureDetails): class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" - def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, - start_time: datetime, is_sub_orch: bool) -> None: + def __init__( + self, + retry_policy: RetryPolicy, + action: pb.OrchestratorAction, + start_time: datetime, + is_sub_orch: bool, + ) -> None: super().__init__() self._action = action self._retry_policy = retry_policy @@ -355,7 +370,10 @@ def compute_next_delay(self) -> Optional[timedelta]: return None retry_expiration: datetime = datetime.max - if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max: + if ( + self._retry_policy.retry_timeout is not None + and self._retry_policy.retry_timeout != datetime.max + ): retry_expiration = self._start_time + self._retry_policy.retry_timeout if self._retry_policy.backoff_coefficient is None: @@ -364,17 +382,21 @@ def compute_next_delay(self) -> Optional[timedelta]: backoff_coefficient = self._retry_policy.backoff_coefficient if datetime.utcnow() < retry_expiration: - next_delay_f = math.pow(backoff_coefficient, self._attempt_count - 1) * self._retry_policy.first_retry_interval.total_seconds() + next_delay_f = ( + math.pow(backoff_coefficient, self._attempt_count - 1) + * self._retry_policy.first_retry_interval.total_seconds() + ) if self._retry_policy.max_retry_interval is not None: - next_delay_f = min(next_delay_f, self._retry_policy.max_retry_interval.total_seconds()) + next_delay_f = min( + next_delay_f, self._retry_policy.max_retry_interval.total_seconds() + ) return timedelta(seconds=next_delay_f) return None class TimerTask(CompletableTask[T]): - def __init__(self) -> None: super().__init__() @@ -448,12 +470,15 @@ def task_id(self) -> int: class RetryPolicy: """Represents the retry policy for an orchestration or activity function.""" - def __init__(self, *, - first_retry_interval: timedelta, - max_number_of_attempts: int, - backoff_coefficient: Optional[float] = 1.0, - max_retry_interval: Optional[timedelta] = None, - retry_timeout: Optional[timedelta] = None): + def __init__( + self, + *, + first_retry_interval: timedelta, + max_number_of_attempts: int, + backoff_coefficient: Optional[float] = 1.0, + max_retry_interval: Optional[timedelta] = None, + retry_timeout: Optional[timedelta] = None, + ): """Creates a new RetryPolicy instance. Parameters @@ -471,15 +496,15 @@ def __init__(self, *, """ # validate inputs if first_retry_interval < timedelta(seconds=0): - raise ValueError('first_retry_interval must be >= 0') + raise ValueError("first_retry_interval must be >= 0") if max_number_of_attempts < 1: - raise ValueError('max_number_of_attempts must be >= 1') + raise ValueError("max_number_of_attempts must be >= 1") if backoff_coefficient is not None and backoff_coefficient < 1: - raise ValueError('backoff_coefficient must be >= 1') + raise ValueError("backoff_coefficient must be >= 1") if max_retry_interval is not None and max_retry_interval < timedelta(seconds=0): - raise ValueError('max_retry_interval must be >= 0') + raise ValueError("max_retry_interval must be >= 0") if retry_timeout is not None and retry_timeout < timedelta(seconds=0): - raise ValueError('retry_timeout must be >= 0') + raise ValueError("retry_timeout must be >= 0") self._first_retry_interval = first_retry_interval self._max_number_of_attempts = max_number_of_attempts @@ -516,7 +541,9 @@ def retry_timeout(self) -> Optional[timedelta]: def get_name(fn: Callable) -> str: """Returns the name of the provided function""" name = fn.__name__ - if name == '': - raise ValueError('Cannot infer a name from a lambda function. Please provide a name explicitly.') + if name == "": + raise ValueError( + "Cannot infer a name from a lambda function. Please provide a name explicitly." + ) return name diff --git a/durabletask/worker.py b/durabletask/worker.py index 7a04649..2d057e1 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -34,10 +34,10 @@ class ConcurrencyOptions: """ def __init__( - self, - maximum_concurrent_activity_work_items: Optional[int] = None, - maximum_concurrent_orchestration_work_items: Optional[int] = None, - maximum_thread_pool_workers: Optional[int] = None, + self, + maximum_concurrent_activity_work_items: Optional[int] = None, + maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_thread_pool_workers: Optional[int] = None, ): """Initialize concurrency options. @@ -214,20 +214,18 @@ class TaskHubGrpcWorker: _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__( - self, - *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler=None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, - concurrency_options: Optional[ConcurrencyOptions] = None, + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler=None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + concurrency_options: Optional[ConcurrencyOptions] = None, ): self._registry = _Registry() - self._host_address = ( - host_address if host_address else shared.get_default_host_address() - ) + self._host_address = host_address if host_address else shared.get_default_host_address() self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False @@ -235,9 +233,7 @@ def __init__( # Use provided concurrency options or create default ones self._concurrency_options = ( - concurrency_options - if concurrency_options is not None - else ConcurrencyOptions() + concurrency_options if concurrency_options is not None else ConcurrencyOptions() ) # Determine the interceptors to use @@ -266,17 +262,13 @@ def __exit__(self, type, value, traceback): def add_orchestrator(self, fn: task.Orchestrator) -> str: """Registers an orchestrator function with the worker.""" if self._is_running: - raise RuntimeError( - "Orchestrators cannot be added while the worker is running." - ) + raise RuntimeError("Orchestrators cannot be added while the worker is running.") return self._registry.add_orchestrator(fn) def add_activity(self, fn: task.Activity) -> str: """Registers an activity function with the worker.""" if self._is_running: - raise RuntimeError( - "Activities cannot be added while the worker is running." - ) + raise RuntimeError("Activities cannot be added while the worker is running.") return self._registry.add_activity(fn) def start(self): @@ -413,9 +405,7 @@ def stream_reader(): loop = asyncio.get_running_loop() while not self._shutdown.is_set(): try: - work_item = await loop.run_in_executor( - None, work_item_queue.get - ) + work_item = await loop.run_in_executor(None, work_item_queue.get) if isinstance(work_item, Exception): raise work_item request_type = work_item.WhichOneof("request") @@ -437,9 +427,7 @@ def stream_reader(): elif work_item.HasField("healthPing"): pass else: - self._logger.warning( - f"Unexpected work item type: {request_type}" - ) + self._logger.warning(f"Unexpected work item type: {request_type}") except Exception as e: self._logger.warning(f"Error in work item stream: {e}") raise e @@ -457,7 +445,10 @@ def stream_reader(): break elif error_code == grpc.StatusCode.UNAVAILABLE: # Check if this is a connection timeout scenario - if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details: + if ( + "Timeout occurred" in error_details + or "Failed to connect to remote host" in error_details + ): self._logger.warning( f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection" ) @@ -499,10 +490,10 @@ def stop(self): self._is_running = False def _execute_orchestrator( - self, - req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, - completionToken, + self, + req: pb.OrchestratorRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, ): try: executor = _OrchestrationExecutor(self._registry, self._logger) @@ -537,17 +528,15 @@ def _execute_orchestrator( ) def _execute_activity( - self, - req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, - completionToken, + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, ): instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute( - instance_id, req.name, req.taskId, req.input.value - ) + result = executor.execute(instance_id, req.name, req.taskId, req.input.value) res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, @@ -626,10 +615,10 @@ def resume(self): self._previous_task = next_task def set_complete( - self, - result: Any, - status: pb.OrchestrationStatus, - is_result_encoded: bool = False, + self, + result: Any, + status: pb.OrchestrationStatus, + is_result_encoded: bool = False, ): if self._is_complete: return @@ -643,7 +632,10 @@ def set_complete( if result is not None: result_json = result if is_result_encoded else shared.to_json(result) action = ph.new_complete_orchestration_action( - self.next_sequence_number(), status, result_json + self.next_sequence_number(), + status, + result_json, + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) self._pending_actions[action.id] = action @@ -660,6 +652,7 @@ def set_failed(self, ex: Exception): pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex), + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) self._pending_actions[action.id] = action @@ -683,20 +676,17 @@ def get_actions(self) -> list[pb.OrchestratorAction]: # replayed when the new instance starts. for event_name, values in self._received_events.items(): for event_value in values: - encoded_value = ( - shared.to_json(event_value) if event_value else None - ) + encoded_value = shared.to_json(event_value) if event_value else None carryover_events.append( ph.new_event_raised_event(event_name, encoded_value) ) action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW, - result=shared.to_json(self._new_input) - if self._new_input is not None - else None, + result=shared.to_json(self._new_input) if self._new_input is not None else None, failure_details=None, carryover_events=carryover_events, + router=pb.TaskRouter(sourceAppID=self._app_id) if self._app_id else None, ) return [action] else: @@ -735,9 +725,9 @@ def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) def create_timer_internal( - self, - fire_at: Union[datetime, timedelta], - retryable_task: Optional[task.RetryableTask] = None, + self, + fire_at: Union[datetime, timedelta], + retryable_task: Optional[task.RetryableTask] = None, ) -> task.Task: id = self.next_sequence_number() if isinstance(fire_at, timedelta): @@ -752,12 +742,12 @@ def create_timer_internal( return timer_task def call_activity( - self, - activity: Union[task.Activity[TInput, TOutput], str], - *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, - app_id: Optional[str] = None, + self, + activity: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + app_id: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() @@ -767,13 +757,13 @@ def call_activity( return self._pending_tasks.get(id, task.CompletableTask()) def call_sub_orchestrator( - self, - orchestrator: Union[task.Orchestrator[TInput, TOutput], str], - *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[task.RetryPolicy] = None, - app_id: Optional[str] = None, + self, + orchestrator: Union[task.Orchestrator[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None, + app_id: Optional[str] = None, ) -> task.Task[TOutput]: id = self.next_sequence_number() if isinstance(orchestrator, str): @@ -792,16 +782,16 @@ def call_sub_orchestrator( return self._pending_tasks.get(id, task.CompletableTask()) def call_activity_function_helper( - self, - id: Optional[int], - activity_function: Union[task.Activity[TInput, TOutput], str], - *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, - is_sub_orch: bool = False, - instance_id: Optional[str] = None, - fn_task: Optional[task.CompletableTask[TOutput]] = None, - app_id: Optional[str] = None, + self, + id: Optional[int], + activity_function: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + is_sub_orch: bool = False, + instance_id: Optional[str] = None, + fn_task: Optional[task.CompletableTask[TOutput]] = None, + app_id: Optional[str] = None, ): if id is None: id = self.next_sequence_number() @@ -880,13 +870,11 @@ class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] - - def __init__( - self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] - ): + def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]): self.actions = actions self.encoded_custom_status = encoded_custom_status + class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None @@ -897,10 +885,10 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._suspended_events: list[pb.HistoryEvent] = [] def execute( - self, - instance_id: str, - old_events: Sequence[pb.HistoryEvent], - new_events: Sequence[pb.HistoryEvent], + self, + instance_id: str, + old_events: Sequence[pb.HistoryEvent], + new_events: Sequence[pb.HistoryEvent], ) -> ExecutionResults: if not new_events: raise task.OrchestrationStateError( @@ -938,28 +926,22 @@ def execute( f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding." ) elif ( - ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + ctx._completion_status + and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW ): - completion_status_str = ph.get_orchestration_status_str( - ctx._completion_status - ) + completion_status_str = ph.get_orchestration_status_str(ctx._completion_status) self._logger.info( f"{instance_id}: Orchestration completed with status: {completion_status_str}" ) actions = ctx.get_actions() if self._logger.level <= logging.DEBUG: - self._logger.debug( f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}" ) - return ExecutionResults( - actions=actions, encoded_custom_status=ctx._encoded_custom_status - ) + return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) - def process_event( - self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent - ) -> None: + def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: if self._is_suspended and _is_suspendable(event): # We are suspended, so we need to buffer this event until we are resumed self._suspended_events.append(event) @@ -985,13 +967,12 @@ def process_event( # deserialize the input, if any input = None if ( - event.executionStarted.input is not None and event.executionStarted.input.value != "" + event.executionStarted.input is not None + and event.executionStarted.input.value != "" ): input = shared.from_json(event.executionStarted.input.value) - result = fn( - ctx, input - ) # this does not execute the generator, only creates it + result = fn(ctx, input) # this does not execute the generator, only creates it if isinstance(result, GeneratorType): # Start the orchestrator's generator function ctx.run(result) @@ -1004,14 +985,10 @@ def process_event( timer_id = event.eventId action = ctx._pending_actions.pop(timer_id, None) if not action: - raise _get_non_determinism_error( - timer_id, task.get_name(ctx.create_timer) - ) + raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer)) elif not action.HasField("createTimer"): expected_method_name = task.get_name(ctx.create_timer) - raise _get_wrong_action_type_error( - timer_id, expected_method_name, action - ) + raise _get_wrong_action_type_error(timer_id, expected_method_name, action) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId timer_task = ctx._pending_tasks.pop(timer_id, None) @@ -1056,14 +1033,10 @@ def process_event( action = ctx._pending_actions.pop(task_id, None) activity_task = ctx._pending_tasks.get(task_id, None) if not action: - raise _get_non_determinism_error( - task_id, task.get_name(ctx.call_activity) - ) + raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity)) elif not action.HasField("scheduleTask"): expected_method_name = task.get_name(ctx.call_activity) - raise _get_wrong_action_type_error( - task_id, expected_method_name, action - ) + raise _get_wrong_action_type_error(task_id, expected_method_name, action) elif action.scheduleTask.name != event.taskScheduled.name: raise _get_wrong_action_name_error( task_id, @@ -1129,11 +1102,9 @@ def process_event( ) elif not action.HasField("createSubOrchestration"): expected_method_name = task.get_name(ctx.call_sub_orchestrator) - raise _get_wrong_action_type_error( - task_id, expected_method_name, action - ) + raise _get_wrong_action_type_error(task_id, expected_method_name, action) elif ( - action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name + action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name ): raise _get_wrong_action_name_error( task_id, @@ -1153,9 +1124,7 @@ def process_event( return result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): - result = shared.from_json( - event.subOrchestrationInstanceCompleted.result.value - ) + result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value) sub_orch_task.complete(result) ctx.resume() elif event.HasField("subOrchestrationInstanceFailed"): @@ -1257,16 +1226,14 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._logger = logger def execute( - self, - orchestration_id: str, - name: str, - task_id: int, - encoded_input: Optional[str], + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: Optional[str], ) -> Optional[str]: """Executes an activity function and returns the serialized result, if any.""" - self._logger.debug( - f"{orchestration_id}/{task_id}: Executing activity '{name}'..." - ) + self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") fn = self._registry.get_activity(name) if not fn: raise ActivityNotRegisteredError( @@ -1279,9 +1246,7 @@ def execute( # Execute the activity function activity_output = fn(ctx, activity_input) - encoded_output = ( - shared.to_json(activity_output) if activity_output is not None else None - ) + encoded_output = shared.to_json(activity_output) if activity_output is not None else None chars = len(encoded_output) if encoded_output else 0 self._logger.debug( f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output." @@ -1289,9 +1254,7 @@ def execute( return encoded_output -def _get_non_determinism_error( - task_id: int, action_name: str -) -> task.NonDeterminismError: +def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError: return task.NonDeterminismError( f"A previous execution called {action_name} with ID={task_id}, but the current " f"execution doesn't have this action with this ID. This problem occurs when either " @@ -1301,7 +1264,7 @@ def _get_non_determinism_error( def _get_wrong_action_type_error( - task_id: int, expected_method_name: str, action: pb.OrchestratorAction + task_id: int, expected_method_name: str, action: pb.OrchestratorAction ) -> task.NonDeterminismError: unexpected_method_name = _get_method_name_for_action(action) return task.NonDeterminismError( @@ -1314,7 +1277,7 @@ def _get_wrong_action_type_error( def _get_wrong_action_name_error( - task_id: int, method_name: str, expected_task_name: str, actual_task_name: str + task_id: int, method_name: str, expected_task_name: str, actual_task_name: str ) -> task.NonDeterminismError: return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " @@ -1422,9 +1385,7 @@ def _ensure_queues_for_current_loop(self): if self.orchestration_queue is not None: try: while not self.orchestration_queue.empty(): - existing_orchestration_items.append( - self.orchestration_queue.get_nowait() - ) + existing_orchestration_items.append(self.orchestration_queue.get_nowait()) except Exception: pass @@ -1468,9 +1429,7 @@ async def run(self): if self.activity_queue is not None and self.orchestration_queue is not None: await asyncio.gather( self._consume_queue(self.activity_queue, self.activity_semaphore), - self._consume_queue( - self.orchestration_queue, self.orchestration_semaphore - ), + self._consume_queue(self.orchestration_queue, self.orchestration_semaphore), ) async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): @@ -1499,7 +1458,7 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor running_tasks.add(task) async def _process_work_item( - self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs ): async with semaphore: try: @@ -1514,13 +1473,12 @@ async def _run_func(self, func, *args, **kwargs): loop = asyncio.get_running_loop() # Avoid submitting to executor after shutdown if ( - getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( - self.thread_pool, "_shutdown", False) + getattr(self, "_shutdown", False) + and getattr(self, "thread_pool", None) + and getattr(self.thread_pool, "_shutdown", False) ): return None - return await loop.run_in_executor( - self.thread_pool, lambda: func(*args, **kwargs) - ) + return await loop.run_in_executor(self.thread_pool, lambda: func(*args, **kwargs)) def submit_activity(self, func, *args, **kwargs): work_item = (func, args, kwargs) diff --git a/examples/activity_sequence.py b/examples/activity_sequence.py index 066a733..fa88363 100644 --- a/examples/activity_sequence.py +++ b/examples/activity_sequence.py @@ -1,19 +1,20 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" + from durabletask import client, task, worker def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" - return f'Hello {name}!' + return f"Hello {name}!" def sequence(ctx: task.OrchestrationContext, _): """Orchestrator function that calls the 'hello' activity function in a sequence""" # call "hello" activity function in a sequence - result1 = yield ctx.call_activity(hello, input='Tokyo') - result2 = yield ctx.call_activity(hello, input='Seattle') - result3 = yield ctx.call_activity(hello, input='London') + result1 = yield ctx.call_activity(hello, input="Tokyo") + result2 = yield ctx.call_activity(hello, input="Seattle") + result3 = yield ctx.call_activity(hello, input="London") # return an array of results return [result1, result2, result3] @@ -30,6 +31,6 @@ def sequence(ctx: task.OrchestrationContext, _): instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=10) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') + print(f"Orchestration completed! Result: {state.serialized_output}") elif state: - print(f'Orchestration failed: {state.failure_details}') + print(f"Orchestration failed: {state.failure_details}") diff --git a/examples/fanout_fanin.py b/examples/fanout_fanin.py index c53744f..30339b7 100644 --- a/examples/fanout_fanin.py +++ b/examples/fanout_fanin.py @@ -1,6 +1,7 @@ """End-to-end sample that demonstrates how to configure an orchestrator that a dynamic number activity functions in parallel, waits for them all to complete, and prints an aggregate summary of the outputs.""" + import random import time @@ -11,13 +12,13 @@ def get_work_items(ctx: task.ActivityContext, _) -> list[str]: """Activity function that returns a list of work items""" # return a random number of work items count = random.randint(2, 10) - print(f'generating {count} work items...') - return [f'work item {i}' for i in range(count)] + print(f"generating {count} work items...") + return [f"work item {i}" for i in range(count)] def process_work_item(ctx: task.ActivityContext, item: str) -> int: """Activity function that returns a result for a given work item""" - print(f'processing work item: {item}') + print(f"processing work item: {item}") # simulate some work that takes a variable amount of time time.sleep(random.random() * 5) @@ -39,9 +40,9 @@ def orchestrator(ctx: task.OrchestrationContext, _): # return an aggregate summary of the results return { - 'work_items': work_items, - 'results': results, - 'total': sum(results), + "work_items": work_items, + "results": results, + "total": sum(results), } @@ -57,6 +58,6 @@ def orchestrator(ctx: task.OrchestrationContext, _): instance_id = c.schedule_new_orchestration(orchestrator) state = c.wait_for_orchestration_completion(instance_id, timeout=30) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') + print(f"Orchestration completed! Result: {state.serialized_output}") elif state: - print(f'Orchestration failed: {state.failure_details}') + print(f"Orchestration failed: {state.failure_details}") diff --git a/examples/human_interaction.py b/examples/human_interaction.py index 2a01897..9773055 100644 --- a/examples/human_interaction.py +++ b/examples/human_interaction.py @@ -15,23 +15,24 @@ @dataclass class Order: """Represents a purchase order""" + Cost: float Product: str Quantity: int def __str__(self): - return f'{self.Product} ({self.Quantity})' + return f"{self.Product} ({self.Quantity})" def send_approval_request(_: task.ActivityContext, order: Order) -> None: """Activity function that sends an approval request to the manager""" time.sleep(5) - print(f'*** Sending approval request for order: {order}') + print(f"*** Sending approval request for order: {order}") def place_order(_: task.ActivityContext, order: Order) -> None: """Activity function that places an order""" - print(f'*** Placing order: {order}') + print(f"*** Placing order: {order}") def purchase_order_workflow(ctx: task.OrchestrationContext, order: Order): @@ -92,7 +93,7 @@ def prompt_for_approval(): if not state: print("Workflow not found!") # not expected elif state.runtime_status == client.OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') + print(f"Orchestration completed! Result: {state.serialized_output}") else: state.raise_if_failed() # raises an exception except TimeoutError: diff --git a/pyproject.toml b/pyproject.toml index 8c4d1e4..f4e83d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,9 +17,9 @@ keywords = [ "workflow" ] classifiers = [ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", ] requires-python = ">=3.9" license = {file = "LICENSE"} @@ -48,3 +48,27 @@ pythonpath = ["."] markers = [ "e2e: mark a test as an end-to-end test that requires a running sidecar" ] + +[tool.ruff] +target-version = "py39" # TODO: update to py310 when we drop support for py39 +line-length = 100 +fix = true +extend-exclude = [".github", "durabletask/internal/orchestrator_service_*.*"] +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear + "UP", # pyupgrade +] +ignore = [ + # Undefined name {name} + "F821", +] +[tool.ruff.format] +# follow upstream quote-style instead of python-sdk to reduce diff +quote-style = "double" + diff --git a/tests/durabletask/test_activity_executor.py b/tests/durabletask/test_activity_executor.py index bfc8eaf..996ae44 100644 --- a/tests/durabletask/test_activity_executor.py +++ b/tests/durabletask/test_activity_executor.py @@ -8,16 +8,18 @@ from durabletask import task, worker logging.basicConfig( - format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.DEBUG) + format="%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG, +) TEST_LOGGER = logging.getLogger("tests") -TEST_INSTANCE_ID = 'abc123' +TEST_INSTANCE_ID = "abc123" TEST_TASK_ID = 42 def test_activity_inputs(): """Validates activity function input population""" + def test_activity(ctx: task.ActivityContext, test_input: Any): # return all activity inputs back as the output return test_input, ctx.orchestration_id, ctx.task_id @@ -34,7 +36,6 @@ def test_activity(ctx: task.ActivityContext, test_input: Any): def test_activity_not_registered(): - def test_activity(ctx: task.ActivityContext, _): pass # not used diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py index e5a8e9b..d55e0e0 100644 --- a/tests/durabletask/test_client.py +++ b/tests/durabletask/test_client.py @@ -1,35 +1,39 @@ from unittest.mock import ANY, patch from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import (get_default_host_address, - get_grpc_channel) +from durabletask.internal.shared import get_default_host_address, get_grpc_channel -HOST_ADDRESS = 'localhost:50051' -METADATA = [('key1', 'value1'), ('key2', 'value2')] +HOST_ADDRESS = "localhost:50051" +METADATA = [("key1", "value1"), ("key2", "value2")] INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] def test_get_grpc_channel_insecure(): - with patch('grpc.insecure_channel') as mock_channel: + with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) def test_get_grpc_channel_secure(): - with patch('grpc.secure_channel') as mock_channel, patch( - 'grpc.ssl_channel_credentials') as mock_credentials: + with ( + patch("grpc.secure_channel") as mock_channel, + patch("grpc.ssl_channel_credentials") as mock_credentials, + ): get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) + def test_get_grpc_channel_default_host_address(): - with patch('grpc.insecure_channel') as mock_channel: + with patch("grpc.insecure_channel") as mock_channel: get_grpc_channel(None, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(get_default_host_address()) def test_get_grpc_channel_with_metadata(): - with patch('grpc.insecure_channel') as mock_channel, patch( - 'grpc.intercept_channel') as mock_intercept_channel: + with ( + patch("grpc.insecure_channel") as mock_channel, + patch("grpc.intercept_channel") as mock_intercept_channel, + ): get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) mock_intercept_channel.assert_called_once() @@ -42,9 +46,10 @@ def test_get_grpc_channel_with_metadata(): def test_grpc_channel_with_host_name_protocol_stripping(): - with patch('grpc.insecure_channel') as mock_insecure_channel, patch( - 'grpc.secure_channel') as mock_secure_channel: - + with ( + patch("grpc.insecure_channel") as mock_insecure_channel, + patch("grpc.secure_channel") as mock_secure_channel, + ): host_name = "myserver.com:1234" prefix = "grpc://" diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py index b49b7ec..a923383 100644 --- a/tests/durabletask/test_concurrency_options.py +++ b/tests/durabletask/test_concurrency_options.py @@ -37,9 +37,7 @@ def test_partial_custom_options(): expected_default = 100 * processor_count expected_workers = processor_count + 4 - options = ConcurrencyOptions( - maximum_concurrent_activity_work_items=30 - ) + options = ConcurrencyOptions(maximum_concurrent_activity_work_items=30) assert options.maximum_concurrent_activity_work_items == 30 assert options.maximum_concurrent_orchestration_work_items == expected_default @@ -67,9 +65,7 @@ def test_worker_default_options(): expected_default = 100 * processor_count expected_workers = processor_count + 4 - assert ( - worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default - ) + assert worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default assert ( worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default ) diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 2343184..3bd394d 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -16,7 +16,6 @@ def test_empty_orchestration(): - invoked = False def empty_orchestrator(ctx: task.OrchestrationContext, _): @@ -44,7 +43,6 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): def test_activity_sequence(): - def plus_one(_: task.ActivityContext, input: int) -> int: return input + 1 @@ -64,8 +62,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(sequence, input=1) - state = task_hub_client.wait_for_orchestration_completion( - id, timeout=30) + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.name == task.get_name(sequence) @@ -78,7 +75,6 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): def test_activity_error_handling(): - def throw(_: task.ActivityContext, input: int) -> int: raise RuntimeError("Kah-BOOOOM!!!") @@ -139,8 +135,7 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): # Fan out to multiple sub-orchestrations tasks = [] for _ in range(count): - tasks.append(ctx.call_sub_orchestrator( - orchestrator_child, input=3)) + tasks.append(ctx.call_sub_orchestrator(orchestrator_child, input=3)) # Wait for all sub-orchestrations to complete yield task.when_all(tasks) @@ -163,9 +158,9 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): def test_wait_for_multiple_external_events(): def orchestrator(ctx: task.OrchestrationContext, _): - a = yield ctx.wait_for_external_event('A') - b = yield ctx.wait_for_external_event('B') - c = yield ctx.wait_for_external_event('C') + a = yield ctx.wait_for_external_event("A") + b = yield ctx.wait_for_external_event("B") + c = yield ctx.wait_for_external_event("C") return [a, b, c] # Start a worker, which will connect to the sidecar in a background thread @@ -176,20 +171,20 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Start the orchestration and immediately raise events to it. task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(orchestrator) - task_hub_client.raise_orchestration_event(id, 'A', data='a') - task_hub_client.raise_orchestration_event(id, 'B', data='b') - task_hub_client.raise_orchestration_event(id, 'C', data='c') + task_hub_client.raise_orchestration_event(id, "A", data="a") + task_hub_client.raise_orchestration_event(id, "B", data="b") + task_hub_client.raise_orchestration_event(id, "C", data="c") state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None assert state.runtime_status == client.OrchestrationStatus.COMPLETED - assert state.serialized_output == json.dumps(['a', 'b', 'c']) + assert state.serialized_output == json.dumps(["a", "b", "c"]) @pytest.mark.parametrize("raise_event", [True, False]) def test_wait_for_external_event_timeout(raise_event: bool): def orchestrator(ctx: task.OrchestrationContext, _): - approval: task.Task[bool] = ctx.wait_for_external_event('Approval') + approval: task.Task[bool] = ctx.wait_for_external_event("Approval") timeout = ctx.create_timer(timedelta(seconds=3)) winner = yield task.when_any([approval, timeout]) if winner == approval: @@ -206,7 +201,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): task_hub_client = client.TaskHubGrpcClient() id = task_hub_client.schedule_new_orchestration(orchestrator) if raise_event: - task_hub_client.raise_orchestration_event(id, 'Approval') + task_hub_client.raise_orchestration_event(id, "Approval") state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None @@ -316,7 +311,6 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): output = "Recursive termination = {recurse}" task_hub_client.terminate_orchestration(instance_id, output=output, recursive=recurse) - metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) assert metadata is not None @@ -326,9 +320,13 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): time.sleep(delay_time) if recurse: - assert activity_counter == 0, "Activity should not have executed with recursive termination" + assert activity_counter == 0, ( + "Activity should not have executed with recursive termination" + ) else: - assert activity_counter == 5, "Activity should have executed without recursive termination" + assert activity_counter == 5, ( + "Activity should have executed without recursive termination" + ) def test_continue_as_new(): @@ -367,6 +365,45 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert all_results == [1, 2, 3, 4, 5] +def test_continue_as_new_with_activity_e2e(): + """E2E test for continue_as_new with activities (generator-based).""" + activity_results = [] + + def double_activity(ctx: task.ActivityContext, value: int) -> int: + """Activity that doubles the value.""" + result = value * 2 + activity_results.append(result) + return result + + def orchestrator(ctx: task.OrchestrationContext, counter: int): + # Call activity to process the counter + processed = yield ctx.call_activity(double_activity, input=counter) + + # Continue as new up to 3 times + if counter < 3: + ctx.continue_as_new(counter + 1, save_events=False) + else: + return {"counter": counter, "processed": processed, "all_results": activity_results} + + with worker.TaskHubGrpcWorker() as w: + w.add_activity(double_activity) + w.add_orchestrator(orchestrator) + w.start() + + task_hub_client = client.TaskHubGrpcClient() + id = task_hub_client.schedule_new_orchestration(orchestrator, input=1) + + state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) + assert state is not None + assert state.runtime_status == client.OrchestrationStatus.COMPLETED + + output = json.loads(state.serialized_output) + # Should have called activity 3 times with input values 1, 2, 3 + assert activity_results == [2, 4, 6] + assert output["counter"] == 3 + assert output["processed"] == 6 + + # NOTE: This test fails when running against durabletask-go with sqlite because the sqlite backend does not yet # support orchestration ID reuse. This gap is being tracked here: # https://github.com/microsoft/durabletask-go/issues/42 @@ -387,7 +424,8 @@ def test_retry_policies(): max_number_of_attempts=3, backoff_coefficient=1, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=30)) + retry_timeout=timedelta(seconds=30), + ) def parent_orchestrator_with_retry(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(child_orchestrator_with_retry, retry_policy=retry_policy) @@ -436,7 +474,8 @@ def test_retry_timeout(): max_number_of_attempts=5, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=14)) + retry_timeout=timedelta(seconds=14), + ) def mock_orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_activity(throw_activity, retry_policy=retry_policy) @@ -464,7 +503,6 @@ def throw_activity(ctx: task.ActivityContext, _): def test_custom_status(): - def empty_orchestrator(ctx: task.OrchestrationContext, _): ctx.set_custom_status("foobaz") @@ -484,4 +522,4 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.runtime_status == client.OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None - assert state.serialized_custom_status == "\"foobaz\"" + assert state.serialized_custom_status == '"foobaz"' diff --git a/tests/durabletask/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py index 21f6c6c..393ca49 100644 --- a/tests/durabletask/test_orchestration_executor.py +++ b/tests/durabletask/test_orchestration_executor.py @@ -12,9 +12,10 @@ from durabletask import task, worker logging.basicConfig( - format='%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - level=logging.DEBUG) + format="%(asctime)s.%(msecs)03d %(name)s %(levelname)s: %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG, +) TEST_LOGGER = logging.getLogger("tests") TEST_INSTANCE_ID = "abc123" @@ -34,7 +35,9 @@ def orchestrator(ctx: task.OrchestrationContext, my_input: int): start_time = datetime.now() new_events = [ helpers.new_orchestrator_started_event(start_time), - helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=json.dumps(test_input)), + helpers.new_execution_started_event( + name, TEST_INSTANCE_ID, encoded_input=json.dumps(test_input) + ), ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, [], new_events) @@ -99,7 +102,8 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(start_time), - helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None)] + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -129,9 +133,9 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(start_time), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_timer_created_event(1, expected_fire_at)] - new_events = [ - helpers.new_timer_fired_event(1, expected_fire_at)] + helpers.new_timer_created_event(1, expected_fire_at), + ] + new_events = [helpers.new_timer_fired_event(1, expected_fire_at)] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) @@ -145,6 +149,7 @@ def delay_orchestrator(ctx: task.OrchestrationContext, _): def test_schedule_activity_actions(): """Test the actions output for the call_activity orchestrator method""" + def dummy_activity(ctx, _): pass @@ -158,7 +163,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): encoded_input = json.dumps(42) new_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input)] + helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, [], new_events) actions = result.actions @@ -173,6 +179,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): def test_schedule_activity_actions_router_without_app_id(): """Tests that scheduleTask action contains correct router fields when app_id is specified""" + def dummy_activity(ctx, _): pass @@ -198,13 +205,14 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert len(actions) == 1 action = actions[0] assert action.router.sourceAppID == "source-app" - assert action.router.targetAppID == '' + assert action.router.targetAppID == "" assert action.scheduleTask.router.sourceAppID == "source-app" - assert action.scheduleTask.router.targetAppID == '' + assert action.scheduleTask.router.targetAppID == "" def test_schedule_activity_actions_router_with_app_id(): """Tests that scheduleTask action contains correct router fields when app_id is specified""" + def dummy_activity(ctx, _): pass @@ -251,7 +259,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] encoded_output = json.dumps("done!") new_events = [helpers.new_task_completed_event(1, encoded_output)] @@ -267,6 +276,7 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): def test_activity_task_failed(): """Tests the failure of an activity task""" + def dummy_activity(ctx, _): pass @@ -280,7 +290,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_task_failed_event(1, ex)] @@ -291,7 +302,9 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? + assert ( + complete_action.failureDetails.errorType == "TaskFailedError" + ) # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage # Make sure the line of code where the exception was raised is included in the stack trace @@ -313,8 +326,10 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): max_number_of_attempts=6, backoff_coefficient=2, max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=50)), - input=orchestrator_input) + retry_timeout=timedelta(seconds=50), + ), + input=orchestrator_input, + ) return result registry = worker._Registry() @@ -325,12 +340,14 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] expected_fire_at = current_timestamp + timedelta(seconds=1) new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -344,7 +361,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(2, current_timestamp)] + helpers.new_timer_fired_event(2, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -357,7 +375,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): expected_fire_at = current_timestamp + timedelta(seconds=2) new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -371,7 +390,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(3, current_timestamp)] + helpers.new_timer_fired_event(3, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -384,7 +404,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -398,7 +419,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(4, current_timestamp)] + helpers.new_timer_fired_event(4, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -411,7 +433,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -425,7 +448,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(5, current_timestamp)] + helpers.new_timer_fired_event(5, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -439,7 +463,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -453,7 +478,8 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(6, current_timestamp)] + helpers.new_timer_fired_event(6, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -465,17 +491,21 @@ def orchestrator(ctx: task.OrchestrationContext, orchestrator_input): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 1 - assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__("Activity task #1 failed: Kah-BOOOOM!!!") + assert actions[0].completeOrchestration.failureDetails.errorMessage.__contains__( + "Activity task #1 failed: Kah-BOOOOM!!!" + ) assert actions[0].id == 7 def test_nondeterminism_expected_timer(): """Tests the non-determinism detection logic when call_timer is expected but some other method (call_activity) is called instead""" + def dummy_activity(ctx, _): pass @@ -490,7 +520,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_timer_created_event(1, fire_at)] + helpers.new_timer_created_event(1, fire_at), + ] new_events = [helpers.new_timer_fired_event(timer_id=1, fire_at=fire_at)] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -499,7 +530,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID assert "create_timer" in complete_action.failureDetails.errorMessage # expected method name assert "call_activity" in complete_action.failureDetails.errorMessage # actual method name @@ -507,6 +538,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_nondeterminism_expected_activity_call_no_task_id(): """Tests the non-determinism detection logic when invoking activity functions""" + def orchestrator(ctx: task.OrchestrationContext, _): result = yield task.CompletableTask() # dummy task return result @@ -517,7 +549,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, "bogus_activity")] + helpers.new_task_scheduled_event(1, "bogus_activity"), + ] new_events = [helpers.new_task_completed_event(1)] @@ -527,13 +560,14 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID assert "call_activity" in complete_action.failureDetails.errorMessage # expected method name def test_nondeterminism_expected_activity_call_wrong_task_type(): """Tests the non-determinism detection when an activity exists in the history but a non-activity is in the code""" + def dummy_activity(ctx, _): pass @@ -547,7 +581,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), + ] new_events = [helpers.new_task_completed_event(1)] @@ -557,7 +592,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID assert "call_activity" in complete_action.failureDetails.errorMessage # expected method name assert "create_timer" in complete_action.failureDetails.errorMessage # unexpected method name @@ -565,6 +600,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_nondeterminism_wrong_activity_name(): """Tests the non-determinism detection when calling an activity with a name that differs from the name in the history""" + def dummy_activity(ctx, _): pass @@ -578,7 +614,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_task_scheduled_event(1, "original_activity")] + helpers.new_task_scheduled_event(1, "original_activity"), + ] new_events = [helpers.new_task_completed_event(1)] @@ -588,15 +625,20 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID assert "call_activity" in complete_action.failureDetails.errorMessage # expected method name - assert "original_activity" in complete_action.failureDetails.errorMessage # expected activity name - assert "dummy_activity" in complete_action.failureDetails.errorMessage # unexpected activity name + assert ( + "original_activity" in complete_action.failureDetails.errorMessage + ) # expected activity name + assert ( + "dummy_activity" in complete_action.failureDetails.errorMessage + ) # unexpected activity name def test_sub_orchestration_task_completion(): """Tests that a sub-orchestration task is completed when the sub-orchestration completes""" + def suborchestrator(ctx: task.OrchestrationContext, _): pass @@ -610,11 +652,15 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_sub_orchestration_created_event(1, suborchestrator_name, "sub-orch-123", encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_sub_orchestration_created_event( + 1, suborchestrator_name, "sub-orch-123", encoded_input=None + ), + ] - new_events = [ - helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] + new_events = [helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) @@ -627,6 +673,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_create_sub_orchestration_actions_router_without_app_id(): """Tests that createSubOrchestration action contains correct router fields when app_id is specified""" + def suborchestrator(ctx: task.OrchestrationContext, _): pass @@ -634,10 +681,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None) registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + _ = registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) - exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt = helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ) exec_evt.router.sourceAppID = "source-app" new_events = [ @@ -652,13 +701,14 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert len(actions) == 1 action = actions[0] assert action.router.sourceAppID == "source-app" - assert action.router.targetAppID == '' + assert action.router.targetAppID == "" assert action.createSubOrchestration.router.sourceAppID == "source-app" - assert action.createSubOrchestration.router.targetAppID == '' + assert action.createSubOrchestration.router.targetAppID == "" def test_create_sub_orchestration_actions_router_with_app_id(): """Tests that createSubOrchestration action contains correct router fields when app_id is specified""" + def suborchestrator(ctx: task.OrchestrationContext, _): pass @@ -666,10 +716,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.call_sub_orchestrator(suborchestrator, input=None, app_id="target-app") registry = worker._Registry() - suborchestrator_name = registry.add_orchestrator(suborchestrator) + _ = registry.add_orchestrator(suborchestrator) orchestrator_name = registry.add_orchestrator(orchestrator) - exec_evt = helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) + exec_evt = helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ) exec_evt.router.sourceAppID = "source-app" new_events = [ @@ -691,6 +743,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_sub_orchestration_task_failed(): """Tests that a sub-orchestration task is completed when the sub-orchestration fails""" + def suborchestrator(ctx: task.OrchestrationContext, _): pass @@ -704,8 +757,13 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_sub_orchestration_created_event(1, suborchestrator_name, "sub-orch-123", encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_sub_orchestration_created_event( + 1, suborchestrator_name, "sub-orch-123", encoded_input=None + ), + ] ex = Exception("Kah-BOOOOM!!!") new_events = [helpers.new_sub_orchestration_failed_event(1, ex)] @@ -716,7 +774,9 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? + assert ( + complete_action.failureDetails.errorType == "TaskFailedError" + ) # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage # Make sure the line of code where the exception was raised is included in the stack trace @@ -726,6 +786,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_nondeterminism_expected_sub_orchestration_task_completion_no_task(): """Tests the non-determinism detection when a sub-orchestration action is encounteed when it shouldn't be""" + def orchestrator(ctx: task.OrchestrationContext, _): result = yield task.CompletableTask() # dummy task return result @@ -735,11 +796,15 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_sub_orchestration_created_event(1, "some_sub_orchestration", "sub-orch-123", encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_sub_orchestration_created_event( + 1, "some_sub_orchestration", "sub-orch-123", encoded_input=None + ), + ] - new_events = [ - helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] + new_events = [helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) @@ -747,17 +812,22 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID - assert "call_sub_orchestrator" in complete_action.failureDetails.errorMessage # expected method name + assert ( + "call_sub_orchestrator" in complete_action.failureDetails.errorMessage + ) # expected method name def test_nondeterminism_expected_sub_orchestration_task_completion_wrong_task_type(): """Tests the non-determinism detection when a sub-orchestration action is encounteed when it shouldn't be. This variation tests the case where the expected task type is wrong (e.g. the code schedules a timer task but the history contains a sub-orchestration completed task).""" + def orchestrator(ctx: task.OrchestrationContext, _): - result = yield ctx.create_timer(datetime.utcnow()) # created timer but history expects sub-orchestration + result = yield ctx.create_timer( + datetime.utcnow() + ) # created timer but history expects sub-orchestration return result registry = worker._Registry() @@ -765,11 +835,15 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), - helpers.new_sub_orchestration_created_event(1, "some_sub_orchestration", "sub-orch-123", encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + helpers.new_sub_orchestration_created_event( + 1, "some_sub_orchestration", "sub-orch-123", encoded_input=None + ), + ] - new_events = [ - helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] + new_events = [helpers.new_sub_orchestration_completed_event(1, encoded_output="42")] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) @@ -777,13 +851,16 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'NonDeterminismError' + assert complete_action.failureDetails.errorType == "NonDeterminismError" assert "1" in complete_action.failureDetails.errorMessage # task ID - assert "call_sub_orchestrator" in complete_action.failureDetails.errorMessage # expected method name + assert ( + "call_sub_orchestrator" in complete_action.failureDetails.errorMessage + ) # expected method name def test_raise_event(): """Tests that an orchestration can wait for and process an external event sent by a client""" + def orchestrator(ctx: task.OrchestrationContext, _): result = yield ctx.wait_for_external_event("my_event") return result @@ -794,7 +871,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [] new_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID)] + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID), + ] # Execute the orchestration until it is waiting for an external event. The result # should be an empty list of actions because the orchestration didn't schedule any work. @@ -817,6 +895,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_raise_event_buffered(): """Tests that an orchestration can receive an event that arrives earlier than expected""" + def orchestrator(ctx: task.OrchestrationContext, _): yield ctx.create_timer(ctx.current_utc_datetime + timedelta(days=1)) result = yield ctx.wait_for_external_event("my_event") @@ -829,7 +908,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): new_events = [ helpers.new_orchestrator_started_event(), helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID), - helpers.new_event_raised_event("my_event", encoded_input="42")] + helpers.new_event_raised_event("my_event", encoded_input="42"), + ] # Execute the orchestration. It should be in a running state waiting for the timer to fire executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -863,10 +943,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID)] + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID), + ] new_events = [ helpers.new_suspend_event(), - helpers.new_event_raised_event("my_event", encoded_input="42")] + helpers.new_event_raised_event("my_event", encoded_input="42"), + ] # Execute the orchestration. It should remain in a running state because it was suspended prior # to processing the event raised event. @@ -898,10 +980,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID)] + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID), + ] new_events = [ helpers.new_terminated_event(encoded_output=json.dumps("terminated!")), - helpers.new_event_raised_event("my_event", encoded_input="42")] + helpers.new_event_raised_event("my_event", encoded_input="42"), + ] # Execute the orchestration. It should be in a running state waiting for an external event executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -915,6 +999,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): @pytest.mark.parametrize("save_events", [True, False]) def test_continue_as_new(save_events: bool): """Tests the behavior of the continue-as-new API""" + def orchestrator(ctx: task.OrchestrationContext, input: int): yield ctx.create_timer(ctx.current_utc_datetime + timedelta(days=1)) ctx.continue_as_new(input + 1, save_events=save_events) @@ -928,9 +1013,9 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): helpers.new_event_raised_event("my_event", encoded_input="42"), helpers.new_event_raised_event("my_event", encoded_input="43"), helpers.new_event_raised_event("my_event", encoded_input="44"), - helpers.new_timer_created_event(1, datetime.utcnow() + timedelta(days=1))] - new_events = [ - helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] + helpers.new_timer_created_event(1, datetime.utcnow() + timedelta(days=1)), + ] + new_events = [helpers.new_timer_fired_event(1, datetime.utcnow() + timedelta(days=1))] # Execute the orchestration. It should be in a running state waiting for the timer to fire executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) @@ -944,12 +1029,15 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): event = complete_action.carryoverEvents[i] assert type(event) is pb.HistoryEvent assert event.HasField("eventRaised") - assert event.eventRaised.name.casefold() == "my_event".casefold() # event names are case-insensitive + assert ( + event.eventRaised.name.casefold() == "my_event".casefold() + ) # event names are case-insensitive assert event.eventRaised.input.value == json.dumps(42 + i) def test_fan_out(): """Tests that a fan-out pattern correctly schedules N tasks""" + def hello(_, name: str): return f"Hello {name}!" @@ -967,7 +1055,10 @@ def orchestrator(ctx: task.OrchestrationContext, count: int): old_events = [] new_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input="10")] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input="10" + ), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) @@ -983,6 +1074,7 @@ def orchestrator(ctx: task.OrchestrationContext, count: int): def test_fan_in(): """Tests that a fan-in pattern works correctly""" + def print_int(_, val: int): return str(val) @@ -999,15 +1091,20 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + ] for i in range(10): - old_events.append(helpers.new_task_scheduled_event( - i + 1, activity_name, encoded_input=str(i))) + old_events.append( + helpers.new_task_scheduled_event(i + 1, activity_name, encoded_input=str(i)) + ) new_events = [] for i in range(10): - new_events.append(helpers.new_task_completed_event( - i + 1, encoded_output=print_int(None, i))) + new_events.append( + helpers.new_task_completed_event(i + 1, encoded_output=print_int(None, i)) + ) # First, test with only the first 5 events. We expect the orchestration to be running # but return zero actions since its still waiting for the other 5 tasks to complete. @@ -1028,6 +1125,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_fan_in_with_single_failure(): """Tests that a fan-in pattern works correctly when one of the tasks fails""" + def print_int(_, val: int): return str(val) @@ -1044,17 +1142,22 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)] + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), + ] for i in range(10): - old_events.append(helpers.new_task_scheduled_event( - i + 1, activity_name, encoded_input=str(i))) + old_events.append( + helpers.new_task_scheduled_event(i + 1, activity_name, encoded_input=str(i)) + ) # 5 of the tasks complete successfully, 1 fails, and 4 are still running. # The expectation is that the orchestration will fail immediately. new_events = [] for i in range(5): - new_events.append(helpers.new_task_completed_event( - i + 1, encoded_output=print_int(None, i))) + new_events.append( + helpers.new_task_completed_event(i + 1, encoded_output=print_int(None, i)) + ) ex = Exception("Kah-BOOOOM!!!") new_events.append(helpers.new_task_failed_event(6, ex)) @@ -1065,12 +1168,15 @@ def orchestrator(ctx: task.OrchestrationContext, _): complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Is this the right error type? + assert ( + complete_action.failureDetails.errorType == "TaskFailedError" + ) # TODO: Is this the right error type? assert str(ex) in complete_action.failureDetails.errorMessage def test_when_any(): """Tests that a when_any pattern works correctly""" + def hello(_, name: str): return f"Hello {name}!" @@ -1090,20 +1196,25 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Test 1: Start the orchestration and let it yield on the when_any. We expect the orchestration # to return two actions: one to schedule the "Tokyo" task and one to schedule the "Seattle" task. old_events = [] - new_events = [helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None)] + new_events = [ + helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None) + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions assert len(actions) == 2 - assert actions[0].HasField('scheduleTask') - assert actions[1].HasField('scheduleTask') + assert actions[0].HasField("scheduleTask") + assert actions[1].HasField("scheduleTask") # The next tests assume that the orchestration has already awaited at the task.when_any() old_events = [ helpers.new_orchestrator_started_event(), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), helpers.new_task_scheduled_event(1, activity_name, encoded_input=json.dumps("Tokyo")), - helpers.new_task_scheduled_event(2, activity_name, encoded_input=json.dumps("Seattle"))] + helpers.new_task_scheduled_event(2, activity_name, encoded_input=json.dumps("Seattle")), + ] # Test 2: Complete the "Tokyo" task. We expect the orchestration to complete with output "Hello, Tokyo!" encoded_output = json.dumps(hello(None, "Tokyo")) @@ -1128,20 +1239,24 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_when_any_with_retry(): """Tests that a when_any pattern works correctly with retries""" + def dummy_activity(_, inp: str): if inp == "Tokyo": raise ValueError("Kah-BOOOOM!!!") return f"Hello {inp}!" def orchestrator(ctx: task.OrchestrationContext, _): - t1 = ctx.call_activity(dummy_activity, - retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=1), - max_number_of_attempts=6, - backoff_coefficient=2, - max_retry_interval=timedelta(seconds=10), - retry_timeout=timedelta(seconds=50)), - input="Tokyo") + t1 = ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=1), + max_number_of_attempts=6, + backoff_coefficient=2, + max_retry_interval=timedelta(seconds=10), + retry_timeout=timedelta(seconds=50), + ), + input="Tokyo", + ) t2 = ctx.call_activity(dummy_activity, input="Seattle") winner = yield task.when_any([t1, t2]) if winner == t1: @@ -1157,14 +1272,18 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Simulate the task failing for the first time and confirm that a timer is scheduled for 1 second in the future old_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), - helpers.new_task_scheduled_event(2, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(2, task.get_name(dummy_activity)), + ] expected_fire_at = current_timestamp + timedelta(seconds=1) new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1178,7 +1297,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(3, current_timestamp)] + helpers.new_timer_fired_event(3, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1191,7 +1311,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): expected_fire_at = current_timestamp + timedelta(seconds=2) new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1213,20 +1334,24 @@ def orchestrator(ctx: task.OrchestrationContext, _): def test_when_all_with_retry(): """Tests that a when_all pattern works correctly with retries""" + def dummy_activity(ctx, inp: str): if inp == "Tokyo": raise ValueError("Kah-BOOOOM!!!") return f"Hello {inp}!" def orchestrator(ctx: task.OrchestrationContext, _): - t1 = ctx.call_activity(dummy_activity, - retry_policy=task.RetryPolicy( - first_retry_interval=timedelta(seconds=2), - max_number_of_attempts=3, - backoff_coefficient=4, - max_retry_interval=timedelta(seconds=5), - retry_timeout=timedelta(seconds=50)), - input="Tokyo") + t1 = ctx.call_activity( + dummy_activity, + retry_policy=task.RetryPolicy( + first_retry_interval=timedelta(seconds=2), + max_number_of_attempts=3, + backoff_coefficient=4, + max_retry_interval=timedelta(seconds=5), + retry_timeout=timedelta(seconds=50), + ), + input="Tokyo", + ) t2 = ctx.call_activity(dummy_activity, input="Seattle") results = yield task.when_all([t1, t2]) return results @@ -1239,14 +1364,18 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Simulate the task failing for the first time and confirm that a timer is scheduled for 2 seconds in the future old_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), - helpers.new_execution_started_event(orchestrator_name, TEST_INSTANCE_ID, encoded_input=None), + helpers.new_execution_started_event( + orchestrator_name, TEST_INSTANCE_ID, encoded_input=None + ), helpers.new_task_scheduled_event(1, task.get_name(dummy_activity)), - helpers.new_task_scheduled_event(2, task.get_name(dummy_activity))] + helpers.new_task_scheduled_event(2, task.get_name(dummy_activity)), + ] expected_fire_at = current_timestamp + timedelta(seconds=2) new_events = [ helpers.new_orchestrator_started_event(timestamp=current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1260,7 +1389,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_timer_fired_event(3, current_timestamp)] + helpers.new_timer_fired_event(3, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1273,7 +1403,8 @@ def orchestrator(ctx: task.OrchestrationContext, _): expected_fire_at = current_timestamp + timedelta(seconds=5) new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1286,8 +1417,10 @@ def orchestrator(ctx: task.OrchestrationContext, _): # And, Simulate the timer firing at the expected time and confirm that another activity task is scheduled encoded_output = json.dumps(dummy_activity(None, "Seattle")) old_events = old_events + new_events - new_events = [helpers.new_task_completed_event(2, encoded_output), - helpers.new_timer_fired_event(4, current_timestamp)] + new_events = [ + helpers.new_task_completed_event(2, encoded_output), + helpers.new_timer_fired_event(4, current_timestamp), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions @@ -1301,17 +1434,22 @@ def orchestrator(ctx: task.OrchestrationContext, _): old_events = old_events + new_events new_events = [ helpers.new_orchestrator_started_event(current_timestamp), - helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!"))] + helpers.new_task_failed_event(1, ValueError("Kah-BOOOOM!!!")), + ] executor = worker._OrchestrationExecutor(registry, TEST_LOGGER) result = executor.execute(TEST_INSTANCE_ID, old_events, new_events) actions = result.actions complete_action = get_and_validate_single_complete_orchestration_action(actions) assert complete_action.orchestrationStatus == pb.ORCHESTRATION_STATUS_FAILED - assert complete_action.failureDetails.errorType == 'TaskFailedError' # TODO: Should this be the specific error? + assert ( + complete_action.failureDetails.errorType == "TaskFailedError" + ) # TODO: Should this be the specific error? assert str(ex) in complete_action.failureDetails.errorMessage -def get_and_validate_single_complete_orchestration_action(actions: list[pb.OrchestratorAction]) -> pb.CompleteOrchestrationAction: +def get_and_validate_single_complete_orchestration_action( + actions: list[pb.OrchestratorAction], +) -> pb.CompleteOrchestrationAction: assert len(actions) == 1 assert type(actions[0]) is pb.OrchestratorAction assert actions[0].HasField("completeOrchestration") diff --git a/tests/durabletask/test_orchestration_wait.py b/tests/durabletask/test_orchestration_wait.py index 03f7e30..49eab0e 100644 --- a/tests/durabletask/test_orchestration_wait.py +++ b/tests/durabletask/test_orchestration_wait.py @@ -1,17 +1,19 @@ -from unittest.mock import patch, ANY, Mock +from unittest.mock import Mock -from durabletask.client import TaskHubGrpcClient -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from durabletask.internal.shared import (get_default_host_address, - get_grpc_channel) import pytest +from durabletask.client import TaskHubGrpcClient + + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_start_timeout(timeout): instance_id = "test-instance" - from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ - OrchestrationState, ORCHESTRATION_STATUS_RUNNING + from durabletask.internal.orchestrator_service_pb2 import ( + ORCHESTRATION_STATUS_RUNNING, + GetInstanceResponse, + OrchestrationState, + ) response = GetInstanceResponse() state = OrchestrationState() @@ -30,16 +32,20 @@ def test_wait_for_orchestration_start_timeout(timeout): c._stub.WaitForInstanceStart.assert_called_once() _, kwargs = c._stub.WaitForInstanceStart.call_args if timeout is None or timeout == 0: - assert kwargs.get('timeout') is None + assert kwargs.get("timeout") is None else: - assert kwargs.get('timeout') == timeout + assert kwargs.get("timeout") == timeout + @pytest.mark.parametrize("timeout", [None, 0, 5]) def test_wait_for_orchestration_completion_timeout(timeout): instance_id = "test-instance" - from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ - OrchestrationState, ORCHESTRATION_STATUS_COMPLETED + from durabletask.internal.orchestrator_service_pb2 import ( + ORCHESTRATION_STATUS_COMPLETED, + GetInstanceResponse, + OrchestrationState, + ) response = GetInstanceResponse() state = OrchestrationState() @@ -58,6 +64,6 @@ def test_wait_for_orchestration_completion_timeout(timeout): c._stub.WaitForInstanceCompletion.assert_called_once() _, kwargs = c._stub.WaitForInstanceCompletion.call_args if timeout is None or timeout == 0: - assert kwargs.get('timeout') is None + assert kwargs.get("timeout") is None else: - assert kwargs.get('timeout') == timeout + assert kwargs.get("timeout") == timeout diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py index de6753b..53b6c9a 100644 --- a/tests/durabletask/test_worker_concurrency_loop.py +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -10,29 +10,30 @@ def __init__(self): self.completed = [] def CompleteOrchestratorTask(self, res): - self.completed.append(('orchestrator', res)) + self.completed.append(("orchestrator", res)) def CompleteActivityTask(self, res): - self.completed.append(('activity', res)) + self.completed.append(("activity", res)) class DummyRequest: def __init__(self, kind, instance_id): self.kind = kind self.instanceId = instance_id - self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) - self.name = 'dummy' + self.orchestrationInstance = type("O", (), {"instanceId": instance_id}) + self.name = "dummy" self.taskId = 1 - self.input = type('I', (), {'value': ''}) + self.input = type("I", (), {"value": ""}) self.pastEvents = [] self.newEvents = [] def HasField(self, field): - return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ - (field == 'activityRequest' and self.kind == 'activity') + return (field == "orchestratorRequest" and self.kind == "orchestrator") or ( + field == "activityRequest" and self.kind == "activity" + ) def WhichOneof(self, _): - return f'{self.kind}Request' + return f"{self.kind}Request" class DummyCompletionToken: @@ -50,33 +51,40 @@ def test_worker_concurrency_loop_sync(): def dummy_orchestrator(req, stub, completionToken): time.sleep(0.1) - stub.CompleteOrchestratorTask('ok') + stub.CompleteOrchestratorTask("ok") def dummy_activity(req, stub, completionToken): time.sleep(0.1) - stub.CompleteActivityTask('ok') + stub.CompleteActivityTask("ok") # Patch the worker's _execute_orchestrator and _execute_activity worker._execute_orchestrator = dummy_orchestrator worker._execute_activity = dummy_activity - orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] - activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + orchestrator_requests = [DummyRequest("orchestrator", f"orch{i}") for i in range(3)] + activity_requests = [DummyRequest("activity", f"act{i}") for i in range(4)] async def run_test(): # Start the worker manager's run loop in the background worker_task = asyncio.create_task(worker._async_worker_manager.run()) for req in orchestrator_requests: - worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_orchestration( + dummy_orchestrator, req, stub, DummyCompletionToken() + ) for req in activity_requests: - worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + worker._async_worker_manager.submit_activity( + dummy_activity, req, stub, DummyCompletionToken() + ) await asyncio.sleep(1.0) - orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') - activity_count = sum(1 for t, _ in stub.completed if t == 'activity') - assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + orchestrator_count = sum(1 for t, _ in stub.completed if t == "orchestrator") + activity_count = sum(1 for t, _ in stub.completed if t == "activity") + assert orchestrator_count == 3, ( + f"Expected 3 orchestrator completions, got {orchestrator_count}" + ) assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" worker._async_worker_manager._shutdown = True await worker_task + asyncio.run(run_test()) @@ -116,6 +124,7 @@ def fn(*args, **kwargs): with lock: results.append((kind, idx)) return f"{kind}-{idx}-done" + return fn # Submit more work than concurrency allows diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py index c7ba238..a88e3e3 100644 --- a/tests/durabletask/test_worker_concurrency_loop_async.py +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -8,29 +8,30 @@ def __init__(self): self.completed = [] def CompleteOrchestratorTask(self, res): - self.completed.append(('orchestrator', res)) + self.completed.append(("orchestrator", res)) def CompleteActivityTask(self, res): - self.completed.append(('activity', res)) + self.completed.append(("activity", res)) class DummyRequest: def __init__(self, kind, instance_id): self.kind = kind self.instanceId = instance_id - self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) - self.name = 'dummy' + self.orchestrationInstance = type("O", (), {"instanceId": instance_id}) + self.name = "dummy" self.taskId = 1 - self.input = type('I', (), {'value': ''}) + self.input = type("I", (), {"value": ""}) self.pastEvents = [] self.newEvents = [] def HasField(self, field): - return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ - (field == 'activityRequest' and self.kind == 'activity') + return (field == "orchestratorRequest" and self.kind == "orchestrator") or ( + field == "activityRequest" and self.kind == "activity" + ) def WhichOneof(self, _): - return f'{self.kind}Request' + return f"{self.kind}Request" class DummyCompletionToken: @@ -48,33 +49,40 @@ def test_worker_concurrency_loop_async(): async def dummy_orchestrator(req, stub, completionToken): await asyncio.sleep(0.1) - stub.CompleteOrchestratorTask('ok') + stub.CompleteOrchestratorTask("ok") async def dummy_activity(req, stub, completionToken): await asyncio.sleep(0.1) - stub.CompleteActivityTask('ok') + stub.CompleteActivityTask("ok") # Patch the worker's _execute_orchestrator and _execute_activity grpc_worker._execute_orchestrator = dummy_orchestrator grpc_worker._execute_activity = dummy_activity - orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] - activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + orchestrator_requests = [DummyRequest("orchestrator", f"orch{i}") for i in range(3)] + activity_requests = [DummyRequest("activity", f"act{i}") for i in range(4)] async def run_test(): # Clear stub state before each run stub.completed.clear() worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) for req in orchestrator_requests: - grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_orchestration( + dummy_orchestrator, req, stub, DummyCompletionToken() + ) for req in activity_requests: - grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + grpc_worker._async_worker_manager.submit_activity( + dummy_activity, req, stub, DummyCompletionToken() + ) await asyncio.sleep(1.0) - orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') - activity_count = sum(1 for t, _ in stub.completed if t == 'activity') - assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + orchestrator_count = sum(1 for t, _ in stub.completed if t == "orchestrator") + activity_count = sum(1 for t, _ in stub.completed if t == "activity") + assert orchestrator_count == 3, ( + f"Expected 3 orchestrator completions, got {orchestrator_count}" + ) assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" grpc_worker._async_worker_manager._shutdown = True await worker_task + asyncio.run(run_test()) asyncio.run(run_test()) diff --git a/tox.ini b/tox.ini new file mode 100644 index 0000000..c12ca9a --- /dev/null +++ b/tox.ini @@ -0,0 +1,103 @@ +[tox] +skipsdist = True +minversion = 3.10.0 +envlist = + py{39,310,311,312,313,314} + flake8, + ruff, + mypy, + +[testenv] +# you can run tox with the e2e pytest marker using tox factors: +# tox -e py39,py310,py311,py312,py313,py314 -- e2e +# or single one with: +# tox -e py310-e2e +# to use custom grpc endpoint and not capture print statements (-s arg in pytest): +# DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e -- -s +setenv = + PYTHONDONTWRITEBYTECODE=1 +deps = -rdev-requirements.txt +commands = + !e2e: pytest {posargs} -q -k "not e2e" --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml + e2e: pytest {posargs} -q -k e2e +commands_pre = + pip3 install -e {toxinidir}/ +allowlist_externals = pip3 +pass_env = DAPR_GRPC_ENDPOINT,DAPR_HTTP_ENDPOINT,DAPR_RUNTIME_HOST,DAPR_GRPC_PORT,DAPR_HTTP_PORT + +[testenv:flake8] +basepython = python3 +usedevelop = False +deps = + flake8==7.3.0 + pip +commands = + flake8 . + +[testenv:ruff] +basepython = python3 +usedevelop = False +deps = ruff +commands = + ruff check --select I --fix + ruff format + +[testenv:examples] +passenv = HOME +basepython = python3 +changedir = ./examples/ +deps = + mechanical-markdown +commands = + ./validate.sh conversation + ./validate.sh crypto + ./validate.sh metadata + ./validate.sh error_handling + ./validate.sh pubsub-simple + ./validate.sh pubsub-streaming + ./validate.sh pubsub-streaming-async + ./validate.sh state_store + ./validate.sh state_store_query + ./validate.sh secret_store + ./validate.sh invoke-simple + ./validate.sh invoke-custom-data + ./validate.sh demo_actor + ./validate.sh invoke-binding + ./validate.sh grpc_proxying + ./validate.sh w3c-tracing + ./validate.sh distributed_lock + ./validate.sh configuration + ./validate.sh demo_workflow + ./validate.sh workflow + ./validate.sh jobs + ./validate.sh ../ +commands_pre = + pip3 install -e {toxinidir}/ +allowlist_externals=* + +[testenv:example-component] +; This environment is used to validate a specific example component. +; Usage: tox -e example-component -- component_name +; Example: tox -e example-component -- conversation +passenv = HOME +basepython = python3 +changedir = ./examples/ +deps = + mechanical-markdown +commands = + ./validate.sh {posargs} + +commands_pre = + pip3 install -e {toxinidir}/ +allowlist_externals=* + +[testenv:type] +basepython = python3 +usedevelop = False +deps = -rdev-requirements.txt +commands = + mypy --config-file mypy.ini +commands_pre = + pip3 install -e {toxinidir}/ +allowlist_externals=* + From a37c93b47571ac4d8af5fd4abe1fc13e7495d7cc Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Fri, 10 Oct 2025 06:10:51 -0500 Subject: [PATCH 2/3] allow host only to be changed Co-authored-by: Albert Callarisa Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- durabletask/internal/shared.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 2f63f10..6d8609a 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -42,8 +42,8 @@ def get_default_host_address() -> str: # Host/port split overrides host = os.environ.get("DAPR_GRPC_HOST") or os.environ.get("DAPR_RUNTIME_HOST") - port = os.environ.get("DAPR_GRPC_PORT") - if host and port: + if host: + port = os.environ.get("DAPR_GRPC_PORT", "4001") return f"{host}:{port}" # Default to durabletask-go default port From 1b7d7df1ae861d5c548ce4ec69b9a703b3f8a927 Mon Sep 17 00:00:00 2001 From: Filinto Duran <1373693+filintod@users.noreply.github.com> Date: Fri, 10 Oct 2025 08:15:35 -0500 Subject: [PATCH 3/3] Migrate dev dependencies to `pyproject.toml`, remove `dev-requirements.txt` and update related files and workflows. Dump flake8 and use ruff that covers format and linting. Update readme Signed-off-by: Filinto Duran <1373693+filintod@users.noreply.github.com> --- .flake8 | 7 --- .github/workflows/pr-validation.yml | 2 +- README.md | 27 ++++++++--- dev-requirements.txt | 8 ---- durabletask/client.py | 8 ++-- examples/components/statestore.yaml | 16 +++++++ pyproject.toml | 14 +++++- requirements.txt | 7 +-- tox.ini | 74 ++--------------------------- 9 files changed, 60 insertions(+), 103 deletions(-) delete mode 100644 .flake8 delete mode 100644 dev-requirements.txt create mode 100644 examples/components/statestore.yaml diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 1b01528..0000000 --- a/.flake8 +++ /dev/null @@ -1,7 +0,0 @@ -[flake8] -ignore = E203,E501,W503,E701,E704,F821,C901 -extend-exclude = .tox,venv,.venv,build,**/.venv,**/venv,*pb2_grpc.py,*pb2.py -per-file-ignores= - examples/**:F541 setup.py:E121 - tests/**:F541,E712 -max-line-length = 100 \ No newline at end of file diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 3c3313a..7eeb732 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -28,7 +28,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r dev-requirements.txt + pip install '.[dev]' - name: Lint with flake8 run: | flake8 . --count --show-source --statistics --exit-zero diff --git a/README.md b/README.md index 4a45d9b..9357f0b 100644 --- a/README.md +++ b/README.md @@ -162,7 +162,9 @@ The following is more information about how to develop this project. Note that d ### Generating protobufs ```sh -pip3 install -r dev-requirements.txt +# install dev dependencies for generating protobufs and running tests +pip3 install '.[dev]' + make gen-proto ``` @@ -170,25 +172,38 @@ This will download the `orchestrator_service.proto` from the `microsoft/durablet ### Running unit tests -Unit tests can be run using the following command from the project root. Unit tests _don't_ require a sidecar process to be running. +Unit tests can be run using the following command from the project root. +Unit tests _don't_ require a sidecar process to be running. + +To run on a specific python version (eg: 3.11), run the following command from the project root: ```sh -make test-unit +tox -e py311 ``` ### Running E2E tests -The E2E (end-to-end) tests require a sidecar process to be running. You can use the Dapr sidecar for this or run a Durable Task test sidecar using the following command: +The E2E (end-to-end) tests require a sidecar process to be running. + +For non-multi app activities test you can use the Durable Task test sidecar using the following command: ```sh go install github.com/dapr/durabletask-go@main durabletask-go --port 4001 ``` -To run the E2E tests, run the following command from the project root: +Certain aspects like multi-app activities require the full dapr runtime to be running. + +```shell +dapr init || true + +dapr run --app-id test-app --dapr-grpc-port 4001 --components-path ./examples/components/ +``` + +To run the E2E tests on a specific python version (eg: 3.11), run the following command from the project root: ```sh -make test-e2e +tox -e py311-e2e ``` ## Contributing diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index bccde37..0000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,8 +0,0 @@ -# TODO: move to pyproject optional-dependencies -pytest-asyncio>=0.23 -flake8 -tox>=4.0.0 -pytest -pytest-cov -grpcio-tools==1.75.1 -protobuf>=6.31.1 \ No newline at end of file diff --git a/durabletask/client.py b/durabletask/client.py index 1fe8688..79475ec 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -137,12 +137,14 @@ def schedule_new_orchestration( ) -> str: name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) + input_pb = ( + wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None + ) + req = pb.CreateInstanceRequest( name=name, instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=wrappers_pb2.StringValue(value=shared.to_json(input)) - if input is not None - else None, + input=input_pb, scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, diff --git a/examples/components/statestore.yaml b/examples/components/statestore.yaml new file mode 100644 index 0000000..a2b567a --- /dev/null +++ b/examples/components/statestore.yaml @@ -0,0 +1,16 @@ +apiVersion: dapr.io/v1alpha1 +kind: Component +metadata: + name: statestore +spec: + type: state.redis + version: v1 + metadata: + - name: redisHost + value: localhost:6379 + - name: redisPassword + value: "" + - name: actorStateStore + value: "true" + - name: keyPrefix + value: "workflow" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f4e83d5..b36cc98 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,19 @@ markers = [ "e2e: mark a test as an end-to-end test that requires a running sidecar" ] +[project.optional-dependencies] +dev = [ + "pytest", + "pytest-asyncio>=0.23", + "flake8==7.3.0", + "tox>=4.0.0", + "pytest-cov", + "ruff", + + # grpc gen + "grpcio-tools==1.75.1", +] + [tool.ruff] target-version = "py39" # TODO: update to py310 when we drop support for py39 line-length = 100 @@ -59,7 +72,6 @@ select = [ "E", # pycodestyle errors "W", # pycodestyle warnings "F", # pyflakes - "I", # isort "C", # flake8-comprehensions "B", # flake8-bugbear "UP", # pyupgrade diff --git a/requirements.txt b/requirements.txt index 07426eb..7b226d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1 @@ -autopep8 -grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newer versions are backwards compatible -protobuf -pytest -pytest-cov -asyncio +# requirements in pyproject.toml \ No newline at end of file diff --git a/tox.ini b/tox.ini index c12ca9a..59a0908 100644 --- a/tox.ini +++ b/tox.ini @@ -6,6 +6,8 @@ envlist = flake8, ruff, mypy, +# TODO: switch runner to uv (tox-uv plugin) +runner = virtualenv [testenv] # you can run tox with the e2e pytest marker using tox factors: @@ -16,7 +18,7 @@ envlist = # DAPR_GRPC_ENDPOINT=localhost:12345 tox -e py310-e2e -- -s setenv = PYTHONDONTWRITEBYTECODE=1 -deps = -rdev-requirements.txt +deps = .[dev] commands = !e2e: pytest {posargs} -q -k "not e2e" --cov=durabletask --cov-branch --cov-report=term-missing --cov-report=xml e2e: pytest {posargs} -q -k e2e @@ -25,79 +27,9 @@ commands_pre = allowlist_externals = pip3 pass_env = DAPR_GRPC_ENDPOINT,DAPR_HTTP_ENDPOINT,DAPR_RUNTIME_HOST,DAPR_GRPC_PORT,DAPR_HTTP_PORT -[testenv:flake8] -basepython = python3 -usedevelop = False -deps = - flake8==7.3.0 - pip -commands = - flake8 . - [testenv:ruff] basepython = python3 usedevelop = False -deps = ruff commands = ruff check --select I --fix ruff format - -[testenv:examples] -passenv = HOME -basepython = python3 -changedir = ./examples/ -deps = - mechanical-markdown -commands = - ./validate.sh conversation - ./validate.sh crypto - ./validate.sh metadata - ./validate.sh error_handling - ./validate.sh pubsub-simple - ./validate.sh pubsub-streaming - ./validate.sh pubsub-streaming-async - ./validate.sh state_store - ./validate.sh state_store_query - ./validate.sh secret_store - ./validate.sh invoke-simple - ./validate.sh invoke-custom-data - ./validate.sh demo_actor - ./validate.sh invoke-binding - ./validate.sh grpc_proxying - ./validate.sh w3c-tracing - ./validate.sh distributed_lock - ./validate.sh configuration - ./validate.sh demo_workflow - ./validate.sh workflow - ./validate.sh jobs - ./validate.sh ../ -commands_pre = - pip3 install -e {toxinidir}/ -allowlist_externals=* - -[testenv:example-component] -; This environment is used to validate a specific example component. -; Usage: tox -e example-component -- component_name -; Example: tox -e example-component -- conversation -passenv = HOME -basepython = python3 -changedir = ./examples/ -deps = - mechanical-markdown -commands = - ./validate.sh {posargs} - -commands_pre = - pip3 install -e {toxinidir}/ -allowlist_externals=* - -[testenv:type] -basepython = python3 -usedevelop = False -deps = -rdev-requirements.txt -commands = - mypy --config-file mypy.ini -commands_pre = - pip3 install -e {toxinidir}/ -allowlist_externals=* -